Skip to content

Commit

Permalink
[Docs] Add user documentation about Collective training interface.
Browse files Browse the repository at this point in the history
Signed-off-by: JunqiHu <[email protected]>
  • Loading branch information
Mesilenceki committed Jul 12, 2023
1 parent 7acd267 commit df11d70
Show file tree
Hide file tree
Showing 3 changed files with 485 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
FROM alideeprec/deeprec-build:deeprec-dev-gpu-py38-cu116-ubuntu20.04

RUN apt-get update && \
apt-get install -y \
--allow-unauthenticated \
--no-install-recommends \
pkg-config \
libssl-dev \
libcurl4-openssl-dev \
zlib1g-dev \
libhdf5-dev \
wget \
curl \
inetutils-ping \
net-tools \
unzip \
git \
vim \
cmake \
clang-format-7 \
openssh-server openssh-client \
openmpi-bin openmpi-common libopenmpi-dev libgtk2.0-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

RUN wget -nv -O /opt/openmpi-4.1.1.tar.gz \
https://www.open-mpi.org/software/ompi/v4.1/downloads/openmpi-4.1.1.tar.gz && \
cd /opt/ && tar -xvzf ./openmpi-4.1.1.tar.gz && \
cd openmpi-4.1.1 && ./configure && make && make install

RUN git clone https://github.com/DeepRec-AI/HybridBackend.git /opt/HybridBackend

ENV HYBRIDBACKEND_USE_CXX11_ABI=0 \
HYBRIDBACKEND_WITH_ARROW_HDFS=ON \
HYBRIDBACKEND_WITH_ARROW_S3=ON \
TMP=/tmp

RUN cd /opt/HybridBackend/build/arrow && \
ARROW_USE_CXX11_ABI=${HYBRIDBACKEND_USE_CXX11_ABI} \
ARROW_HDFS=${HYBRIDBACKEND_WITH_ARROW_HDFS} \
ARROW_S3=${HYBRIDBACKEND_WITH_ARROW_S3} \
./build.sh /opt/arrow

RUN pip install -U --no-cache-dir \
Cython \
nvidia-pyindex \
pybind11 \
tqdm && \
pip install -U --no-cache-dir \
nvidia-nsys-cli

ARG TF_REPO=https://github.com/DeepRec-AI/DeepRec.git
ARG TF_TAG=main

RUN git clone ${TF_REPO} -b ${TF_TAG} /opt/DeepRec

RUN wget -nv -O /opt/DeepRec/install_bazel.sh \
http://pythonrun.oss-cn-zhangjiakou.aliyuncs.com/bazel-0.26.1-installer-linux-x86_64.sh && \
chmod 777 /opt/DeepRec/install_bazel.sh && /opt/DeepRec/install_bazel.sh


ENV TF_NEED_CUDA=1 \
TF_CUDA_PATHS=/usr,/usr/local/cuda \
TF_CUDA_VERSION=11.6 \
TF_CUBLAS_VERSION=11 \
TF_CUDNN_VERSION=8 \
TF_NCCL_VERSION=2 \
TF_CUDA_CLANG=0 \
TF_DOWNLOAD_CLANG=0 \
TF_NEED_TENSORRT=0 \
TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0" \
TF_ENABLE_XLA=1 \
TF_NEED_MPI=0 \
CC_OPT_FLAGS="-march=skylake -Wno-sign-compare" \
CXX_OPT_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0"

RUN cd /opt/DeepRec && \
yes "" | bash ./configure || true

RUN --mount=type=cache,target=/var/cache/bazel.tensorflow \
cd /opt/DeepRec && \
bazel build \
--disk_cache=/var/cache/bazel.tensorflow \
--config=nogcp \
--config=cuda \
--config=xla \
--verbose_failures \
--cxxopt="${CXX_OPT_FLAGS}" \
--host_cxxopt="${CXX_OPT_FLAGS}" \
--define tensorflow_mkldnn_contraction_kernel=0 \
//tensorflow/tools/pip_package:build_pip_package

RUN mkdir -p /src/dist && \
cd /opt/DeepRec && \
./bazel-bin/tensorflow/tools/pip_package/build_pip_package \
/src/dist --gpu --project_name tensorflow

RUN pip install --no-cache-dir --user \
/src/dist/tensorflow-*.whl && \
rm -f /src/dist/tensorflow-*.whl

