diff --git a/keras_nlp/layers/modeling/f_net_encoder.py b/keras_nlp/layers/modeling/f_net_encoder.py index 05da0ac44a..806a3dac8c 100644 --- a/keras_nlp/layers/modeling/f_net_encoder.py +++ b/keras_nlp/layers/modeling/f_net_encoder.py @@ -138,9 +138,9 @@ def call(self, inputs): def fourier_transform(input): # Apply FFT on the input and take the real part. - x = (input, ops.zeros_like(input)) - mixing_output = ops.fft2(x)[0] - return mixing_output + real_in, imaginary_in = (input, ops.zeros_like(input)) + real_out, _ = ops.fft2((real_in, imaginary_in)) + return real_out def add_and_norm(input1, input2, norm_layer): return norm_layer(input1 + input2)