Skip to content

Commit

Permalink
1. Add point cloud encoding policy - pct_policy
Browse files Browse the repository at this point in the history
    2. Add custom keras layers module

PiperOrigin-RevId: 658711302
  • Loading branch information
jaindeepali authored and copybara-github committed Aug 13, 2024
1 parent 5f4733a commit d797379
Show file tree
Hide file tree
Showing 12 changed files with 882 additions and 0 deletions.
73 changes: 73 additions & 0 deletions iris/policies/layers/keras_image_encoder_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A keras layer for encoding image into patches."""

from typing import Tuple

import tensorflow as tf


class ImageEncoder(tf.keras.layers.Layer):
"""Keras layer for encoding image into patches."""

def __init__(self,
patch_height: int,
patch_width: int,
stride_height: int,
stride_width: int,
normalize_positions: bool = True) -> None:
"""Initializes Keras layer for encoding image into patches.
Args:
patch_height: Height of image patch for encoding.
patch_width: Width of image patch for encoding.
stride_height: Stride (shift) height for consecutive image patches.
stride_width: Stride (shift) width for consecutive image patches.
normalize_positions: True to normalize patch center positions.
"""
super().__init__()
self._patch_height = patch_height
self._patch_width = patch_width
self._stride_height = stride_height
self._stride_width = stride_width
self._normalize_positions = normalize_positions

def call(self, images: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
batch_shape, image_height, image_width, channels = images.shape
if batch_shape is None:
batch_shape = tf.shape(images)[0]
patches = tf.image.extract_patches(
images,
sizes=[1, self._patch_height, self._patch_width, 1],
strides=[1, self._stride_height, self._stride_width, 1],
rates=[1, 1, 1, 1],
padding='VALID')
encoding = tf.reshape(
patches,
[batch_shape, -1, self._patch_height * self._patch_width * channels])
pos_x = tf.range(self._patch_height // 2, image_height, self._stride_height)
pos_y = tf.range(self._patch_width // 2, image_width, self._stride_width)
if self._normalize_positions:
pos_x /= image_height
pos_y /= image_width
x, y = tf.meshgrid(pos_x, pos_y)
x = tf.transpose(x)
y = tf.transpose(y)
centers = tf.stack([x, y], axis=-1)
centers = tf.reshape(centers, (-1, 2))
centers = tf.tile(centers, (batch_shape, 1))
centers = tf.reshape(centers, (batch_shape, -1, 2))
centers = tf.cast(centers, 'float32')
return encoding, centers
40 changes: 40 additions & 0 deletions iris/policies/layers/keras_image_encoder_layer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from iris.policies.layers import keras_image_encoder_layer
import numpy as np
import tensorflow as tf
from absl.testing import absltest


class ImageEncoderTest(absltest.TestCase):

def test_layer_output(self):
"""Tests the output of ImageEncoder layer."""
input_layer = tf.keras.layers.Input(
batch_input_shape=(2, 5, 6, 2), dtype="float", name="input")
output_layer = keras_image_encoder_layer.ImageEncoder(
patch_height=2,
patch_width=2,
stride_height=1,
stride_width=1)(input_layer)
model = tf.keras.models.Model(inputs=[input_layer], outputs=[output_layer])
images = np.arange(2*5*6*2).reshape((2, 5, 6, 2))
encoding, centers = model.predict(images)[0]
self.assertEqual(encoding.shape, (2, 20, 8))
self.assertEqual(centers.shape, (2, 20, 2))


if __name__ == "__main__":
absltest.main()
75 changes: 75 additions & 0 deletions iris/policies/layers/keras_masking_attention_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A keras layer for masking based attention."""

from typing import Callable
import tensorflow as tf