RUN mkdir -p \
$(pip show tensorflow | grep Location | cut -d " " -f 2)/tensorflow_core/include/third_party/gpus/cuda/ && \
ln -sf /usr/local/cuda/include \
$(pip show tensorflow | grep Location | cut -d " " -f 2)/tensorflow_core/include/third_party/gpus/cuda/include

RUN cd /opt/DeepRec/ && \
cp tensorflow/core/kernels/gpu_device_array* \
$(pip show tensorflow | grep Location | cut -d " " -f 2)/tensorflow_core/include/tensorflow/core/kernels

RUN cd /opt/DeepRec && \
bazel build --disk_cache=/var/cache/bazel.tensorflow \
-j 16 -c opt --config=opt //tensorflow/tools/pip_package:build_sok && \
./bazel-bin/tensorflow/tools/pip_package/build_sok

ENV ARROW_INCLUDE=/opt/arrow/include \
ARROW_LIB=/opt/arrow/lib \
ZSTD_LIB=/opt/arrow/lib

# Configure HybridBackend
ENV HYBRIDBACKEND_WITH_CUDA=ON \
HYBRIDBACKEND_WITH_NCCL=ON \
HYBRIDBACKEND_WITH_ARROW_ZEROCOPY=ON \
HYBRIDBACKEND_WITH_TENSORFLOW_HALF=OFF \
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=99881015 \
HYBRIDBACKEND_USE_CXX11_ABI=0 \
HYBRIDBACKEND_USE_RUFF=1 \
HYBRIDBACKEND_WHEEL_ALIAS=-deeprec-cu116 \
TF_DISABLE_EV_ALLOCATOR=true

RUN cd /opt/HybridBackend && make -j32

RUN pip install --no-cache-dir --user \
/opt/HybridBackend/build/wheel/hybridbackend_deeprec*.whl

RUN rm -rf /opt/DeepRec /opt/HybridBackend && /opt/openmpi-4.1.1.tar.gz
176 changes: 176 additions & 0 deletions docs/docs_en/Collective-Training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Collective Training

## Background

For sparse recommendation models like DLRM, there are a large number of parameters and heavy GEMM operations. The asynchronous training paradigm of PS makes it difficult to fully utilize the GPUs in the cluster to accelerate the entire training/inference process.We try to place all the parameters on the worker, but the large amount of memory consumed by the parameters(Embedding) cannot be stored on a single GPU, so we need to perform sharding to place on all GPUs.Native Tensorflow did not support model parallel training (MP), and the community has many excellent plug-ins based on Tensorflow, such as HybridBackend (hereinafter referred to as HB), SparseOperationKit (hereinafter referred to as SOK), and so on. DeepRec provides a unified synchronous training interface `CollectiveStrategy` for users to choose and use. Users can use different synchronous training frameworks with very little code.

## Interface Introduction

1. Currently the interface supports HB and SOK, users can choose through the environment variable `COLLECTIVE_STRATEGY`. `COLLECTIVE_STRATEGY` can configure hb, sok corresponding to HB and SOK respectively. The difference from normal startup of Tensorflow tasks is that when users use synchronous training, they need to pull up through additional modules, which need to be started in the following way:

```bash
CUDA_VISIBLE_DEVICES=0,1 COLLECTIVE_STRATEGY=hb python3 -m tensorflow.python.distribute.launch <python script.py>
```
If the environment variable is not configured with `CUDA_VISIBLE_DEVICES`, the process will pull up the training sub-processes with the number of GPUs in the current environment by default.

2. In the user script, a `CollectiveStrategy` needs to be initialized to complete the construction of the model.

```python
class CollectiveStrategy:
def scope(self, *args, **kwargs):
pass
def embedding_scope(self, **kwargs):
pass
def world_size(self):
pass
def rank(self):
pass
def estimator(self):
pass
def export_saved_model(self):
pass
```

Following steps below to using synchronous training:
- Mark with strategy.scope() before the entire model definition.
- Use the embedding_scope() flag where model parallelism is required (embedding layer)
- Use export_saved_model when exporting
- (Optional) In addition, the strategy also provides the estimator interface for users to use.

## Example

**MonitoredTrainingSession**

The following example guides users how to construct Graph through tf.train.MonitoredTrainingSession.

