From d7973798a67563d4b37bb9405f2f1622d5148af3 Mon Sep 17 00:00:00 2001 From: Deepali Jain Date: Fri, 2 Aug 2024 01:38:09 -0700 Subject: [PATCH] 1. Add point cloud encoding policy - pct_policy 2. Add custom keras layers module PiperOrigin-RevId: 658711302 --- .../layers/keras_image_encoder_layer.py | 73 +++++ .../layers/keras_image_encoder_layer_test.py | 40 +++ .../layers/keras_masking_attention_layer.py | 75 +++++ .../keras_masking_attention_layer_test.py | 46 ++++ .../layers/keras_positional_encoding_layer.py | 40 +++ .../keras_positional_encoding_layer_test.py | 27 ++ .../layers/keras_ranking_attention_layer.py | 70 +++++ .../keras_ranking_attention_layer_test.py | 46 ++++ .../layers/keras_trans_attention_layer.py | 79 ++++++ .../keras_trans_attention_layer_test.py | 43 +++ iris/policies/pct_policy.py | 260 ++++++++++++++++++ iris/policies/pct_policy_test.py | 83 ++++++ 12 files changed, 882 insertions(+) create mode 100644 iris/policies/layers/keras_image_encoder_layer.py create mode 100644 iris/policies/layers/keras_image_encoder_layer_test.py create mode 100644 iris/policies/layers/keras_masking_attention_layer.py create mode 100644 iris/policies/layers/keras_masking_attention_layer_test.py create mode 100644 iris/policies/layers/keras_positional_encoding_layer.py create mode 100644 iris/policies/layers/keras_positional_encoding_layer_test.py create mode 100644 iris/policies/layers/keras_ranking_attention_layer.py create mode 100644 iris/policies/layers/keras_ranking_attention_layer_test.py create mode 100644 iris/policies/layers/keras_trans_attention_layer.py create mode 100644 iris/policies/layers/keras_trans_attention_layer_test.py create mode 100644 iris/policies/pct_policy.py create mode 100644 iris/policies/pct_policy_test.py diff --git a/iris/policies/layers/keras_image_encoder_layer.py b/iris/policies/layers/keras_image_encoder_layer.py new file mode 100644 index 0000000..4b664ba --- /dev/null +++ b/iris/policies/layers/keras_image_encoder_layer.py @@ -0,0 +1,73 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A keras layer for encoding image into patches.""" + +from typing import Tuple + +import tensorflow as tf + + +class ImageEncoder(tf.keras.layers.Layer): + """Keras layer for encoding image into patches.""" + + def __init__(self, + patch_height: int, + patch_width: int, + stride_height: int, + stride_width: int, + normalize_positions: bool = True) -> None: + """Initializes Keras layer for encoding image into patches. + + Args: + patch_height: Height of image patch for encoding. + patch_width: Width of image patch for encoding. + stride_height: Stride (shift) height for consecutive image patches. + stride_width: Stride (shift) width for consecutive image patches. + normalize_positions: True to normalize patch center positions. + """ + super().__init__() + self._patch_height = patch_height + self._patch_width = patch_width + self._stride_height = stride_height + self._stride_width = stride_width + self._normalize_positions = normalize_positions + + def call(self, images: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + batch_shape, image_height, image_width, channels = images.shape + if batch_shape is None: + batch_shape = tf.shape(images)[0] + patches = tf.image.extract_patches( + images, + sizes=[1, self._patch_height, self._patch_width, 1], + strides=[1, self._stride_height, self._stride_width, 1], + rates=[1, 1, 1, 1], + padding='VALID') + encoding = tf.reshape( + patches, + [batch_shape, -1, self._patch_height * self._patch_width * channels]) + pos_x = tf.range(self._patch_height // 2, image_height, self._stride_height) + pos_y = tf.range(self._patch_width // 2, image_width, self._stride_width) + if self._normalize_positions: + pos_x /= image_height + pos_y /= image_width + x, y = tf.meshgrid(pos_x, pos_y) + x = tf.transpose(x) + y = tf.transpose(y) + centers = tf.stack([x, y], axis=-1) + centers = tf.reshape(centers, (-1, 2)) + centers = tf.tile(centers, (batch_shape, 1)) + centers = tf.reshape(centers, (batch_shape, -1, 2)) + centers = tf.cast(centers, 'float32') + return encoding, centers diff --git a/iris/policies/layers/keras_image_encoder_layer_test.py b/iris/policies/layers/keras_image_encoder_layer_test.py new file mode 100644 index 0000000..92efbd5 --- /dev/null +++ b/iris/policies/layers/keras_image_encoder_layer_test.py @@ -0,0 +1,40 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris.policies.layers import keras_image_encoder_layer +import numpy as np +import tensorflow as tf +from absl.testing import absltest + + +class ImageEncoderTest(absltest.TestCase): + + def test_layer_output(self): + """Tests the output of ImageEncoder layer.""" + input_layer = tf.keras.layers.Input( + batch_input_shape=(2, 5, 6, 2), dtype="float", name="input") + output_layer = keras_image_encoder_layer.ImageEncoder( + patch_height=2, + patch_width=2, + stride_height=1, + stride_width=1)(input_layer) + model = tf.keras.models.Model(inputs=[input_layer], outputs=[output_layer]) + images = np.arange(2*5*6*2).reshape((2, 5, 6, 2)) + encoding, centers = model.predict(images)[0] + self.assertEqual(encoding.shape, (2, 20, 8)) + self.assertEqual(centers.shape, (2, 20, 2)) + + +if __name__ == "__main__": + absltest.main() diff --git a/iris/policies/layers/keras_masking_attention_layer.py b/iris/policies/layers/keras_masking_attention_layer.py new file mode 100644 index 0000000..191cf32 --- /dev/null +++ b/iris/policies/layers/keras_masking_attention_layer.py @@ -0,0 +1,75 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A keras layer for masking based attention.""" + +from typing import Callable +import tensorflow as tf + + +class FavorMaskingAttention(tf.keras.layers.Layer): + """A keras layer for masking based attention. + + A layer that creates a representation of the RGB(D)-image using attention + mechanism from https://arxiv.org/abs/2009.14794. It leverages Performer-ReLU + (go/performer) attention module in order to bypass explicit materialization of + the L x L attention tensor, where L is the number of patches (potentially even + individual pixels). This reduces time complexity of the attention module from + quadratic to linear in L and provides a gateway to processing high-resolution + images, where explicitly calculating attention tensor is not feasible. The + ranking procedure is adopted from https://arxiv.org/abs/2003.08165, where + scores of patches are defined as sums of the entries of the corresponding + column in the attention tensor. After ranking, top K tokens are preserved and + the rest of them are masked by 0. + """ + + def __init__( + self, + kernel_transformation: Callable[..., tf.Tensor], + top_k: int = 5) -> None: # pytype: disable=annotation-type-mismatch + """Initializes FavorMaskingAttention layer. + + Args: + kernel_transformation: Transformation used to get finite kernel features. + top_k: Number of top patches that will be chosen to "summarize" entire + image. + """ + super().__init__() + self._kernel_transformation = kernel_transformation + self._top_k = top_k + + def call(self, + queries: tf.Tensor, + keys: tf.Tensor, + values: tf.Tensor) -> tf.Tensor: + queries_prime = self._kernel_transformation( + data=tf.expand_dims(queries, axis=2), + is_query=True) + queries_prime = tf.squeeze(queries_prime, axis=2) + keys_prime = self._kernel_transformation( + data=tf.expand_dims(keys, axis=2), + is_query=False) + keys_prime = tf.squeeze(keys_prime, axis=2) + _, length, _ = queries_prime.shape + all_ones = tf.ones([1, length]) + reduced_queries_prime = tf.matmul(all_ones, queries_prime) + scores = tf.matmul(reduced_queries_prime, keys_prime, transpose_b=True) + scores = tf.reshape(scores, (-1, length)) + sorted_idxs = tf.argsort(scores, axis=-1, direction='DESCENDING') + cutoff = tf.gather( + scores, sorted_idxs[:, self._top_k], axis=1, batch_dims=1) + cond = scores > tf.expand_dims(cutoff, -1) + return tf.where(tf.expand_dims(cond, -1), + values, + tf.zeros_like(values)) diff --git a/iris/policies/layers/keras_masking_attention_layer_test.py b/iris/policies/layers/keras_masking_attention_layer_test.py new file mode 100644 index 0000000..f19407b --- /dev/null +++ b/iris/policies/layers/keras_masking_attention_layer_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris.policies.layers import keras_masking_attention_layer +from lingvo.core import favor_attention as favor +import numpy as np +import tensorflow as tf +from absl.testing import absltest + + +class FavorMaskingAttentionTest(absltest.TestCase): + + def test_layer_output(self): + """Tests the output of RankingAttention layer.""" + query_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="query") + key_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="keys") + value_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="values") + output_layer = keras_masking_attention_layer.FavorMaskingAttention( + kernel_transformation=favor.relu_kernel_transformation, + top_k=2)(query_layer, key_layer, value_layer) + model = tf.keras.models.Model( + inputs=[query_layer, key_layer, value_layer], outputs=[output_layer]) + queries = np.arange(2 * 3 * 4).reshape((2, 3, 4)) + top_values = model.predict((queries, queries, queries)) + self.assertEqual(top_values.shape, (2, 3, 4)) + true_values = np.arange(2 * 3 * 4).reshape((2, 3, 4)) + true_values[:, 0, :] = 0 + np.testing.assert_array_almost_equal(top_values, true_values, 1) + + +if __name__ == "__main__": + absltest.main() diff --git a/iris/policies/layers/keras_positional_encoding_layer.py b/iris/policies/layers/keras_positional_encoding_layer.py new file mode 100644 index 0000000..865e88c --- /dev/null +++ b/iris/policies/layers/keras_positional_encoding_layer.py @@ -0,0 +1,40 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A keras layer for positional encoding.""" + +from typing import Tuple + +import tensorflow as tf + + +class PositionalEncoding(tf.keras.layers.Layer): + """Keras layer for positional encoding.""" + + def call(self, + seq_len: int, + encoding_dimension: int) -> Tuple[tf.Tensor, tf.Tensor]: + num_freq = encoding_dimension // 2 + indices = tf.expand_dims(tf.range(seq_len), 0) + indices = tf.tile(indices, [num_freq, 1]) + freq_fn = lambda k: 1.0/(10000 ** (2*k/encoding_dimension)) + freq = tf.keras.layers.Lambda(freq_fn)(tf.range(num_freq)) + freq = tf.expand_dims(freq, 1) + freq = tf.tile(freq, [1, seq_len]) + args = tf.multiply(freq, tf.cast(indices, dtype=tf.float64)) + sin_enc = tf.math.sin(args) + cos_enc = tf.math.sin(args) + encoding = tf.keras.layers.Concatenate(axis=0)([sin_enc, cos_enc]) + encoding = tf.expand_dims(tf.transpose(encoding), 0) + return encoding diff --git a/iris/policies/layers/keras_positional_encoding_layer_test.py b/iris/policies/layers/keras_positional_encoding_layer_test.py new file mode 100644 index 0000000..806c83e --- /dev/null +++ b/iris/policies/layers/keras_positional_encoding_layer_test.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris.policies.layers import keras_positional_encoding_layer +from absl.testing import absltest + + +class PositionalEncodingTest(absltest.TestCase): + + def test_layer_output(self): + """Tests the output of PositionalEncoding layer.""" + encoding = keras_positional_encoding_layer.PositionalEncoding()(7, 4) + self.assertEqual(encoding.shape, (1, 7, 4)) + +if __name__ == "__main__": + absltest.main() diff --git a/iris/policies/layers/keras_ranking_attention_layer.py b/iris/policies/layers/keras_ranking_attention_layer.py new file mode 100644 index 0000000..6a929cf --- /dev/null +++ b/iris/policies/layers/keras_ranking_attention_layer.py @@ -0,0 +1,70 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A keras layer for ranking based attention.""" + +from typing import Callable +import tensorflow as tf + + +class FavorRankingAttention(tf.keras.layers.Layer): + """A keras layer for ranking based attention. + + A layer that creates a representation of the RGB(D)-image using attention + mechanism from https://arxiv.org/abs/2009.14794. It leverages Performer-ReLU + (go/performer) attention module in order to bypass explicit materialization of + the L x L attention tensor, where L is the number of patches (potentially even + individual pixels). This reduces time complexity of the attention module from + quadratic to linear in L and provides a gateway to processing high-resolution + images, where explicitly calculating attention tensor is not feasible. The + ranking procedure is adopted from https://arxiv.org/abs/2003.08165, where + scores of patches are defined as sums of the entries of the corresponding + column in the attention tensor. + """ + + def __init__( + self, + kernel_transformation: Callable[..., tf.Tensor], + top_k: int = 5) -> None: # pytype: disable=annotation-type-mismatch + """Initializes FavorRankingAttention layer. + + Args: + kernel_transformation: Transformation used to get finite kernel features. + top_k: Number of top patches that will be chosen to "summarize" entire + image. + """ + super().__init__() + self._kernel_transformation = kernel_transformation + self._top_k = top_k + + def call(self, + queries: tf.Tensor, + keys: tf.Tensor, + values: tf.Tensor) -> tf.Tensor: + queries_prime = self._kernel_transformation( + data=tf.expand_dims(queries, axis=1), + is_query=True) + queries_prime = tf.squeeze(queries_prime, axis=1) + keys_prime = self._kernel_transformation( + data=tf.expand_dims(keys, axis=1), + is_query=False) + keys_prime = tf.squeeze(keys_prime, axis=1) + _, length, _ = queries_prime.shape + all_ones = tf.ones([1, length]) + reduced_queries_prime = tf.matmul(all_ones, queries_prime) + scores = tf.matmul(reduced_queries_prime, keys_prime, transpose_b=True) + scores = tf.reshape(scores, (-1, length)) + sorted_idxs = tf.argsort(scores, axis=-1, direction='DESCENDING') + top_idxs = sorted_idxs[:, :self._top_k] + return tf.gather(values, top_idxs, axis=1, batch_dims=1) diff --git a/iris/policies/layers/keras_ranking_attention_layer_test.py b/iris/policies/layers/keras_ranking_attention_layer_test.py new file mode 100644 index 0000000..eb70412 --- /dev/null +++ b/iris/policies/layers/keras_ranking_attention_layer_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris.policies.layers import keras_ranking_attention_layer +from lingvo.core import favor_attention as favor +import numpy as np +import tensorflow as tf +from absl.testing import absltest + + +class FavorRankingAttentionTest(absltest.TestCase): + + def test_layer_output(self): + """Tests the output of RankingAttention layer.""" + query_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="query") + key_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="keys") + value_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="values") + output_layer = keras_ranking_attention_layer.FavorRankingAttention( + kernel_transformation=favor.relu_kernel_transformation, + top_k=2)(query_layer, key_layer, value_layer) + model = tf.keras.models.Model( + inputs=[query_layer, key_layer, value_layer], outputs=[output_layer]) + queries = np.arange(2 * 3 * 4).reshape((2, 3, 4)) + top_values = model.predict((queries, queries, queries)) + self.assertEqual(top_values.shape, (2, 2, 4)) + true_values = np.arange(2 * 3 * 4).reshape((2, 3, 4)) + true_values = np.flip(true_values[:, 1:, :], 1) + np.testing.assert_array_almost_equal(top_values, true_values, 1) + + +if __name__ == "__main__": + absltest.main() diff --git a/iris/policies/layers/keras_trans_attention_layer.py b/iris/policies/layers/keras_trans_attention_layer.py new file mode 100644 index 0000000..6ee4822 --- /dev/null +++ b/iris/policies/layers/keras_trans_attention_layer.py @@ -0,0 +1,79 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A keras layer for performer attention.""" + +from typing import Callable +import tensorflow as tf + + +class FavorTransAttention(tf.keras.layers.Layer): + """A keras layer for FAVOR trans attention. + + A layer that leverages Performer-ReLU (go/performer) attention module in order + to bypass explicit materialization of the L x L attention tensor, where L is + the number of patches (potentially even individual pixels). This reduces time + complexity of the attention module from quadratic to linear in L and provides + a gateway to processing high-resolution images, where explicitly calculating + attention tensor is not feasible. Performer attention is applied to the input + sequence of tokens to transform it into an encoded sequence. + """ + + def __init__( + self, + kernel_transformation: Callable[..., tf.Tensor],) -> None: # pytype: disable=annotation-type-mismatch + """Initializes FavorTransAttention layer. + + Args: + kernel_transformation: Transformation used to get finite kernel features. + """ + super().__init__() + self._kernel_transformation = kernel_transformation + + def call(self, + queries: tf.Tensor, + keys: tf.Tensor, + values: tf.Tensor) -> tf.Tensor: + + # Pass queries and keys through a non-linear kernel transformation to get + # Q' and K' + queries_prime = self._kernel_transformation( + data=tf.expand_dims(queries, axis=1), + is_query=True) + queries_prime = tf.squeeze(queries_prime, axis=1) + keys_prime = self._kernel_transformation( + data=tf.expand_dims(keys, axis=1), + is_query=False) + keys_prime = tf.squeeze(keys_prime, axis=1) + b, l, _ = queries_prime.shape + if b is None: + b = tf.shape(queries_prime)[0] + if l is None: + l = tf.shape(queries_prime)[1] + + # For applying FAVOR attention, product of K' and value vector is multiplied + # by Q' prime without having to materialize the attention matrix + # A = Q'(K')^T + # Multiply K' and value vector + kvs = tf.einsum("blm,bld->bmd", keys_prime, values) # bmd + # Multiply Q' with previous result to get attention output, x + x = tf.einsum("blm,bmd->bld", queries_prime, kvs) # bld + + # For normalization, attention output, x is divided by x_norm. x_norm is + # obtained similarly to x by replacing value vector with all ones. + kvs_norm = tf.einsum("blm,bld->bmd", keys_prime, tf.ones( + (b, l, 1))) # bmd (d=1) + x_norm = tf.einsum("blm,bmd->bld", queries_prime, kvs_norm) # bld + + return x/x_norm diff --git a/iris/policies/layers/keras_trans_attention_layer_test.py b/iris/policies/layers/keras_trans_attention_layer_test.py new file mode 100644 index 0000000..ced9ec0 --- /dev/null +++ b/iris/policies/layers/keras_trans_attention_layer_test.py @@ -0,0 +1,43 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris.policies.layers import keras_trans_attention_layer +from lingvo.core import favor_attention as favor +import numpy as np +import tensorflow as tf +from absl.testing import absltest + + +class FavorMaskingAttentionTest(absltest.TestCase): + + def test_layer_output(self): + """Tests the output of RankingAttention layer.""" + query_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="query") + key_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="keys") + value_layer = tf.keras.layers.Input( + batch_input_shape=(2, 3, 4), dtype="float", name="values") + output_layer = keras_trans_attention_layer.FavorTransAttention( + kernel_transformation=favor.relu_kernel_transformation)( + query_layer, key_layer, value_layer) + model = tf.keras.models.Model( + inputs=[query_layer, key_layer, value_layer], outputs=[output_layer]) + queries = np.arange(2 * 3 * 4).reshape((2, 3, 4)) + values = model.predict((queries, queries, queries)) + self.assertEqual(values.shape, (2, 3, 4)) + self.assertAlmostEqual(values.sum(), 305, delta=0.1) + +if __name__ == "__main__": + absltest.main() diff --git a/iris/policies/pct_policy.py b/iris/policies/pct_policy.py new file mode 100644 index 0000000..d269df0 --- /dev/null +++ b/iris/policies/pct_policy.py @@ -0,0 +1,260 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Policy class that computes action by running jax neural networks.""" +import functools +from typing import Dict, Union + +from flax import linen as nn +import gym +from gym.spaces import utils +from iris.policies import jax_policy +import jax.numpy as jnp +import numpy as np +from scenic.projects.meshdynamics import model +from scenic.projects.pointcloud import models + + +class PCTEncoder(nn.Module): + """Point Cloud Transformer Encoder. + + Attributes: + in_dim: Point cloud feature dimension. + feature_dim: Point cloud encoder feature dim. + emb_dim: Point cloud encoder output dim. + emb_idx: Point cloud index to extract the embedding. If None, all points + will be returned. + encoder_feature_dim: Point cloud encoder feature dim. + kernel_size: Point cloud encoder kernel size. + num_attention_layers: Number of attention layers in the point cloud encoder. + attention_type: str defining attention algorithm; possible values + are 'regular', 'perf-softmax', 'perf-relu' + rpe_masking_type: str defining applied RPE mechanism; possible values + are 'nomask', 'fft', 'flt' + """ + + in_dim: int = 3 + feature_dim: int = 16 + emb_dim: int = 8 + emb_idx: int | None = 0 + encoder_feature_dim: int = 16 + kernel_size: int = 1 + num_attention_layers: int = 2 + num_pre_conv_layers: int = 2 + attention_type: str = 'regular' + rpe_masking_type: str = 'nomask' + pseudolocal_sigma: float = 0.05 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + mask: jnp.ndarray | None = None, + train: bool = True, + coords: jnp.ndarray | None = None, + ): + """Runs PCT encoder and generates point cloud embedding. + + Args: + x: A point clound of shape (batch_sizes, N, in_dim). + mask: Optional boolean mask of shape (batch_sizes, N) applied to the point + cloud. + train: Whether the module is called during training. + coords: xyz-coordinates of the points of shape (batch_sizes, N, 3). + + Returns: + The embedding array of shape (batch_sizes, emb_dim). + """ + batch_dim = True + if jnp.ndim(x) == 2: + batch_dim = False + x = x[None, :, :] + if mask is not None: + mask = mask[None, :] + if coords is not None: + coords = coords[None, :, :] + + if self.attention_type == 'regular': + x = model.PointCloudTransformerEncoder( + in_dim=self.in_dim, + feature_dim=self.feature_dim, + out_dim=self.emb_dim, + encoder_feature_dim=self.encoder_feature_dim, + kernel_size=self.kernel_size, + num_attention_layers=self.num_attention_layers, + num_pre_conv_layers=self.num_pre_conv_layers, + )(x, mask) + else: + if self.attention_type == 'pseudolocal-performer': + attention_fn_configs = dict() + attention_fn_configs['attention_kind'] = 'performer' + attention_fn_configs['performer'] = { + 'masking_type': 'pseudolocal', + 'rf_type': 'hyper', + 'num_features': 128, + 'sigma': self.pseudolocal_sigma, + } + else: + kernel_name_translate = {'perf-softmax': 'softmax', 'perf-relu': 'relu'} + rpe_mask_name_translate = { + 'fft': 'fftmasked', + 'flt': 'sharpmasked', + 'nomask': 'nomask', + } + num_features_translate = {'perf-softmax': 64, 'perf-relu': 0} + use_rand_proj_translate = {'perf-softmax': True, 'perf-relu': False} + + attention_fn_configs = dict() + attention_fn_configs['attention_kind'] = 'performer' + attention_fn_configs['performer'] = { + 'masking_type': rpe_mask_name_translate[self.rpe_masking_type], + 'kernel_transformation': kernel_name_translate[self.attention_type], + 'num_features': num_features_translate[self.attention_type], + 'rpe_method': None, + 'num_realizations': 10, + 'num_sines': 1, + 'use_random_projections': use_rand_proj_translate[ + self.attention_type + ], + 'seed': 41, + } + + x = models.PointCloudTransformerEncoder( + in_dim=self.in_dim, + feature_dim=self.feature_dim, + out_dim=self.emb_dim, + encoder_feature_dim=self.encoder_feature_dim, + kernel_size=self.kernel_size, + num_attention_layers=self.num_attention_layers, + num_pre_conv_layers=self.num_pre_conv_layers, + attention_fn_configs=attention_fn_configs, + )(x, mask=mask, coords=coords) + + if self.emb_idx is not None: + x = x[:, self.emb_idx, :] + + if not batch_dim: + return x[0] + + return x + + +class PCTPolicyNet(nn.Module): + """Point Cloud Transformer Net. + + Attributes: + auxiliary_observations: List of names of observations other than the point + cloud. + in_dim: Point cloud feature dimension. + feature_dim: Point cloud encoder feature dim. + emb_dim: Point cloud encoder output dim. + encoder_feature_dim: Point cloud encoder feature dim. + kernel_size: Point cloud encoder kernel size. + num_attention_layers: Number of attention layers in the point cloud encoder. + fc_layer_dims: Fully connected layer dims. + act_dim: Output action dim. + attention_type: str defining attention algorithm; possible values + are 'regular', 'perf-softmax', 'perf-relu' + rpe_masking_type: str defining applied RPE mechanism; possible values + are 'nomask', 'fft', 'flt' + """ + + auxiliary_observations: list[str] + point_cloud_obs_name: str = 'object_point_cloud' + point_cloud_mask_name: str = 'object_point_cloud_mask' + in_dim: int = 3 + feature_dim: int = 16 + emb_dim: int = 8 + encoder_feature_dim: int = 16 + kernel_size: int = 1 + num_attention_layers: int = 2 + fc_layer_dims: tuple[int, ...] = (8, 8) + act_dim: int = 7 + attention_type: str = 'regular' + rpe_masking_type: str = 'nomask' + + @nn.compact + def __call__(self, ob: dict[str, jnp.ndarray], train: bool = True): + """Runs PCT policy network and generates actions. + + Args: + ob: A dictionary of observations. Value with the key point_cloud_obs_name + should be a point clound of shape (N, in_dim). auxiliary_observations + keys should be present in the ob dict. + train: Whether the module is called during training. + + Returns: + The action array of size act_dim. + """ + mask = ( + ob[self.point_cloud_mask_name] + if self.point_cloud_mask_name in ob + else None + ) + x = PCTEncoder( + in_dim=self.in_dim, + feature_dim=self.feature_dim, + emb_dim=self.emb_dim, + encoder_feature_dim=self.encoder_feature_dim, + kernel_size=self.kernel_size, + num_attention_layers=self.num_attention_layers, + attention_type=self.attention_type, + rpe_masking_type=self.rpe_masking_type, + )(ob[self.point_cloud_obs_name], mask) + all_inputs = [x] + for name in self.auxiliary_observations: + all_inputs.append(ob[name]) + x = jnp.concatenate(all_inputs, axis=-1) + + for dim in self.fc_layer_dims: + x = nn.tanh(nn.Dense(dim)(x)) + x = nn.tanh(nn.Dense(self.act_dim)(x)) + return x + + +class PCTPolicy(jax_policy.JaxPolicy): + """Policy class that computes action by running PCT jax models.""" + + def __init__( + self, ob_space: gym.Space, ac_space: gym.Space, + auxiliary_observations: list[str], seed: int = 42 + ) -> None: + """Initializes a jax policy. See the base class for more details.""" + init_x = ob_space.sample() + init_x['object_point_cloud'] = init_x['object_point_cloud'].squeeze().T + super().__init__( + ob_space=ob_space, ac_space=ac_space, + model=functools.partial( + PCTPolicyNet, + act_dim=utils.flatdim(ac_space), + auxiliary_observations=auxiliary_observations), + init_x=init_x, seed=seed) + + def act(self, ob: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """Maps the observation to action. + + Args: + ob: The observations in reinforcement learning. + + Returns: + The action in reinforcement learning. + """ + ob['object_point_cloud'] = ob['object_point_cloud'].squeeze().T + action = self.model.apply( + self._tree_weights, + ob, + mutable=['batch_stats'])[0] + action = utils.unflatten(self._ac_space, action) + return action diff --git a/iris/policies/pct_policy_test.py b/iris/policies/pct_policy_test.py new file mode 100644 index 0000000..d26cea4 --- /dev/null +++ b/iris/policies/pct_policy_test.py @@ -0,0 +1,83 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from gym import spaces +from iris.policies import pct_policy +import jax +import numpy as np +from absl.testing import absltest + + +class PCTPolicyTest(absltest.TestCase): + + def test_pct_encoder(self): + batch_size = 10 + num_points = 100 + in_dim = 3 + out_dim = 8 + pc = np.random.normal(size=(batch_size, num_points, in_dim)) + pct_enc = pct_policy.PCTEncoder(emb_dim=out_dim, attention_type='perf-relu') + params = pct_enc.init(jax.random.PRNGKey(0), pc) + emb = pct_enc.apply(params, pc) + self.assertEqual(emb.shape, (batch_size, out_dim)) + + def test_pct_encoder_with_mask(self): + batch_size = 10 + num_points = 100 + masked_points = 20 + in_dim = 3 + out_dim = 8 + pc = np.random.normal(size=(batch_size, num_points, in_dim)) + mask = np.concatenate( + [np.ones((batch_size, num_points-masked_points)), + np.zeros((batch_size, masked_points))], axis=-1) + mask = mask.astype(bool) + pct_enc = pct_policy.PCTEncoder(emb_dim=out_dim) + params = pct_enc.init(jax.random.PRNGKey(0), pc) + emb = pct_enc.apply(params, pc, mask) + self.assertEqual(emb.shape, (batch_size, out_dim)) + + emb_no_mask = pct_enc.apply(params, pc[:, :-masked_points], None) + np.testing.assert_allclose(emb_no_mask, emb, atol=1e-1, rtol=1e-1) + + def test_policy_act(self): + """Tests the act function for PCT policy.""" + + ob_space = spaces.Dict({ + 'object_position': spaces.Box( + low=-5, high=5, shape=(3,) + ), + 'object_bounding_box': spaces.Box( + low=-1, high=1, shape=(3,) + ), + 'object_point_cloud': spaces.Box(low=-1, high=1, shape=(3, 100, 1)), + }) + + policy = pct_policy.PCTPolicy( + ob_space=ob_space, + ac_space=spaces.Box(low=-3, high=3, shape=(7,)), + auxiliary_observations=['object_position', 'object_bounding_box',], + ) + + n = policy.get_weights().shape[0] + policy.update_weights(new_weights=np.zeros(n)) + + x = ob_space.sample() + + jax_act = policy.act(x) + np.testing.assert_array_almost_equal(jax_act, np.zeros(7), 2) + + +if __name__ == '__main__': + absltest.main()