class FavorMaskingAttention(tf.keras.layers.Layer):
"""A keras layer for masking based attention.
A layer that creates a representation of the RGB(D)-image using attention
mechanism from https://arxiv.org/abs/2009.14794. It leverages Performer-ReLU
(go/performer) attention module in order to bypass explicit materialization of
the L x L attention tensor, where L is the number of patches (potentially even
individual pixels). This reduces time complexity of the attention module from
quadratic to linear in L and provides a gateway to processing high-resolution
images, where explicitly calculating attention tensor is not feasible. The
ranking procedure is adopted from https://arxiv.org/abs/2003.08165, where
scores of patches are defined as sums of the entries of the corresponding
column in the attention tensor. After ranking, top K tokens are preserved and
the rest of them are masked by 0.
"""

def __init__(
self,
kernel_transformation: Callable[..., tf.Tensor],
top_k: int = 5) -> None: # pytype: disable=annotation-type-mismatch
"""Initializes FavorMaskingAttention layer.
Args:
kernel_transformation: Transformation used to get finite kernel features.
top_k: Number of top patches that will be chosen to "summarize" entire
image.
"""
super().__init__()
self._kernel_transformation = kernel_transformation
self._top_k = top_k

def call(self,
queries: tf.Tensor,
keys: tf.Tensor,
values: tf.Tensor) -> tf.Tensor:
queries_prime = self._kernel_transformation(
data=tf.expand_dims(queries, axis=2),
is_query=True)
queries_prime = tf.squeeze(queries_prime, axis=2)
keys_prime = self._kernel_transformation(
data=tf.expand_dims(keys, axis=2),
is_query=False)
keys_prime = tf.squeeze(keys_prime, axis=2)
_, length, _ = queries_prime.shape
all_ones = tf.ones([1, length])
reduced_queries_prime = tf.matmul(all_ones, queries_prime)
scores = tf.matmul(reduced_queries_prime, keys_prime, transpose_b=True)
scores = tf.reshape(scores, (-1, length))
sorted_idxs = tf.argsort(scores, axis=-1, direction='DESCENDING')
cutoff = tf.gather(
scores, sorted_idxs[:, self._top_k], axis=1, batch_dims=1)
cond = scores > tf.expand_dims(cutoff, -1)
return tf.where(tf.expand_dims(cond, -1),
values,
tf.zeros_like(values))
46 changes: 46 additions & 0 deletions iris/policies/layers/keras_masking_attention_layer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from iris.policies.layers import keras_masking_attention_layer
from lingvo.core import favor_attention as favor
import numpy as np
import tensorflow as tf
from absl.testing import absltest


class FavorMaskingAttentionTest(absltest.TestCase):

def test_layer_output(self):
"""Tests the output of RankingAttention layer."""
query_layer = tf.keras.layers.Input(
batch_input_shape=(2, 3, 4), dtype="float", name="query")
key_layer = tf.keras.layers.Input(
batch_input_shape=(2, 3, 4), dtype="float", name="keys")
value_layer = tf.keras.layers.Input(
batch_input_shape=(2, 3, 4), dtype="float", name="values")
output_layer = keras_masking_attention_layer.FavorMaskingAttention(
kernel_transformation=favor.relu_kernel_transformation,
top_k=2)(query_layer, key_layer, value_layer)
model = tf.keras.models.Model(
inputs=[query_layer, key_layer, value_layer], outputs=[output_layer])
queries = np.arange(2 * 3 * 4).reshape((2, 3, 4))
top_values = model.predict((queries, queries, queries))
self.assertEqual(top_values.shape, (2, 3, 4))
true_values = np.arange(2 * 3 * 4).reshape((2, 3, 4))
true_values[:, 0, :] = 0
np.testing.assert_array_almost_equal(top_values, true_values, 1)