```python
import tensorflow as tf
from tensorflow.python.distribute.group_embedding_collective_strategy import CollectiveStrategy

#STEP1: initialize a collective strategy
strategy = CollectiveStrategy()
#STEP2: define the data parallel scope
with strategy.scope(), tf.Graph().as_default():
#STEP3: define the model parallel scope
with strategy.embedding_scope():
var = tf.get_variable(
'var_1',
shape=(1000, 3),
initializer=tf.ones_initializer(tf.float32),
partitioner=tf.fixed_size_partitioner(num_shards=strategy.world_size())
)
emb = tf.nn.embedding_lookup(
var, tf.cast([0, 1, 2, 5, 6, 7], tf.int64))
fun = tf.multiply(emb, 2.0, name='multiply')
loss = tf.reduce_sum(fun, name='reduce_sum')
opt = tf.train.FtrlOptimizer(
0.1,
l1_regularization_strength=2.0,
l2_regularization_strength=0.00001)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
with tf.train.MonitoredTrainingSession('') as sess:
emb_result, loss_result, _ = sess.run([emb, loss, train_op])
print (emb_result, loss_result)
```

**Estimator**

The following example guides users how to construct Graph through tf.estimator.Estimator.
```python
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.distribute.group_embedding_collective_strategy import CollectiveStrategy

#STEP1: initialize a collective strategy
strategy = CollectiveStrategy()
#STEP2: define the data parallel scope
with strategy.scope(), tf.Graph().as_default():
def input_fn():
ratings = tfds.load("movie_lens/100k-ratings", split="train")
ratings = ratings.map(
lambda x: {
"movie_id": tf.strings.to_number(x["movie_id"], tf.int64),
"user_id": tf.strings.to_number(x["user_id"], tf.int64),
"user_rating": x["user_rating"]
})
shuffled = ratings.shuffle(1_000_000,
seed=2021,
reshuffle_each_iteration=False)
dataset = shuffled.batch(256)
return dataset

def input_receiver():
r'''Prediction input receiver.
'''
inputs = {
"movie_id": tf.placeholder(dtype=tf.int64, shape=[None]),
"user_id": tf.placeholder(dtype=tf.int64, shape=[None]),
"user_rating": tf.placeholder(dtype=tf.float32, shape=[None])
}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)

def model_fn(features, labels, mode, params):
r'''Model function for estimator.
'''
del params
movie_id = features["movie_id"]
user_id = features["user_id"]
rating = features["user_rating"]

embedding_columns = [
tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_embedding(
"movie_id", dtype=tf.int64),
dimension=16,
initializer=tf.random_uniform_initializer(-1e-3, 1e-3)),
tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_embedding(
"user_id", dtype=tf.int64),
dimension=16,
initializer=tf.random_uniform_initializer(-1e-3, 1e-3))
]
#STEP3: define the model parallel scope
with strategy.embedding_scope():
with tf.variable_scope(
'embedding',
partitioner=tf.fixed_size_partitioner(
strategy.world_size)):
deep_features = [
tf.feature_column.input_layer(features, [c])
for c in embedding_columns]
emb = tf.concat(deep_features, axis=-1)
logits = tf.multiply(emb, 2.0, name='multiply')

if mode == tf.estimator.ModeKeys.TRAIN:
labels = tf.reshape(tf.to_float(labels), shape=[-1, 1])
loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(labels, logits))
step = tf.train.get_or_create_global_step()
opt = tf.train.AdagradOptimizer(learning_rate=self._args.lr)
train_op = opt.minimize(loss, global_step=step)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
training_chief_hooks=[])

return None
estimator = strategy.estimator(model_fn=model_fn,
model_dir="./",
config=None)
estimator.train_and_evaluate(
tf.estimator.TrainSpec(
input_fn=input_fn,
max_steps=50),
tf.estimator.EvalSpec(
input_fn=input_fn))
estimator.export_saved_model("./", input_receiver)
```

## 附录

- Currently DeepRec provides the corresponding GPU image for users to use (alideeprec/deeprec-release:deeprec2304-gpu-py38-cu116-ubuntu20.04-hybridbackend), users can also refer to [Dockerfile](../../cibuild/dockerfiles/Dockerfile.devel-py3.8-cu116-ubuntu20.04-hybridbackend)
- We also provides more detailed demos about the above two usage methods, see: [ModelZoo](../../modelzoo/features/grouped_embedding)

- If further optimization is required, there are more fine-tuning parameters for HB and SOK, please refer to:
[SOK](./SOK.md)[HB](https://github.com/DeepRec-AI/HybridBackend)
Loading

0 comments on commit df11d70

Please sign in to comment.