Skip to content

Commit

Permalink
use keras for init variables
Browse files Browse the repository at this point in the history
  • Loading branch information
sineeli committed Nov 6, 2024
1 parent 581d152 commit c6b700a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 21 deletions.
19 changes: 4 additions & 15 deletions keras_hub/src/models/retinanet/retinanet_image_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
from keras_hub.src.utils.tensor_utils import convert_preprocessing_inputs
from keras_hub.src.utils.tensor_utils import preprocessing_function


Expand Down Expand Up @@ -30,24 +29,14 @@ def call(self, inputs):
x = inputs
# Rescaling Image
if self.scale is not None:
x = x * convert_preprocessing_inputs(
self._expand_non_channel_dims(self.scale, x)
)
x = x * self._expand_non_channel_dims(self.scale, x)
if self.offset is not None:
x = x + convert_preprocessing_inputs(
self._expand_non_channel_dims(self.offset, x)
)

x = x + self._expand_non_channel_dims(self.offset, x)
# By default normalize using imagenet mean and std
if self.norm_mean:
x = x - convert_preprocessing_inputs(
self._expand_non_channel_dims(self.norm_mean, x)
)

x = x - self._expand_non_channel_dims(self.norm_mean, x)
if self.norm_std:
x = x / convert_preprocessing_inputs(
self._expand_non_channel_dims(self.norm_std, x)
)
x = x / self._expand_non_channel_dims(self.norm_std, x)

return x

Expand Down
16 changes: 10 additions & 6 deletions keras_hub/src/models/retinanet/retinanet_object_detector_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from keras import ops
from keras import random

from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
Expand Down Expand Up @@ -69,14 +70,17 @@ def setUp(self):
}

self.input_size = 512
self.images = np.random.uniform(
low=0, high=255, size=(1, self.input_size, self.input_size, 3)
).astype("float32")
self.images = random.uniform(
shape=(1, self.input_size, self.input_size, 3),
minval=0,
maxval=255,
dtype="float32",
)
self.labels = {
"boxes": np.array(
"boxes": ops.convert_to_tensor(
[[[20.0, 10.0, 12.0, 11.0], [30.0, 20.0, 40.0, 12.0]]]
),
"classes": np.array([[0, 2]]),
"classes": ops.convert_to_tensor([[0, 2]]),
}

self.train_data = (self.images, self.labels)
Expand Down

0 comments on commit c6b700a

Please sign in to comment.