Skip to content

Commit

Permalink
Update DALI TensorFlow examples to work with 2.11 (#4554)
Browse files Browse the repository at this point in the history
- updates RN50 example
- updates YOLO example
- updates EfficientDet example

Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL committed Jan 9, 2023
1 parent 07da929 commit c572c3f
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def get_optimizer(params, *args):

if params["optimizer"].lower() == "sgd":
logging.info("Use SGD optimizer")
optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=params["momentum"])
optimizer = tf.keras.optimizers.legacy.SGD(learning_rate, momentum=params["momentum"])
elif params["optimizer"].lower() == "adam":
logging.info("Use Adam optimizer")
optimizer = tf.keras.optimizers.Adam(learning_rate)
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate)
else:
raise ValueError("optimizers should be adam or sgd")
return optimizer
Expand Down
7 changes: 2 additions & 5 deletions docs/examples/use_cases/tensorflow/resnet-n/nvutils/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,13 @@
from distutils.version import StrictVersion

import tensorflow as tf
import tensorflow.keras as keras
import keras
import os
import time
import re
import numpy as np
import horovod.tensorflow.keras as hvd

from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
from tensorflow.python.keras import backend
from keras import backend
print(tf.__version__)
if StrictVersion(tf.__version__) > StrictVersion("2.1.0"):
if StrictVersion(tf.__version__) >= StrictVersion("2.4.0"):
Expand Down
13 changes: 3 additions & 10 deletions docs/examples/use_cases/tensorflow/resnet-n/nvutils/runner_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,12 @@
from distutils.version import StrictVersion

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.python.ops import data_flow_ops
import keras
import os
import sys
import time
import argparse
import datetime
import random
import re
import numpy as np
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
from tensorflow.python.keras import backend

from keras import backend
print(tf.__version__)
if StrictVersion(tf.__version__) > StrictVersion("2.1.0"):
if StrictVersion(tf.__version__) >= StrictVersion("2.4.0"):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/use_cases/tensorflow/yolov4/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def dataset_fn(input_context):
with strategy.scope():
model = YOLOv4Model()
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=lr_fn)
optimizer=tf.keras.optimizers.legacy.SGD(learning_rate=lr_fn)
)

if start_weights:
Expand Down
2 changes: 0 additions & 2 deletions qa/TL3_EfficientDet_convergence/test_tensorflow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,4 @@ python train.py
--ckpt_dir . \
--output_filename out_weights_1.h5 2>&1 | tee $LOG

popd

CLEAN_AND_EXIT ${PIPESTATUS[0]}

0 comments on commit c572c3f

Please sign in to comment.