if __name__ == "__main__":
absltest.main()
40 changes: 40 additions & 0 deletions iris/policies/layers/keras_positional_encoding_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A keras layer for positional encoding."""

from typing import Tuple

import tensorflow as tf


class PositionalEncoding(tf.keras.layers.Layer):
"""Keras layer for positional encoding."""

def call(self,
seq_len: int,
encoding_dimension: int) -> Tuple[tf.Tensor, tf.Tensor]:
num_freq = encoding_dimension // 2
indices = tf.expand_dims(tf.range(seq_len), 0)
indices = tf.tile(indices, [num_freq, 1])
freq_fn = lambda k: 1.0/(10000 ** (2*k/encoding_dimension))
freq = tf.keras.layers.Lambda(freq_fn)(tf.range(num_freq))
freq = tf.expand_dims(freq, 1)
freq = tf.tile(freq, [1, seq_len])
args = tf.multiply(freq, tf.cast(indices, dtype=tf.float64))
sin_enc = tf.math.sin(args)
cos_enc = tf.math.sin(args)
encoding = tf.keras.layers.Concatenate(axis=0)([sin_enc, cos_enc])
encoding = tf.expand_dims(tf.transpose(encoding), 0)
return encoding
27 changes: 27 additions & 0 deletions iris/policies/layers/keras_positional_encoding_layer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from iris.policies.layers import keras_positional_encoding_layer
from absl.testing import absltest


class PositionalEncodingTest(absltest.TestCase):

def test_layer_output(self):
"""Tests the output of PositionalEncoding layer."""
encoding = keras_positional_encoding_layer.PositionalEncoding()(7, 4)
self.assertEqual(encoding.shape, (1, 7, 4))

if __name__ == "__main__":
absltest.main()
70 changes: 70 additions & 0 deletions iris/policies/layers/keras_ranking_attention_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A keras layer for ranking based attention."""

from typing import Callable
import tensorflow as tf


class FavorRankingAttention(tf.keras.layers.Layer):
"""A keras layer for ranking based attention.
A layer that creates a representation of the RGB(D)-image using attention
mechanism from https://arxiv.org/abs/2009.14794. It leverages Performer-ReLU
(go/performer) attention module in order to bypass explicit materialization of
the L x L attention tensor, where L is the number of patches (potentially even
individual pixels). This reduces time complexity of the attention module from
quadratic to linear in L and provides a gateway to processing high-resolution
images, where explicitly calculating attention tensor is not feasible. The
ranking procedure is adopted from https://arxiv.org/abs/2003.08165, where
scores of patches are defined as sums of the entries of the corresponding
column in the attention tensor.
"""

def __init__(
self,
kernel_transformation: Callable[..., tf.Tensor],
top_k: int = 5) -> None: # pytype: disable=annotation-type-mismatch
"""Initializes FavorRankingAttention layer.
Args:
kernel_transformation: Transformation used to get finite kernel features.
top_k: Number of top patches that will be chosen to "summarize" entire
image.
"""
super().__init__()
self._kernel_transformation = kernel_transformation
self._top_k = top_k

def call(self,
queries: tf.Tensor,
keys: tf.Tensor,
values: tf.Tensor) -> tf.Tensor:
queries_prime = self._kernel_transformation(
data=tf.expand_dims(queries, axis=1),
is_query=True)
queries_prime = tf.squeeze(queries_prime, axis=1)
keys_prime = self._kernel_transformation(
data=tf.expand_dims(keys, axis=1),
is_query=False)
keys_prime = tf.squeeze(keys_prime, axis=1)
_, length, _ = queries_prime.shape
all_ones = tf.ones([1, length])
reduced_queries_prime = tf.matmul(all_ones, queries_prime)
scores = tf.matmul(reduced_queries_prime, keys_prime, transpose_b=True)
scores = tf.reshape(scores, (-1, length))
sorted_idxs = tf.argsort(scores, axis=-1, direction='DESCENDING')
top_idxs = sorted_idxs[:, :self._top_k]
return tf.gather(values, top_idxs, axis=1, batch_dims=1)
Loading

0 comments on commit d797379

Please sign in to comment.