From d1734f93cb447642db6f1257ecb6c253314f84f4 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Mon, 5 Aug 2024 13:38:17 +0300 Subject: [PATCH] [onert-micro] Support training configure tool This pr supports training configure tool in onert-micro. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- onert-micro/onert-micro/CMakeLists.txt | 5 + onert-micro/onert-micro/include/OMConfig.h | 5 + .../include/OMTrainingInterpreter.h | 5 + .../onert-micro/include/core/OMKernelType.h | 10 + .../include/core/OMRuntimeContext.h | 27 +- .../onert-micro/include/core/OMRuntimeShape.h | 2 + .../include/core/OMRuntimeStorage.h | 10 +- .../include/core/OMTrainingRuntimeModule.h | 5 + .../include/core/memory/OMMemoryManager.h | 7 + .../include/core/memory/OMRuntimeAllocator.h | 16 +- .../core/reader/OMTrainingConfigFileReader.h | 79 ++++++ .../include/core/train/OMTrainingHandler.h | 10 +- .../include/core/train/OMTrainingStorage.h | 20 ++ .../include/import/OMDynamicShapesHandler.h | 48 ++++ .../include/import/OMExecutionPlanCreator.h | 8 +- .../include/pal/common/PALConv2DInputGrad.h | 2 +- .../include/pal/common/PALConv2DWeightGrad.h | 8 +- .../pal/common/PALFullyConnectedWeightGrad.h | 19 +- .../onert-micro/include/pal/common/PALUtils.h | 23 ++ .../include/train/OMBackpropExecuteArgs.h | 2 + .../include/train/train_optimizers/Adam.h | 16 +- .../include/train/train_optimizers/SGD.h | 13 +- onert-micro/onert-micro/src/CMakeLists.txt | 4 +- .../onert-micro/src/core/CMakeLists.txt | 3 +- .../onert-micro/src/core/OMRuntimeContext.cpp | 2 +- .../onert-micro/src/core/OMRuntimeGraph.cpp | 10 +- .../onert-micro/src/core/OMRuntimeModule.cpp | 21 +- .../src/core/OMTrainingRuntimeModule.cpp | 42 ++- .../src/core/memory/OMMemoryManager.cpp | 25 ++ .../src/core/memory/OMRuntimeAllocator.cpp | 80 +++++- .../reader/OMTrainingConfigFileReader.cpp | 97 +++++++ .../src/core/train/OMTrainingHandler.cpp | 29 +- .../src/core/train/OMTrainingStorage.cpp | 17 ++ .../src/execute/OMKernelExecute.cpp | 4 + .../onert-micro/src/import/CMakeLists.txt | 1 + .../src/import/OMDynamicShapesHandler.cpp | 123 ++++++++ .../src/import/OMExecutionPlanCreator.cpp | 268 ++++++++++++++++-- .../src/import/kernels/FullyConnected.cpp | 6 +- .../src/import/kernels/Reshape.cpp | 4 +- .../src/train/OMBackpropExecute.cpp | 55 +++- .../onert-micro/src/train/kernels/Conv2D.cpp | 46 ++- .../src/train/kernels/FullyConnected.cpp | 53 ++-- .../onert-micro/src/train/kernels/Softmax.cpp | 6 + .../src/train/train_optimizers/Adam.cpp | 183 ++++++++++-- .../src/train/train_optimizers/SGD.cpp | 88 +++++- 45 files changed, 1359 insertions(+), 148 deletions(-) create mode 100644 onert-micro/onert-micro/include/core/reader/OMTrainingConfigFileReader.h create mode 100644 onert-micro/onert-micro/include/import/OMDynamicShapesHandler.h create mode 100644 onert-micro/onert-micro/src/core/reader/OMTrainingConfigFileReader.cpp create mode 100644 onert-micro/onert-micro/src/import/OMDynamicShapesHandler.cpp diff --git a/onert-micro/onert-micro/CMakeLists.txt b/onert-micro/onert-micro/CMakeLists.txt index a48ad1628fe..308f9b90c25 100644 --- a/onert-micro/onert-micro/CMakeLists.txt +++ b/onert-micro/onert-micro/CMakeLists.txt @@ -40,6 +40,11 @@ if (DIS_FLOAT) add_definitions(-DDIS_FLOAT) endif() +# To enable memory estimate +if (OM_MEMORY_ESTIMATE) + add_definitions(-DOM_MEMORY_ESTIMATE) +endif() + # To enable training part if (ENABLE_TRAINING) add_definitions(-DENABLE_TRAINING) diff --git a/onert-micro/onert-micro/include/OMConfig.h b/onert-micro/onert-micro/include/OMConfig.h index 0ffbf024b1f..40b0350afd1 100644 --- a/onert-micro/onert-micro/include/OMConfig.h +++ b/onert-micro/onert-micro/include/OMConfig.h @@ -41,6 +41,7 @@ enum OMMetrics MAE_METRICS, CROSS_ENTROPY_METRICS, ACCURACY, + NONE, }; /* @@ -65,6 +66,8 @@ enum OMLoss * beta_squares - used by ADAM optimizer * epsilon - used by ADAM optimizer * num_Step - used by ADAM optimizer + * training_config_info_data - pointer to the training config data, to store training specific + * scenario (default null) */ struct OMTrainingContext { @@ -79,6 +82,8 @@ struct OMTrainingContext uint32_t num_step = 0; uint32_t num_epoch = 0; uint32_t epochs = 0; + + char *training_config_info_data = nullptr; }; /* diff --git a/onert-micro/onert-micro/include/OMTrainingInterpreter.h b/onert-micro/onert-micro/include/OMTrainingInterpreter.h index 0c1a35defc8..b3dbcd8987b 100644 --- a/onert-micro/onert-micro/include/OMTrainingInterpreter.h +++ b/onert-micro/onert-micro/include/OMTrainingInterpreter.h @@ -92,6 +92,11 @@ class OMTrainingInterpreter void *getInputData(uint32_t position); void *getInputDataAt(uint32_t position); void *getOutputDataAt(uint32_t position); + +#ifdef OM_MEMORY_ESTIMATE + size_t getPeakFootprintMemory() { return _training_runtime_module.getPeakFootprintMemory(); } + size_t getCurrentFootprintMemory() { return _training_runtime_module.getPeakFootprintMemory(); } +#endif // OM_MEMORY_ESTIMATE }; } // namespace onert_micro diff --git a/onert-micro/onert-micro/include/core/OMKernelType.h b/onert-micro/onert-micro/include/core/OMKernelType.h index 229c96e30ef..6e5e9134b4a 100644 --- a/onert-micro/onert-micro/include/core/OMKernelType.h +++ b/onert-micro/onert-micro/include/core/OMKernelType.h @@ -25,6 +25,16 @@ namespace onert_micro namespace core { +// Enum to indicate the degree(rank) to which part of the operation we will train +enum OpTrainableRankType +{ + ALL = 0, // 0 - Train all weights in the operation + ONLY_BIAS = 1, // 1 - Train bias only in the operation + UP_1_2_PART = 2, // 2 - Train the upper 1/2 part of the operation + LOWER_1_2_PART = 3, // 3 - Train the lower 1/2 part of the operation + // TODO add more +}; + enum OMKernelType { Normal, diff --git a/onert-micro/onert-micro/include/core/OMRuntimeContext.h b/onert-micro/onert-micro/include/core/OMRuntimeContext.h index 1a54c37184f..37e4cd7e9b7 100644 --- a/onert-micro/onert-micro/include/core/OMRuntimeContext.h +++ b/onert-micro/onert-micro/include/core/OMRuntimeContext.h @@ -19,8 +19,12 @@ #include "OMStatus.h" +#include "core/OMRuntimeShape.h" +#include "core/OMRuntimeStorage.h" + #include "reader/OMCircleReader.h" #include "reader/OMWeightOnlyFormatReader.h" +#include "reader/OMTrainingConfigFileReader.h" #include @@ -32,8 +36,9 @@ namespace core class OMRuntimeContext { private: - reader::OMCircleReader _reader; - reader::OMWeightOnlyFormatReader _wof_reader; + reader::OMCircleReader _reader{}; + reader::OMWeightOnlyFormatReader _wof_reader{}; + reader::OMTrainingConfigReader _train_config_reader{}; public: OMRuntimeContext() = default; @@ -66,7 +71,23 @@ class OMRuntimeContext return Ok; } - const bool isConstTensor(uint32_t tensor_index) { return _reader.isConstTensor(tensor_index); } + OMStatus setTrainConfigFile(char *train_config_file_ptr) + { + OMStatus status = Ok; + _train_config_reader.parse(train_config_file_ptr); + + status = _train_config_reader.validate(&_reader); + if (status != Ok) + return status; + return Ok; + } + + std::unordered_map getTrainableOpsIndexes() + { + return _train_config_reader.getTrainableOpsIndexes(); + } + + bool isConstTensor(uint32_t tensor_index) { return _reader.isConstTensor(tensor_index); } const reader::CircleValues *getCircleOutputs() { return _reader.outputs(); } const reader::CircleValues *getCircleInputs() { return _reader.inputs(); } diff --git a/onert-micro/onert-micro/include/core/OMRuntimeShape.h b/onert-micro/onert-micro/include/core/OMRuntimeShape.h index f1262d10133..12fdfe12b29 100644 --- a/onert-micro/onert-micro/include/core/OMRuntimeShape.h +++ b/onert-micro/onert-micro/include/core/OMRuntimeShape.h @@ -55,6 +55,8 @@ class OMRuntimeShape // vector. inline int flatSize() const { + if (_size == 0) + return 0; int buffer_size = 1; const int *dims_data = reinterpret_cast(dimsData()); for (int i = 0; i < _size; i++) diff --git a/onert-micro/onert-micro/include/core/OMRuntimeStorage.h b/onert-micro/onert-micro/include/core/OMRuntimeStorage.h index 0a1e259ba03..e7731de6bbd 100644 --- a/onert-micro/onert-micro/include/core/OMRuntimeStorage.h +++ b/onert-micro/onert-micro/include/core/OMRuntimeStorage.h @@ -34,7 +34,7 @@ class OMRuntimeStorage { private: #ifndef DIS_DYN_SHAPES - std::unordered_map _tensor_index_to_dynamic_tensor_size; + std::unordered_map _tensor_index_to_dynamic_tensor_size; #endif std::unordered_map _tensor_index_to_data; std::unordered_map _operator_index_to_kernel_type; @@ -70,18 +70,18 @@ class OMRuntimeStorage return Ok; } #ifndef DIS_DYN_SHAPES - int32_t getDynamicTensorSize(uint16_t tensor_index) + OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index) { auto it = _tensor_index_to_dynamic_tensor_size.find(tensor_index); if (it == _tensor_index_to_dynamic_tensor_size.end()) - return -1; + return {}; // Return empty return it->second; } - OMStatus setDynamicTensorSize(uint16_t tensor_index, uint32_t dynamic_size) + OMStatus setDynamicRuntimeShape(uint16_t tensor_index, const OMRuntimeShape &shape) { - _tensor_index_to_dynamic_tensor_size[tensor_index] = dynamic_size; + _tensor_index_to_dynamic_tensor_size[tensor_index] = shape; return Ok; } #endif // DIS_DYN_SHAPES diff --git a/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h b/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h index d9201374b5e..9b53c3bd888 100644 --- a/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h +++ b/onert-micro/onert-micro/include/core/OMTrainingRuntimeModule.h @@ -92,6 +92,11 @@ class OMTrainingRuntimeModule : public OMRuntimeModule OMStatus loadCheckpointData(OMConfig &config, const char *data); void *getInputData(int32_t index); + +#ifdef OM_MEMORY_ESTIMATE + size_t getPeakFootprintMemory(); + size_t getCurrentFootprintMemory(); +#endif // OM_MEMORY_ESTIMATE }; } // namespace core diff --git a/onert-micro/onert-micro/include/core/memory/OMMemoryManager.h b/onert-micro/onert-micro/include/core/memory/OMMemoryManager.h index 61fc8f85a02..c7a86cbe82c 100644 --- a/onert-micro/onert-micro/include/core/memory/OMMemoryManager.h +++ b/onert-micro/onert-micro/include/core/memory/OMMemoryManager.h @@ -20,6 +20,7 @@ #include "OMStatus.h" #include +#include namespace onert_micro { @@ -30,6 +31,12 @@ namespace memory struct OMMemoryManager { + // Need for configure tool estimations +#ifdef OM_MEMORY_ESTIMATE + static size_t peak_memory_allocated; + static size_t cur_memory_allocated; + static OMStatus deallocateMemory(uint32_t size, uint8_t *data); +#endif // OM_MEMORY_ESTIMATE static OMStatus allocateMemory(uint32_t size, uint8_t **data); static OMStatus deallocateMemory(uint8_t *data); }; diff --git a/onert-micro/onert-micro/include/core/memory/OMRuntimeAllocator.h b/onert-micro/onert-micro/include/core/memory/OMRuntimeAllocator.h index 1e9d29a8e9b..d1d52310834 100644 --- a/onert-micro/onert-micro/include/core/memory/OMRuntimeAllocator.h +++ b/onert-micro/onert-micro/include/core/memory/OMRuntimeAllocator.h @@ -45,18 +45,6 @@ class OMRuntimeAllocator OMRuntimeAllocator(OMRuntimeAllocator &&) = default; ~OMRuntimeAllocator() = default; - void saveAllocPlan(std::vector> &&alloc_plan) - { - _alloc_plan.clear(); - _alloc_plan = std::move(alloc_plan); - } - - void saveDeallocPlan(std::vector> &&dealloc_plan) - { - _dealloc_plan.clear(); - _dealloc_plan = std::move(dealloc_plan); - } - std::vector> &getAllocPlan() { return _alloc_plan; } std::vector> &getDeallocPlan() { return _dealloc_plan; } @@ -67,6 +55,10 @@ class OMRuntimeAllocator OMStatus allocate(size_t kernel_index, OMRuntimeContext *context, OMRuntimeStorage *storage); OMStatus deallocate(size_t kernel_index, OMRuntimeStorage *storage); + // Need for configure tool estimations +#ifdef OM_MEMORY_ESTIMATE + OMStatus deallocate(size_t kernel_index, OMRuntimeStorage *storage, OMRuntimeContext *context); +#endif // OM_MEMORY_ESTIMTE }; } // namespace memory diff --git a/onert-micro/onert-micro/include/core/reader/OMTrainingConfigFileReader.h b/onert-micro/onert-micro/include/core/reader/OMTrainingConfigFileReader.h new file mode 100644 index 00000000000..fa80bb246c4 --- /dev/null +++ b/onert-micro/onert-micro/include/core/reader/OMTrainingConfigFileReader.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H +#define ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H + +#include "OMStatus.h" +#include "OMCircleReader.h" + +#include + +namespace onert_micro +{ +namespace core +{ +namespace reader +{ +namespace +{ + +enum TrainConfigFileFieldsOffsets +{ + MAGIC_NUMBER_FIELD = 0, + SCHEMA_VERSION_FIELD = 2, + NUM_LAYERS_FIELD = 4, + FIRST_LAYER_INDEX_FIELD = 8 +}; + +} // namespace + +constexpr uint16_t train_config_file_magic_number = 29; +constexpr uint8_t train_config_file_schema_version = 1; + +/** + * @brief Loads Training Config files and provides helpers functions + */ +class OMTrainingConfigReader +{ +public: + OMTrainingConfigReader() = default; + OMTrainingConfigReader(const OMTrainingConfigReader &) = delete; + OMTrainingConfigReader(OMTrainingConfigReader &&) = default; + OMTrainingConfigReader &operator=(const OMTrainingConfigReader &) = delete; + OMTrainingConfigReader &&operator=(const OMTrainingConfigReader &&) = delete; + ~OMTrainingConfigReader() = default; + +public: + // To validate _train_config_ptr and compare with circle model saved in reader. + OMStatus validate(OMCircleReader *reader); + + // Save pointer + void parse(char *ptr) { _train_config_ptr = ptr; } + + // Read and return indexes of trainable layers from config file + // first it is op index in graph, second is rank of the training (see OpTrainableRank) + std::unordered_map getTrainableOpsIndexes(); + +private: + char *_train_config_ptr; +}; + +} // namespace reader +} // namespace core +} // namespace onert_micro + +#endif // ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H diff --git a/onert-micro/onert-micro/include/core/train/OMTrainingHandler.h b/onert-micro/onert-micro/include/core/train/OMTrainingHandler.h index 08efc77e439..a292b991f68 100644 --- a/onert-micro/onert-micro/include/core/train/OMTrainingHandler.h +++ b/onert-micro/onert-micro/include/core/train/OMTrainingHandler.h @@ -77,7 +77,8 @@ class OMTrainingHandler OMRuntimeContext &context); // Handle with updating weights with current optimizer - OMStatus updateWeights(const OMConfig &config, OMRuntimeContext &context); + OMStatus updateWeights(const OMConfig &config, OMRuntimeContext &context, + OMRuntimeStorage &storage); // Evaluate metric and save result in metric_val // Warning: 1) assume that all metric_val for all OMMetrics types actually are float values. @@ -95,6 +96,13 @@ class OMTrainingHandler // Reset and deallocate all internal states void reset(); + +#ifdef OM_MEMORY_ESTIMATE + + // Reset and deallocate all states + void reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage); + +#endif // OM_MEMORY_ESTIMATE }; } // namespace train diff --git a/onert-micro/onert-micro/include/core/train/OMTrainingStorage.h b/onert-micro/onert-micro/include/core/train/OMTrainingStorage.h index 55f7f114224..6442a529333 100644 --- a/onert-micro/onert-micro/include/core/train/OMTrainingStorage.h +++ b/onert-micro/onert-micro/include/core/train/OMTrainingStorage.h @@ -51,6 +51,9 @@ class OMTrainingStorage // Note: initial its null std::unique_ptr _adam_optimizer = nullptr; + // Store rank types + std::unordered_map _tensor_index_to_train_rank; + public: OMTrainingStorage() = default; OMTrainingStorage(const OMTrainingStorage &) = delete; @@ -70,6 +73,16 @@ class OMTrainingStorage _target_index_to_target_data[target_index] = data; } + void addTrainRank(uint16_t tensor_index, core::OpTrainableRankType train_rank) + { + _tensor_index_to_train_rank[tensor_index] = train_rank; + } + + std::unordered_map &getTensorIndexToRankTypeTable() + { + return _tensor_index_to_train_rank; + } + // Choose and set optimizer defined in config OMStatus setOptimizer(const OMConfig &config); @@ -88,6 +101,13 @@ class OMTrainingStorage return _target_index_to_target_data[target_index]; } +#ifdef OM_MEMORY_ESTIMATE + + // Reset and deallocate all states + void reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage); + +#endif // OM_MEMORY_ESTIMATE + // Reset and deallocate all states void reset(); }; diff --git a/onert-micro/onert-micro/include/import/OMDynamicShapesHandler.h b/onert-micro/onert-micro/include/import/OMDynamicShapesHandler.h new file mode 100644 index 00000000000..1daec2395c9 --- /dev/null +++ b/onert-micro/onert-micro/include/import/OMDynamicShapesHandler.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_IMPORT_DYNAMIC_SHAPES_HANDLER_H +#define ONERT_MICRO_IMPORT_DYNAMIC_SHAPES_HANDLER_H + +#include "OMStatus.h" +#include "OMConfig.h" +#include "core/OMRuntimeStorage.h" +#include "core/OMRuntimeContext.h" +#include "core/train/OMTrainingStorage.h" + +namespace onert_micro +{ +namespace import +{ + +/* + * Class to handle with tensors dynamic shapes + */ +struct OMDynamicShapesHandler +{ + /* + * Import dynamic shapes from train config file data: + * Some tensors can have sparse tensor backpropagation scheme (train rank) + */ + static OMStatus importDynamicShapesFromTrainConfig(core::OMRuntimeStorage &storage, + core::OMRuntimeContext &context, + core::train::OMTrainingStorage &train_storage); +}; + +} // namespace import +} // namespace onert_micro + +#endif // ONERT_MICRO_IMPORT_DYNAMIC_SHAPES_HANDLER_H diff --git a/onert-micro/onert-micro/include/import/OMExecutionPlanCreator.h b/onert-micro/onert-micro/include/import/OMExecutionPlanCreator.h index 9385ae6dcf4..e6359a3dc45 100644 --- a/onert-micro/onert-micro/include/import/OMExecutionPlanCreator.h +++ b/onert-micro/onert-micro/include/import/OMExecutionPlanCreator.h @@ -31,12 +31,18 @@ namespace import struct OMExecutionPlanCreator { - // Create execution plan for forward graph + // Create execution plan for graph for non-train mode static OMStatus createExecutionPlan(core::OMRuntimeStorage &runtime_storage, core::OMRuntimeContext &runtime_context, core::memory::OMRuntimeAllocator &allocator, const OMConfig &configs); + // Create execution plan for forward graph for train mode + static OMStatus createForwardExecutionPlan(core::OMRuntimeStorage &runtime_storage, + core::OMRuntimeContext &runtime_context, + core::memory::OMRuntimeAllocator &allocator, + const OMConfig &configs); + // Create execution plan for backward graph static OMStatus createBackwardExecutionPlan(core::OMRuntimeStorage &runtime_storage, core::OMRuntimeContext &runtime_context, diff --git a/onert-micro/onert-micro/include/pal/common/PALConv2DInputGrad.h b/onert-micro/onert-micro/include/pal/common/PALConv2DInputGrad.h index 7405f76e1e9..91e8777d8a7 100644 --- a/onert-micro/onert-micro/include/pal/common/PALConv2DInputGrad.h +++ b/onert-micro/onert-micro/include/pal/common/PALConv2DInputGrad.h @@ -111,7 +111,7 @@ void Conv2DInputGrad(const core::FloatConv2D ¶ms, const core::OMRuntimeShape filter_y * weight_w * dloss_dinput_d + ic * weight_w * dloss_dinput_d * weight_h; assert(input_offset < dloss_doutput_shape.flatSize()); - assert(filter_offset < weight_shape.flatSize()); + // assert(filter_offset < weight_shape.flatSize()); float input_value = dloss_doutput_data[input_offset]; float filter_value = n_c_weight_data[filter_offset]; total += (input_value * filter_value); diff --git a/onert-micro/onert-micro/include/pal/common/PALConv2DWeightGrad.h b/onert-micro/onert-micro/include/pal/common/PALConv2DWeightGrad.h index 9a75ccbffff..d82a49a83b5 100644 --- a/onert-micro/onert-micro/include/pal/common/PALConv2DWeightGrad.h +++ b/onert-micro/onert-micro/include/pal/common/PALConv2DWeightGrad.h @@ -58,7 +58,8 @@ void Conv2DBiasGrad(const core::OMRuntimeShape &dloss_doutput_shape, void Conv2DWeightGrad(const core::FloatConv2D ¶ms, const core::OMRuntimeShape &input_shape, const float *input_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *dloss_doutput_data, - const core::OMRuntimeShape &dloss_dweight_shape, float *dloss_dweight_data) + const core::OMRuntimeShape &dloss_dweight_shape, float *dloss_dweight_data, + core::OpTrainableRankType rank) { const int stride_width = params.stride_w; const int stride_height = params.stride_h; @@ -77,8 +78,11 @@ void Conv2DWeightGrad(const core::FloatConv2D ¶ms, const core::OMRuntimeShap const int dloss_dweight_h = dloss_dweight_shape.dims(1); const int dloss_dweight_w = dloss_dweight_shape.dims(2); const int dloss_dweight_d = dloss_dweight_shape.dims(3); + const int dloss_dweight_o = dloss_dweight_shape.dims(0); - for (uint32_t oc = 0; oc < dloss_doutput_d; ++oc) + auto depth_bounds = execute::pal::getUpLowerWeightTensorDepth(rank, dloss_doutput_d); + + for (uint32_t oc = 0; oc < dloss_dweight_o; ++oc) { for (uint32_t ic = 0; ic < input_d; ++ic) { diff --git a/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h b/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h index 11526495eb1..5f134f0e0df 100644 --- a/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h +++ b/onert-micro/onert-micro/include/pal/common/PALFullyConnectedWeightGrad.h @@ -30,19 +30,22 @@ namespace train namespace pal { -void inline FullyConnectedWeightGrad(const float *dloss_doutput_data, - const core::OMRuntimeShape &dloss_doutput_shape, - const float *input_data, - const core::OMRuntimeShape &input_shape, - float *dloss_dweight_data) +void inline FullyConnectedWeightGrad( + const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, + const float *input_data, const core::OMRuntimeShape &input_shape, float *dloss_dweight_data, + const core::OMRuntimeShape &weight_shape, core::OpTrainableRankType rank) { const uint32_t batches = input_shape.dims(0); const uint32_t output_depth = dloss_doutput_shape.dims(1); const uint32_t accum_depth = input_shape.dims(1); - for (uint32_t o = 0; o < output_depth; ++o) + auto depth_bounds = execute::pal::getUpLowerWeightTensorDepth(rank, output_depth); + + auto weight_depth = weight_shape.dims(0); + + for (uint32_t o = 0; o < weight_depth; ++o) { - float cur_dloss_doutput = dloss_doutput_data[o]; + float cur_dloss_doutput = dloss_doutput_data[o + depth_bounds.first]; for (uint32_t i = 0; i < accum_depth; ++i) { dloss_dweight_data[i + o * accum_depth] = cur_dloss_doutput * input_data[i]; @@ -51,7 +54,7 @@ void inline FullyConnectedWeightGrad(const float *dloss_doutput_data, for (int b = 1; b < batches; ++b) { - for (uint32_t o = 0; o < output_depth; ++o) + for (uint32_t o = depth_bounds.first; o < depth_bounds.second; ++o) { float cur_dloss_doutput = dloss_doutput_data[o + b * output_depth]; for (uint32_t i = 0; i < accum_depth; ++i) diff --git a/onert-micro/onert-micro/include/pal/common/PALUtils.h b/onert-micro/onert-micro/include/pal/common/PALUtils.h index d2f02203475..7071fb894c1 100644 --- a/onert-micro/onert-micro/include/pal/common/PALUtils.h +++ b/onert-micro/onert-micro/include/pal/common/PALUtils.h @@ -27,6 +27,29 @@ namespace execute namespace pal { +inline std::pair getUpLowerWeightTensorDepth(core::OpTrainableRankType rank, + const uint32_t output_depth) +{ + std::pair result(0u, output_depth); + + switch (rank) + { + case core::ALL: + break; + case core::UP_1_2_PART: + result.second = static_cast(static_cast(output_depth) / 2.f); + break; + case core::LOWER_1_2_PART: + result.first = static_cast(static_cast(output_depth) / 2.f); + break; + default: + assert("Unsupported type"); + break; + } + + return result; +} + // Table of sigmoid(i/24) at 0.16 format - 256 elements. // We use combined sigmoid and tanh look-up table, since // tanh(x) = 2*sigmoid(2*x) -1. diff --git a/onert-micro/onert-micro/include/train/OMBackpropExecuteArgs.h b/onert-micro/onert-micro/include/train/OMBackpropExecuteArgs.h index 8e91d425743..ef2e9c2b967 100644 --- a/onert-micro/onert-micro/include/train/OMBackpropExecuteArgs.h +++ b/onert-micro/onert-micro/include/train/OMBackpropExecuteArgs.h @@ -37,7 +37,9 @@ struct OMBackpropExecuteArgs core::OMRuntimeStorage &backward_storage; core::OMRuntimeContext &backward_context; bool is_last_layer; + bool is_trainable_layer; uint16_t kernel_index; + core::OpTrainableRankType train_rank_type; }; } // namespace train diff --git a/onert-micro/onert-micro/include/train/train_optimizers/Adam.h b/onert-micro/onert-micro/include/train/train_optimizers/Adam.h index 0e5ac4fbda2..4204deb74ae 100644 --- a/onert-micro/onert-micro/include/train/train_optimizers/Adam.h +++ b/onert-micro/onert-micro/include/train/train_optimizers/Adam.h @@ -20,6 +20,7 @@ #include "OMStatus.h" #include "core/OMRuntimeStorage.h" #include "core/OMRuntimeContext.h" +#include "core/OMRuntimeStorage.h" #include #include @@ -53,6 +54,14 @@ class Adam Adam &&operator=(const Adam &&) = delete; ~Adam() { fullReset(); } +#ifdef OM_MEMORY_ESTIMATE + // Reset and deallocate all internal states + void fullReset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage); + + // Reset only gradients + void reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage); +#endif // OM_MEMORY_ESTIMATE + // Reset and deallocate all internal states void fullReset(); @@ -74,10 +83,13 @@ class Adam void setExponentAvgSquaresDataByTensorIndex(uint16_t tensor_index, uint8_t *data); // Update internal states according to Adam theory - OMStatus handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context); + OMStatus handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context, + core::OMRuntimeStorage &storage); // Update weights according to Adam theory - OMStatus updateWeights(const OMTrainingContext &training_config, core::OMRuntimeContext &context); + OMStatus updateWeights(const OMTrainingContext &training_config, core::OMRuntimeContext &context, + core::OMRuntimeStorage &storage, + std::unordered_map &); }; } // namespace optimizers diff --git a/onert-micro/onert-micro/include/train/train_optimizers/SGD.h b/onert-micro/onert-micro/include/train/train_optimizers/SGD.h index 7b3f13348e8..ee7c8c5e716 100644 --- a/onert-micro/onert-micro/include/train/train_optimizers/SGD.h +++ b/onert-micro/onert-micro/include/train/train_optimizers/SGD.h @@ -48,14 +48,23 @@ class SGD SGD &&operator=(const SGD &&) = delete; ~SGD() { reset(); } +#ifdef OM_MEMORY_ESTIMATE + // Reset and deallocate all internal states + void reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage); +#endif // OM_MEMORY_ESTIMATE + // Reset and deallocate all internal states void reset(); // Update internal states according to SGD theory - OMStatus handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context); + OMStatus handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context, + core::OMRuntimeStorage &storage); // Update weights according to SGD theory - OMStatus updateWeights(const OMTrainingContext &training_config, core::OMRuntimeContext &context); + OMStatus updateWeights( + const OMTrainingContext &training_config, core::OMRuntimeContext &context, + core::OMRuntimeStorage &storage, + std::unordered_map &tensor_index_to_rank_type_map); }; } // namespace optimizers diff --git a/onert-micro/onert-micro/src/CMakeLists.txt b/onert-micro/onert-micro/src/CMakeLists.txt index 6c2d7a8a733..6c8f673ab7e 100644 --- a/onert-micro/onert-micro/src/CMakeLists.txt +++ b/onert-micro/onert-micro/src/CMakeLists.txt @@ -17,9 +17,9 @@ set(OM_INCLUDE_TRAIN_DIR "${OM_INCLUDE_DIR}/train") set(OM_SOURCE_DEV_DIR "${OM_SOURCE_DIR}/api") #OM_Interpreter lib binary name -set(OM_INTERPRETER_LIB "onert_micro_interpreter") +set(OM_INTERPRETER_LIB "onert_micro_interpreter${OM_SUFFIX}") #OM_Training_Interpreter lib binary name -set(OM_TRAINING_INTERPRETER_LIB "onert_micro_training_interpreter") +set(OM_TRAINING_INTERPRETER_LIB "onert_micro_training_interpreter${OM_SUFFIX}") #Core lib binary name set(OM_CORE_LIB "onert_micro_core${OM_SUFFIX}") #Execute lib binary name diff --git a/onert-micro/onert-micro/src/core/CMakeLists.txt b/onert-micro/onert-micro/src/core/CMakeLists.txt index 4ed5c11df70..2e11b855d7e 100644 --- a/onert-micro/onert-micro/src/core/CMakeLists.txt +++ b/onert-micro/onert-micro/src/core/CMakeLists.txt @@ -16,7 +16,8 @@ set(SOURCES memory/OMMemoryManager.cpp memory/OMRuntimeAllocator.cpp reader/OMCircleReader.cpp - reader/OMWeightOnlyFormatReader.cpp) + reader/OMWeightOnlyFormatReader.cpp + reader/OMTrainingConfigFileReader.cpp) add_library(${OM_CORE_LIB} STATIC ${SOURCES}) diff --git a/onert-micro/onert-micro/src/core/OMRuntimeContext.cpp b/onert-micro/onert-micro/src/core/OMRuntimeContext.cpp index e746d4cbd74..bc8ffb73c3d 100644 --- a/onert-micro/onert-micro/src/core/OMRuntimeContext.cpp +++ b/onert-micro/onert-micro/src/core/OMRuntimeContext.cpp @@ -22,7 +22,7 @@ using namespace onert_micro::core; const circle::Operator *OMRuntimeContext::getCircleOperatorAt(uint16_t index) { const auto *operators = _reader.operators(); - + assert(index < operators->size()); return operators->operator[](index); } diff --git a/onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp b/onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp index 1d7fb200d24..f4ba276703b 100644 --- a/onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp +++ b/onert-micro/onert-micro/src/core/OMRuntimeGraph.cpp @@ -16,6 +16,7 @@ #include "core/OMRuntimeGraph.h" #include "core/OMDataType.h" +#include "core/memory/OMMemoryManager.h" #include "OMStatus.h" using namespace onert_micro::core; @@ -29,7 +30,14 @@ OMStatus OMRuntimeGraph::reset() return status; } -OMRuntimeGraph::~OMRuntimeGraph() { reset(); } +OMRuntimeGraph::~OMRuntimeGraph() +{ + reset(); +#ifdef OM_MEMORY_ESTIMATE + memory::OMMemoryManager::cur_memory_allocated = 0; + memory::OMMemoryManager::peak_memory_allocated = 0; +#endif // OM_MEMORY_ESTIMATE +} void *OMRuntimeGraph::getInputDataAt(uint32_t position) { diff --git a/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp b/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp index 166d2e89ca0..857d7413240 100644 --- a/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp +++ b/onert-micro/onert-micro/src/core/OMRuntimeModule.cpp @@ -94,14 +94,31 @@ OMStatus OMRuntimeModule::importModel(const char *model_ptr, const OMConfig &con if (config.wof_ptr != nullptr) runtime_context.setWofFile(config.wof_ptr); + // Parse and validate Train Config File if it is exists + // WARNING: setTrainConfigFile method of RuntimeContext should follow after setModel. + if (config.train_mode and config.training_context.training_config_info_data != nullptr) + runtime_context.setTrainConfigFile(config.training_context.training_config_info_data); + // Third - optimize it until can status = optimize::OMOptimizer::optimize(runtime_storage, runtime_context, config); if (status != Ok) return status; // 4 - AllocDeallocPlan creation - import::OMExecutionPlanCreator::createExecutionPlan(runtime_storage, runtime_context, - runtime_allocator, config); + if (not config.train_mode) + { + // Non trainable mode + status = import::OMExecutionPlanCreator::createExecutionPlan(runtime_storage, runtime_context, + runtime_allocator, config); + } + else + { + // Trainable mode + status = import::OMExecutionPlanCreator::createForwardExecutionPlan( + runtime_storage, runtime_context, runtime_allocator, config); + } + if (status != Ok) + return status; } for (uint32_t i = 0; i < num_subgraph; ++i) { diff --git a/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp b/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp index 35ee91a47ed..f5645806516 100644 --- a/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp +++ b/onert-micro/onert-micro/src/core/OMTrainingRuntimeModule.cpp @@ -16,10 +16,15 @@ #include "core/OMTrainingRuntimeModule.h" #include "import/OMExecutionPlanCreator.h" +#include "import/OMDynamicShapesHandler.h" #include "train/OMBackpropExecute.h" #include "core/train/OMCheckpointSaver.h" #include "core/train/OMCheckpointLoader.h" +#ifdef OM_MEMORY_ESTIMATE +#include "core/memory/OMMemoryManager.h" +#endif // OM_MEMORY_ESTIMATE + using namespace onert_micro::core; using namespace onert_micro; @@ -54,11 +59,21 @@ OMStatus OMTrainingRuntimeModule::importTrainModel(char *model_ptr, const OMConf if (config.wof_ptr != nullptr) runtime_context.setWofFile(config.wof_ptr); + // Parse and validate Train Config File if it is exists + // WARNING: setTrainConfigFile method of RuntimeContext should follow after setModel. + if (config.train_mode and config.training_context.training_config_info_data != nullptr) + runtime_context.setTrainConfigFile(config.training_context.training_config_info_data); + // AllocDeallocPlan backward graph creation status = import::OMExecutionPlanCreator::createBackwardExecutionPlan( runtime_storage, runtime_context, runtime_allocator, config); if (status != Ok) return status; + + // Set tensor to train rank type + auto &train_storage = _training_handler.getTrainingStorage(); + import::OMDynamicShapesHandler::importDynamicShapesFromTrainConfig( + runtime_storage, runtime_context, train_storage); } // Set current optimizer @@ -144,7 +159,7 @@ OMStatus OMTrainingRuntimeModule::trainSingleStep(OMConfig &config) // d. Run backward graph { onert_micro::train::OMBackpropExecuteArgs backprop_execute_args = { - forward_storage, backward_storage, backward_context, false, 0}; + forward_storage, backward_storage, backward_context, false, false, 0, ALL}; status = onert_micro::train::OMBackpropExecute::runBackward( config, backprop_execute_args, backward_graph.getRuntimeAllocator()); @@ -173,7 +188,8 @@ OMStatus OMTrainingRuntimeModule::trainSingleStep(OMConfig &config) // Get backward context OMRuntimeGraph &backward_graph = _backward_graphs.at(i); OMRuntimeContext &backward_context = backward_graph.getRuntimeContext(); - status = _training_handler.updateWeights(config, backward_context); + OMRuntimeStorage &backward_storage = backward_graph.getRuntimeStorage(); + status = _training_handler.updateWeights(config, backward_context, backward_storage); } return status; @@ -260,8 +276,13 @@ OMStatus OMTrainingRuntimeModule::reset() { graph.reset(); } - +#ifdef OM_MEMORY_ESTIMATE + auto &context = _backward_graphs.begin()->getRuntimeContext(); + auto &storage = _backward_graphs.begin()->getRuntimeStorage(); + _training_handler.reset(context, storage); +#elif _training_handler.reset(); +#endif // OM_MEMORY_ESTIMATE return status; } @@ -307,3 +328,18 @@ void *OMTrainingRuntimeModule::getInputData(int32_t index) { return _training_handler.getInputData(index); } + +#ifdef OM_MEMORY_ESTIMATE + +size_t OMTrainingRuntimeModule::getPeakFootprintMemory() +{ + return std::max(memory::OMMemoryManager::peak_memory_allocated, + memory::OMMemoryManager::cur_memory_allocated); +} + +size_t OMTrainingRuntimeModule::getCurrentFootprintMemory() +{ + return memory::OMMemoryManager::cur_memory_allocated; +} + +#endif // OM_MEMORY_ESTIMATE diff --git a/onert-micro/onert-micro/src/core/memory/OMMemoryManager.cpp b/onert-micro/onert-micro/src/core/memory/OMMemoryManager.cpp index 7ecbae1f4b0..da19ba10374 100644 --- a/onert-micro/onert-micro/src/core/memory/OMMemoryManager.cpp +++ b/onert-micro/onert-micro/src/core/memory/OMMemoryManager.cpp @@ -16,20 +16,45 @@ #include "core/memory/OMMemoryManager.h" +#include + using namespace onert_micro::core::memory; using namespace onert_micro; +size_t OMMemoryManager::peak_memory_allocated = 0; +size_t OMMemoryManager::cur_memory_allocated = 0; + OMStatus OMMemoryManager::allocateMemory(uint32_t size, uint8_t **data) { if (size == 0) return UnknownError; auto data_tmp = new uint8_t[size]; +#ifdef OM_MEMORY_ESTIMATE + + cur_memory_allocated += size; + + peak_memory_allocated = std::max(cur_memory_allocated, peak_memory_allocated); + +#endif // OM_MEMORY_ESTIMATE + *data = data_tmp; return Ok; } +#ifdef OM_MEMORY_ESTIMATE +OMStatus OMMemoryManager::deallocateMemory(uint32_t size, uint8_t *data) +{ + if (int32_t(cur_memory_allocated) - int32_t(size) < 0 and data != nullptr) + peak_memory_allocated = std::max(cur_memory_allocated, peak_memory_allocated); + cur_memory_allocated -= data != nullptr ? size : 0; + + delete[] data; + return Ok; +} +#endif // OM_MEMORY_ESTIMATE + OMStatus OMMemoryManager::deallocateMemory(uint8_t *data) { delete[] data; diff --git a/onert-micro/onert-micro/src/core/memory/OMRuntimeAllocator.cpp b/onert-micro/onert-micro/src/core/memory/OMRuntimeAllocator.cpp index fbb89a1d9aa..dab7b2aa54a 100644 --- a/onert-micro/onert-micro/src/core/memory/OMRuntimeAllocator.cpp +++ b/onert-micro/onert-micro/src/core/memory/OMRuntimeAllocator.cpp @@ -30,10 +30,25 @@ OMStatus OMRuntimeAllocator::clearAllTensorsData(OMRuntimeContext *context, for (auto &cur_tensor_index_data : tensor_index_to_data) { - auto tensor_index = cur_tensor_index_data.first; uint8_t *allocated_data = cur_tensor_index_data.second; +#ifdef OM_MEMORY_ESTIMATE + auto tensor_index = cur_tensor_index_data.first; + + auto tensor = context->getTensorByIndex(tensor_index); + auto num_elements = OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage->getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES + + auto tensor_size = num_elements * sizeof(OMDataType(tensor->type())); + OMMemoryManager::deallocateMemory(tensor_size, allocated_data); +#elif OMMemoryManager::deallocateMemory(allocated_data); +#endif // OM_MEMORY_ESTIMATE } return Ok; @@ -56,13 +71,13 @@ OMStatus OMRuntimeAllocator::allocate(size_t kernel_index, OMRuntimeContext *con int32_t num_elements = tensor_shape.flatSize(); #ifndef DIS_DYN_SHAPES - int32_t dynamic_tensor_size = storage->getDynamicTensorSize(tensor_index); - if (dynamic_tensor_size != -1) + int32_t dynamic_tensor_size = storage->getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) num_elements = dynamic_tensor_size; #endif // DIS_DYN_SHAPES - assert(num_elements >= 0 && "Num elements should be positive"); - if (num_elements < 0) + assert(num_elements > 0 && "Num elements should be greater zero"); + if (num_elements <= 0) return UnknownError; const auto casted_num_elements = static_cast(num_elements); const auto type_size = @@ -86,6 +101,50 @@ OMStatus OMRuntimeAllocator::allocate(size_t kernel_index, OMRuntimeContext *con return Ok; } +#ifdef OM_MEMORY_ESTIMATE +OMStatus OMRuntimeAllocator::deallocate(size_t kernel_index, OMRuntimeStorage *storage, + OMRuntimeContext *context) +{ + assert(kernel_index < _alloc_plan.size() && "Wrong kernel index"); + if (kernel_index >= _alloc_plan.size()) + return UnknownError; + + const std::vector ¤t_deallocate_plan = _dealloc_plan[kernel_index]; + + for (const uint16_t tensor_index : current_deallocate_plan) + { + uint8_t *allocated_data = nullptr; + OMStatus status = storage->getDataByTensorIndex(&allocated_data, tensor_index); + // To continue deallocate due to current tensor is not saved in storage + if (allocated_data == nullptr) + continue; + if (status != Ok) + return status; + + auto tensor = context->getTensorByIndex(tensor_index); + auto num_elements = OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage->getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES + + auto tensor_size = num_elements * sizeof(OMDataType(tensor->type())); + status = OMMemoryManager::deallocateMemory(tensor_size, allocated_data); + if (status != Ok) + return status; + + status = storage->removeTensorFromTensorIndexToData(tensor_index); + if (status != Ok) + return status; + } + + return Ok; +} + +#endif // OM_MEMORY_ESTIMATE + OMStatus OMRuntimeAllocator::deallocate(size_t kernel_index, OMRuntimeStorage *storage) { assert(kernel_index < _alloc_plan.size() && "Wrong kernel index"); @@ -139,7 +198,18 @@ OMStatus OMRuntimeAllocator::allocateGraphInputs(OMRuntimeContext *context, // First clear if already allocated status = storage->getDataByTensorIndex(&allocated_data, tensor_index); +#ifdef OM_MEMORY_ESTIMATE +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage->getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES + + auto tensor_size = num_elements * sizeof(OMDataType(tensor->type())); + OMMemoryManager::deallocateMemory(tensor_size, allocated_data); +#elif OMMemoryManager::deallocateMemory(allocated_data); +#endif // OM_MEMORY_ESTIMATE // Then Allocate status = OMMemoryManager::allocateMemory(casted_num_elements * type_size, &allocated_data); diff --git a/onert-micro/onert-micro/src/core/reader/OMTrainingConfigFileReader.cpp b/onert-micro/onert-micro/src/core/reader/OMTrainingConfigFileReader.cpp new file mode 100644 index 00000000000..4ca8903c62b --- /dev/null +++ b/onert-micro/onert-micro/src/core/reader/OMTrainingConfigFileReader.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/reader/OMTrainingConfigFileReader.h" + +using namespace onert_micro::core::reader; +using namespace onert_micro; + +/* + * Validate method for Weight Only Format file to check its correctness + */ +OMStatus OMTrainingConfigReader::validate(OMCircleReader *reader) +{ + OMStatus status = Ok; + + // Validate magic number + uint16_t mag_num = 0; + std::memcpy(&mag_num, &_train_config_ptr[MAGIC_NUMBER_FIELD], sizeof(mag_num)); + assert(mag_num == train_config_file_magic_number && + "False MAGIC NUMBER, check correctness of wof file"); + if (mag_num != train_config_file_magic_number) + return FailReadWOFFile; + + // Validate schema version + uint8_t version = 0; + std::memcpy(&version, &_train_config_ptr[SCHEMA_VERSION_FIELD], sizeof(version)); + assert(version == train_config_file_schema_version && + "False MAGIC NUMBER, check correctness of wof file"); + if (version != train_config_file_schema_version) + return FailReadWOFFile; + + // Validate count of ops is not greater than current model has + assert(reader != nullptr && "Reader should exist"); + if (reader == nullptr) + return ModelNotImport; + uint32_t num_ops = reader->operators()->size(); + uint32_t num_ops_in_file = 0; + std::memcpy(&num_ops_in_file, &_train_config_ptr[NUM_LAYERS_FIELD], sizeof(num_ops_in_file)); + assert(num_ops_in_file > 0 and num_ops >= num_ops_in_file && + "Number of operators in circle should be greater than train config file has"); + if (num_ops_in_file > 0 and num_ops < num_ops_in_file) + return FailReadWOFFile; + + return status; +} +/* + * Read and return indexes of trainable layers from config file + */ +std::unordered_map OMTrainingConfigReader::getTrainableOpsIndexes() +{ + std::unordered_map result; + + // If reader is not parsed then return empty vector + if (_train_config_ptr == nullptr) + return result; + + // Read number of ops + uint32_t num_ops_in_file = 0; + std::memcpy(&num_ops_in_file, &_train_config_ptr[NUM_LAYERS_FIELD], sizeof(num_ops_in_file)); + + assert(num_ops_in_file > 0); + // Obtain pointer to the first layer index position in the file + char *cur_op_index_ptr = &_train_config_ptr[FIRST_LAYER_INDEX_FIELD]; + char *cur_op_train_rank_ptr = + &_train_config_ptr[FIRST_LAYER_INDEX_FIELD + sizeof(uint16_t) * num_ops_in_file]; + // Fill result set with indexes and its rank + for (uint32_t i = 0; i < num_ops_in_file; ++i) + { + // Read op index + uint16_t cur_op_index; + std::memcpy(&cur_op_index, cur_op_index_ptr, sizeof(cur_op_index)); + cur_op_index_ptr += sizeof(cur_op_index); + + // Read op train rank + uint8_t cur_op_train_rank; + std::memcpy(&cur_op_train_rank, cur_op_train_rank_ptr, sizeof(cur_op_train_rank)); + cur_op_train_rank_ptr += sizeof(cur_op_train_rank); + + // Insert op index and op rank + result[cur_op_index] = cur_op_train_rank; + } + + return result; +} diff --git a/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp b/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp index 24535625a02..140807cb836 100644 --- a/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp +++ b/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp @@ -99,7 +99,8 @@ OMStatus OMTrainingHandler::handleError(const OMConfig &config, OMRuntimeStorage /* * Update weights with current optimizer logic */ -OMStatus OMTrainingHandler::updateWeights(const OMConfig &config, OMRuntimeContext &context) +OMStatus OMTrainingHandler::updateWeights(const OMConfig &config, OMRuntimeContext &context, + OMRuntimeStorage &storage) { OMStatus status = Ok; @@ -113,10 +114,15 @@ OMStatus OMTrainingHandler::updateWeights(const OMConfig &config, OMRuntimeConte if (sgd_optimizer == nullptr) return UnknownError; - status = sgd_optimizer->updateWeights(config.training_context, context); + status = sgd_optimizer->updateWeights(config.training_context, context, storage, + _training_storage.getTensorIndexToRankTypeTable()); assert(status == Ok); // Reset values +#ifdef OM_MEMORY_ESTIMATE + sgd_optimizer->reset(context, storage); +#elif sgd_optimizer->reset(); +#endif // OM_MEMORY_ESTIMATE break; } case ADAM: @@ -125,11 +131,15 @@ OMStatus OMTrainingHandler::updateWeights(const OMConfig &config, OMRuntimeConte assert(adam_optimizer != nullptr); if (adam_optimizer == nullptr) return UnknownError; - - status = adam_optimizer->updateWeights(config.training_context, context); + status = adam_optimizer->updateWeights(config.training_context, context, storage, + _training_storage.getTensorIndexToRankTypeTable()); assert(status == Ok); // Reset values +#ifdef OM_MEMORY_ESTIMATE + adam_optimizer->reset(context, storage); +#elif adam_optimizer->reset(); +#endif // OM_MEMORY_ESTIMATE break; } default: @@ -163,7 +173,7 @@ OMStatus OMTrainingHandler::updateOptimizerState(const OMConfig &config, if (sgd_optimizer == nullptr) return UnknownError; - sgd_optimizer->handle(backward_storage, context); + sgd_optimizer->handle(backward_storage, context, backward_storage); break; } case ADAM: @@ -173,7 +183,7 @@ OMStatus OMTrainingHandler::updateOptimizerState(const OMConfig &config, if (adam_optimizer == nullptr) return UnknownError; - adam_optimizer->handle(backward_storage, context); + adam_optimizer->handle(backward_storage, context, backward_storage); break; } default: @@ -188,6 +198,13 @@ OMStatus OMTrainingHandler::updateOptimizerState(const OMConfig &config, void OMTrainingHandler::reset() { _training_storage.reset(); } +#ifdef OM_MEMORY_ESTIMATE +void OMTrainingHandler::reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage) +{ + _training_storage.reset(context, storage); +} +#endif // OM_MEMORY_ESTIMATE + /* * Evaluate metric according OMMetrics and save it into metric_val * diff --git a/onert-micro/onert-micro/src/core/train/OMTrainingStorage.cpp b/onert-micro/onert-micro/src/core/train/OMTrainingStorage.cpp index b3167facde6..7a53411eb5f 100644 --- a/onert-micro/onert-micro/src/core/train/OMTrainingStorage.cpp +++ b/onert-micro/onert-micro/src/core/train/OMTrainingStorage.cpp @@ -55,6 +55,23 @@ void OMTrainingStorage::reset() if (_sgd_optimizer) _sgd_optimizer->reset(); + if (_adam_optimizer) + _adam_optimizer->fullReset(); + + _target_index_to_target_data.clear(); + _input_index_to_input_data.clear(); +} + +#ifdef OM_MEMORY_ESTIMATE +void OMTrainingStorage::reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage) +{ + if (_sgd_optimizer) + _sgd_optimizer->reset(context, storage); + + if (_adam_optimizer) + _adam_optimizer->fullReset(context, storage); + _target_index_to_target_data.clear(); _input_index_to_input_data.clear(); } +#endif // OM_MEMORY_ESTIMATE diff --git a/onert-micro/onert-micro/src/execute/OMKernelExecute.cpp b/onert-micro/onert-micro/src/execute/OMKernelExecute.cpp index b01ba2234f0..c907baae22c 100644 --- a/onert-micro/onert-micro/src/execute/OMKernelExecute.cpp +++ b/onert-micro/onert-micro/src/execute/OMKernelExecute.cpp @@ -81,7 +81,11 @@ OMStatus OMKernelExecute::runForward(OMExecuteArgs &execute_args, if (status != Ok) return status; +#ifdef OM_MEMORY_ESTIMATE + status = allocator.deallocate(i, &storage, &context); +#elif status = allocator.deallocate(i, &storage); +#endif // OM_MEMORY_ESTIMATE } return status; diff --git a/onert-micro/onert-micro/src/import/CMakeLists.txt b/onert-micro/onert-micro/src/import/CMakeLists.txt index ee8ea4d5ffc..e865ada32ff 100644 --- a/onert-micro/onert-micro/src/import/CMakeLists.txt +++ b/onert-micro/onert-micro/src/import/CMakeLists.txt @@ -4,6 +4,7 @@ set(SOURCES OMExecutionPlanCreator.cpp OMKernelConfiguration.cpp OMKernelConfigureBuilder.cpp + OMDynamicShapesHandler.cpp helpers/OMConfigureSISOKernel.cpp helpers/OMPadCommon.cpp helpers/OMConfigureTISOKernel.cpp diff --git a/onert-micro/onert-micro/src/import/OMDynamicShapesHandler.cpp b/onert-micro/onert-micro/src/import/OMDynamicShapesHandler.cpp new file mode 100644 index 00000000000..9297743c92a --- /dev/null +++ b/onert-micro/onert-micro/src/import/OMDynamicShapesHandler.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "import/OMDynamicShapesHandler.h" +#include "core/OMKernelType.h" + +using namespace onert_micro::core; +using namespace onert_micro::import; +using namespace onert_micro; + +namespace +{ + +// Function to obtain index for current operation which has weight tensor, which will be sparse +// updated If operation don't support such behaviour - return -1 +int32_t getWeightTensorIndexForOperatorWithOpcode(const circle::OperatorCode *opcode) +{ + switch (opcode->builtin_code()) + { + // For Conv2D and for FullyConnected return index for weight tensor - is 1 + case circle::BuiltinOperator_CONV_2D: + case circle::BuiltinOperator_FULLY_CONNECTED: + return 1; + default: + break; + } + return -1; +} + +OMRuntimeShape createDynamicRuntimeShapeForOperator(OMRuntimeShape shape, + const circle::OperatorCode *opcode, + const float partition_size) +{ + assert(partition_size > 1.0); + switch (opcode->builtin_code()) + { + // For Conv2D and for FullyConnected return index for weight tensor - is 1 + case circle::BuiltinOperator_CONV_2D: + case circle::BuiltinOperator_FULLY_CONNECTED: + { + auto first_dim_val = shape.dims(0); + assert(partition_size <= static_cast(first_dim_val)); + assert(partition_size > 0); + if (partition_size == 0) + return shape; + first_dim_val = static_cast(static_cast(first_dim_val) / partition_size); + assert(first_dim_val > 0); + shape.setDim(0, first_dim_val); + } + break; + default: + break; + } + + return shape; +} + +} // namespace + +/* + * Import dynamic shapes from train config file data: + * Some tensors can have sparse tensor backpropagation scheme (train rank) + */ +OMStatus OMDynamicShapesHandler::importDynamicShapesFromTrainConfig( + core::OMRuntimeStorage &storage, core::OMRuntimeContext &context, + core::train::OMTrainingStorage &train_storage) +{ + std::unordered_map train_op_indexes_to_train_rank = + context.getTrainableOpsIndexes(); + const auto opcodes = context.getCircleOpcodes(); + + // Goes over pairs of op index and train rank value + for (auto &p : train_op_indexes_to_train_rank) + { + const uint16_t op_index = p.first; + const auto train_rank = static_cast(p.second); + + switch (train_rank) + { + case core::LOWER_1_2_PART: + case core::UP_1_2_PART: + { + const auto cur_op = context.getCircleOperatorAt(op_index); + const auto opcode = opcodes->operator[](cur_op->opcode_index()); + + int32_t res_index = getWeightTensorIndexForOperatorWithOpcode(opcode); + // The operation doesn't support such behaviour + if (res_index == -1) + continue; + + auto tensor_local_index = static_cast(res_index); + auto tensor_index = cur_op->inputs()->operator[](tensor_local_index); + auto tensor = context.getTensorByIndex(tensor_index); + OMRuntimeShape old_shape(tensor); + const float partition_size = 2.f; + OMRuntimeShape new_shape = + createDynamicRuntimeShapeForOperator(old_shape, opcode, partition_size); + storage.setDynamicRuntimeShape(tensor_index, new_shape); + + train_storage.addTrainRank(tensor_index, train_rank); + + break; + } + default: + continue; + } + } + + return Ok; +} diff --git a/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp b/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp index fb85184597d..a83efa06fa1 100644 --- a/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp +++ b/onert-micro/onert-micro/src/import/OMExecutionPlanCreator.cpp @@ -40,10 +40,43 @@ bool isTrainableWeights(const circle::OperatorCode *opcode) } } +bool isOpNeedSaveOutputData(const circle::OperatorCode *opcode, const circle::Operator *cur_op) +{ + switch (opcode->builtin_code()) + { + case circle::BuiltinOperator_FULLY_CONNECTED: + { + if (cur_op->builtin_options_as_FullyConnectedOptions() != nullptr) + return true; + } + break; + case circle::BuiltinOperator_CONV_2D: + { + if (cur_op->builtin_options_as_Conv2DOptions() != nullptr) + return true; + } + break; + default: + return false; + } + return false; +} + +bool isOpNeedSaveInputData(const circle::OperatorCode *opcode) +{ + switch (opcode->builtin_code()) + { + case circle::BuiltinOperator_MAX_POOL_2D: + return true; + default: + return false; + } +} + } // namespace /* - * Create execution plan for forward graph + * Create execution plan for graph for non trainable mode * TODO: describe creation execution plan logic */ OMStatus OMExecutionPlanCreator::createExecutionPlan(core::OMRuntimeStorage &runtime_storage, @@ -51,12 +84,117 @@ OMStatus OMExecutionPlanCreator::createExecutionPlan(core::OMRuntimeStorage &run core::memory::OMRuntimeAllocator &allocator, const OMConfig &configs) { + // Check is non trainable mode + assert(configs.train_mode != true); + if (configs.train_mode == true) + return UnknownError; + bool keep_input = configs.keep_input; - bool train_mode = configs.train_mode; std::vector> &alloc_plan = allocator.getAllocPlan(); std::vector> &dealloc_plan = allocator.getDeallocPlan(); + // First remove prev plan (if it was created) + alloc_plan.clear(); + dealloc_plan.clear(); + + using Lifetime = std::pair; + + std::map lifetimes; + + const reader::CircleOperators *operators = runtime_context.getCircleOperators(); + + const size_t num_kernels = operators->size(); + + if (not keep_input) + { + auto graph_inputs = runtime_context.getCircleInputs(); + for (const auto input_ind : *graph_inputs) + { + assert(lifetimes.count(input_ind) == 0); + lifetimes[input_ind] = Lifetime(-1, 0); + } + } + + for (int32_t index = 0; index < num_kernels; ++index) + { + auto *cur_op = operators->operator[](index); + + const auto *op_inputs = cur_op->inputs(); + const auto *op_outputs = cur_op->outputs(); + auto kernel_type = runtime_storage.getKernelType(index); + for (int32_t j = 0; j < op_inputs->size(); ++j) + { + const auto input_index = op_inputs->operator[](j); + + if (input_index == -1) + continue; + + // Pass constant tensors + if (runtime_context.isConstTensor(input_index)) + continue; + + if (lifetimes.count(input_index) > 0) + { + if (kernel_type == Inplace) + lifetimes.at(input_index).second = -1; + else + lifetimes.at(input_index).second = index; + } + } + + for (int32_t j = 0; j < op_outputs->size(); ++j) + { + const auto output_index = op_outputs->operator[](j); + + if (kernel_type == Inplace) + lifetimes[output_index] = Lifetime(-1, index); + else + lifetimes[output_index] = Lifetime(index, index); + } + } + auto graph_outputs = runtime_context.getCircleOutputs(); + for (const auto output_ind : *graph_outputs) + { + if (lifetimes.count(output_ind) > 0) + lifetimes.at(output_ind).second = static_cast(num_kernels); + } + + alloc_plan.assign(num_kernels, std::vector()); + dealloc_plan.assign(num_kernels + 1, std::vector()); + + for (const auto &item : lifetimes) + { + if (item.second.first != -1) + alloc_plan[item.second.first].push_back(item.first); + if (item.second.second != -1) + dealloc_plan[item.second.second].push_back(item.first); + } + + return Ok; +} + +/* + * Create execution plan for graph for non trainable mode + * TODO: describe creation execution plan logic + */ +OMStatus OMExecutionPlanCreator::createForwardExecutionPlan( + core::OMRuntimeStorage &runtime_storage, core::OMRuntimeContext &runtime_context, + core::memory::OMRuntimeAllocator &allocator, const OMConfig &configs) +{ + // Check is trainable mode + assert(configs.train_mode == true); + if (configs.train_mode != true) + return UnknownError; + + bool keep_input = configs.keep_input; + std::vector> &alloc_plan = allocator.getAllocPlan(); + std::vector> &dealloc_plan = allocator.getDeallocPlan(); + + // First remove prev plan (if it was created) + alloc_plan.clear(); + dealloc_plan.clear(); + using Lifetime = std::pair; std::map lifetimes; @@ -66,9 +204,27 @@ OMStatus OMExecutionPlanCreator::createExecutionPlan(core::OMRuntimeStorage &run const size_t num_kernels = operators->size(); uint32_t num_train_layers = configs.training_context.num_of_train_layers; - if (train_mode and num_train_layers == 0) + if (num_train_layers == 0) num_train_layers = num_kernels; + std::unordered_map trainable_ops_config = + runtime_context.getTrainableOpsIndexes(); + + // If context has config file defined trainable operations + // than ignore configs.training_context.num_of_train_layers value + // and use max value from trainable_ops_indexes to define last train op + uint16_t last_train_op_indx = num_kernels - num_train_layers; + if (!trainable_ops_config.empty()) + { + last_train_op_indx = std::numeric_limits::max(); + // Find op trainable index with min value + for (auto &p : trainable_ops_config) + { + last_train_op_indx = std::min(p.first, last_train_op_indx); + } + num_train_layers = (num_kernels - last_train_op_indx); + } + if (not keep_input) { auto graph_inputs = runtime_context.getCircleInputs(); @@ -79,6 +235,8 @@ OMStatus OMExecutionPlanCreator::createExecutionPlan(core::OMRuntimeStorage &run } } + const auto *op_codes = runtime_context.getCircleOpcodes(); + for (int32_t index = 0; index < num_kernels; ++index) { auto *cur_op = operators->operator[](index); @@ -86,6 +244,29 @@ OMStatus OMExecutionPlanCreator::createExecutionPlan(core::OMRuntimeStorage &run const auto *op_inputs = cur_op->inputs(); const auto *op_outputs = cur_op->outputs(); auto kernel_type = runtime_storage.getKernelType(index); + + uint32_t cur_opcode_index = cur_op->opcode_index(); + + assert(cur_opcode_index < op_codes->size()); + + const auto opcode = op_codes->operator[](cur_opcode_index); + + // Flag to determine is current operation needed to save input data (is this op in training part + // of the graph) + bool need_to_save_input_data = + (index >= last_train_op_indx) and + ((trainable_ops_config.find(index) != trainable_ops_config.end() and + trainable_ops_config[index] != ONLY_BIAS) or + isOpNeedSaveInputData(opcode)); + + // Flag to determine is current operation needed to save output data (is this op in training + // part of the graph) + bool need_to_save_output_data = + (index >= last_train_op_indx) and + ((trainable_ops_config.find(index) != trainable_ops_config.end() and + trainable_ops_config[index] != ONLY_BIAS) or + isOpNeedSaveOutputData(opcode, cur_op)); + for (int32_t j = 0; j < op_inputs->size(); ++j) { const auto input_index = op_inputs->operator[](j); @@ -99,7 +280,9 @@ OMStatus OMExecutionPlanCreator::createExecutionPlan(core::OMRuntimeStorage &run if (lifetimes.count(input_index) > 0) { - if (kernel_type == Inplace or train_mode and index >= (num_kernels - num_train_layers)) + // lifetimes.at(input_index).second == -2 - Means need to save data for input_index tensor + if (kernel_type == Inplace or need_to_save_input_data or + (lifetimes.at(input_index).second == -2)) lifetimes.at(input_index).second = -1; else lifetimes.at(input_index).second = index; @@ -112,8 +295,8 @@ OMStatus OMExecutionPlanCreator::createExecutionPlan(core::OMRuntimeStorage &run if (kernel_type == Inplace) lifetimes[output_index] = Lifetime(-1, index); - else if (train_mode and index >= (num_kernels - num_train_layers)) - lifetimes[output_index] = Lifetime(index, -1); + else if (need_to_save_output_data) + lifetimes[output_index] = Lifetime(index, -2); else lifetimes[output_index] = Lifetime(index, index); } @@ -130,9 +313,9 @@ OMStatus OMExecutionPlanCreator::createExecutionPlan(core::OMRuntimeStorage &run for (const auto &item : lifetimes) { - if (item.second.first != -1) + if (item.second.first >= 0) alloc_plan[item.second.first].push_back(item.first); - if (item.second.second != -1) + if (item.second.second >= 0) dealloc_plan[item.second.second].push_back(item.first); } @@ -161,15 +344,20 @@ OMStatus OMExecutionPlanCreator::createBackwardExecutionPlan( std::vector> &alloc_plan = allocator.getAllocPlan(); std::vector> &dealloc_plan = allocator.getDeallocPlan(); + // First remove prev plan (if it was created) + alloc_plan.clear(); + dealloc_plan.clear(); + using Lifetime = std::pair; std::map lifetimes; const reader::CircleOperators *operators = runtime_context.getCircleOperators(); const uint32_t num_kernels = operators->size(); - uint32_t num_train_layers = configs.training_context.num_of_train_layers == 0 - ? num_kernels - : configs.training_context.num_of_train_layers; + uint32_t num_train_layers = + configs.training_context.num_of_train_layers == 0 + ? num_kernels + : std::min(num_kernels, configs.training_context.num_of_train_layers); auto graph_outputs = runtime_context.getCircleOutputs(); for (const auto output_ind : *graph_outputs) @@ -178,9 +366,26 @@ OMStatus OMExecutionPlanCreator::createBackwardExecutionPlan( lifetimes[output_ind] = Lifetime(-1, 0); } - uint32_t last_node_pos = std::min(num_kernels, num_train_layers); + std::unordered_map trainable_ops_config = + runtime_context.getTrainableOpsIndexes(); + + // If context has config file defined trainable operations + // than ignore configs.training_context.num_of_train_layers value + // and use max value from trainable_ops_indexes to define last train op + uint16_t last_train_op_indx = num_kernels - num_train_layers; + if (!trainable_ops_config.empty()) + { + last_train_op_indx = std::numeric_limits::max(); + // Find op trainable index with min value + for (auto &p : trainable_ops_config) + { + last_train_op_indx = std::min(p.first, last_train_op_indx); + } + num_train_layers = (num_kernels - last_train_op_indx); + } + const auto *op_codes = runtime_context.getCircleOpcodes(); - for (int32_t index = 0; index < last_node_pos; ++index) + for (int32_t index = 0; index < num_train_layers; ++index) { uint32_t cur_op_index = num_kernels - index - 1; auto *cur_op = operators->operator[](cur_op_index); @@ -193,29 +398,52 @@ OMStatus OMExecutionPlanCreator::createBackwardExecutionPlan( const auto *op_inputs = cur_op->inputs(); const auto *op_outputs = cur_op->outputs(); + + bool is_trainable_ops = + trainable_ops_config.empty() == true + ? isTrainableWeights(opcode) + : trainable_ops_config.find(cur_op_index) != trainable_ops_config.end(); + + // Warning: this is right for Conv2D and for FullyConnected kernels + const int32_t bias_index = 2; + for (int32_t j = 0; j < op_inputs->size(); ++j) { const auto input_index = op_inputs->operator[](j); const auto is_const = runtime_context.isConstTensor(input_index); // Note: we dont need to allocate for last node and for empty tensor - if (input_index == -1 or (is_const and not isTrainableWeights(opcode)) or - ((index == last_node_pos - 1) and !is_const)) + if (input_index == -1 or (is_const and not is_trainable_ops)) { continue; } - lifetimes[input_index] = {index, -1}; + + if ((index == num_train_layers - 1) and !is_const) + { + lifetimes[input_index] = {-1, index}; + } + else if (is_const and + trainable_ops_config.find(cur_op_index) != trainable_ops_config.end() and + trainable_ops_config[cur_op_index] == ONLY_BIAS and j != bias_index) + { + // Do nothing, due to update only bias + continue; + } + else + { + lifetimes[input_index] = {index, -1}; + } } for (int32_t j = 0; j < op_outputs->size(); ++j) { const auto output_index = op_outputs->operator[](j); - - lifetimes.at(output_index).second = index; + if (lifetimes.count(output_index) > 0) + lifetimes.at(output_index).second = index; } } - alloc_plan.assign(last_node_pos, std::vector()); - dealloc_plan.assign(last_node_pos, std::vector()); + alloc_plan.assign(num_train_layers, std::vector()); + dealloc_plan.assign(num_train_layers, std::vector()); for (const auto &item : lifetimes) { diff --git a/onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp b/onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp index 3d6f03ba3fa..e7bd5a4b71a 100644 --- a/onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp +++ b/onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp @@ -80,9 +80,9 @@ onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs if (input_shape.flatSize() == 1 and output_shape.flatSize() != 1) { #ifndef DIS_DYN_SHAPES - int32_t dynamic_tensor_size = - runtime_storage.getDynamicTensorSize(runtime_kernel.inputs_index[inputTensorIdx]); - if (dynamic_tensor_size == -1) + input_shape = + runtime_storage.getDynamicRuntimeShape(runtime_kernel.inputs_index[inputTensorIdx]); + if (input_shape.flatSize() == 0) return UnsupportedDynamicShapeCase; #else return UnsupportedDynamicShapeCase; diff --git a/onert-micro/onert-micro/src/import/kernels/Reshape.cpp b/onert-micro/onert-micro/src/import/kernels/Reshape.cpp index d85f6e4f9b9..0b44b7b50b9 100644 --- a/onert-micro/onert-micro/src/import/kernels/Reshape.cpp +++ b/onert-micro/onert-micro/src/import/kernels/Reshape.cpp @@ -75,8 +75,8 @@ OMStatus onert_micro::import::configure_kernel_CircleReshape(const OMConfigureAr if (status != Ok) return status; - runtime_storage.setDynamicTensorSize(runtime_kernel.outputs_index[outputTensorIdx], - input_shape_size); + runtime_storage.setDynamicRuntimeShape(runtime_kernel.outputs_index[outputTensorIdx], + input_shape); } else { diff --git a/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp b/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp index 4e224ad1a27..cd37f875a8c 100644 --- a/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp +++ b/onert-micro/onert-micro/src/train/OMBackpropExecute.cpp @@ -37,12 +37,28 @@ OMStatus OMBackpropExecute::runBackward(const OMConfig &config, OMBackpropExecut const auto num_operators = operators->size(); const auto *op_codes = context.getCircleOpcodes(); - uint32_t num_train_layers = config.training_context.num_of_train_layers == 0 - ? num_operators - : config.training_context.num_of_train_layers; - uint32_t last_node_pos = std::min(num_operators, num_train_layers); + uint32_t num_train_layers = + config.training_context.num_of_train_layers == 0 + ? num_operators + : std::min(num_operators, config.training_context.num_of_train_layers); + std::unordered_map trainable_ops_config = context.getTrainableOpsIndexes(); + + // If context has config file defined trainable operations + // than ignore configs.training_context.num_of_train_layers value + // and use max value from trainable_ops_indexes to define last train op + uint16_t last_train_op_indx = num_operators - num_train_layers; + if (!trainable_ops_config.empty()) + { + last_train_op_indx = std::numeric_limits::max(); + // Find op trainable index with min value + for (auto &p : trainable_ops_config) + { + last_train_op_indx = std::min(p.first, last_train_op_indx); + } + num_train_layers = (num_operators - last_train_op_indx); + } - for (uint32_t i = 0; i < last_node_pos; ++i) + for (int32_t i = 0; i < num_train_layers; ++i) { uint32_t cur_op_index = num_operators - i - 1; auto *cur_op = operators->operator[](cur_op_index); @@ -68,8 +84,24 @@ OMStatus OMBackpropExecute::runBackward(const OMConfig &config, OMBackpropExecut args.kernel_index = cur_op_index; - if (i == last_node_pos - 1) + if (i == num_train_layers - 1) + { args.is_last_layer = true; + } + else + { + args.is_last_layer = false; + } + + if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end()) + { + args.is_trainable_layer = true; + args.train_rank_type = core::OpTrainableRankType(trainable_ops_config[cur_op_index]); + } + else + { + args.is_trainable_layer = false; + } // Calculate gradients KernelTrainFunc *train_func = nullptr; @@ -96,13 +128,22 @@ OMStatus OMBackpropExecute::runBackward(const OMConfig &config, OMBackpropExecut if (status != Ok) return status; - // Deallocate tensors data in backward storage + // Deallocate tensors data in backward storage +#ifdef OM_MEMORY_ESTIMATE + status = allocator.deallocate(i, &backward_storage, &context); + if (status != Ok) + return status; + + // Deallocate tensors data in forward storage + status = allocator.deallocate(i, &forward_storage, &context); +#else status = allocator.deallocate(i, &backward_storage); if (status != Ok) return status; // Deallocate tensors data in forward storage status = allocator.deallocate(i, &forward_storage); +#endif } return status; diff --git a/onert-micro/onert-micro/src/train/kernels/Conv2D.cpp b/onert-micro/onert-micro/src/train/kernels/Conv2D.cpp index af3f2a9d63b..3f7cdc04886 100644 --- a/onert-micro/onert-micro/src/train/kernels/Conv2D.cpp +++ b/onert-micro/onert-micro/src/train/kernels/Conv2D.cpp @@ -52,6 +52,8 @@ OMStatus onert_micro::train::train_kernel_CircleConv2D(const OMBackpropExecuteAr const circle::Tensor *weight; const circle::Tensor *output; + int32_t weight_tensor_index = -1; + uint8_t *input_data; uint8_t *dloss_dinput_data; @@ -78,6 +80,9 @@ OMStatus onert_micro::train::train_kernel_CircleConv2D(const OMBackpropExecuteAr // Bias can be nullptr assert(output != nullptr); + weight_tensor_index = runtime_kernel.inputs_index[weightTensorIdx]; + assert(weight_tensor_index != -1); + // Read forward storage { runtime_kernel.getDataFromStorage(op_index, forward_storage, context); @@ -88,7 +93,7 @@ OMStatus onert_micro::train::train_kernel_CircleConv2D(const OMBackpropExecuteAr output_data = runtime_kernel.outputs_data[outputTensorIdx]; // Bias_data can be nullptr // Output_data can be nullptr - assert(input_data != nullptr); + // Input_data can be nullptr if we don't train this layer assert(weight_data != nullptr); } @@ -147,20 +152,35 @@ OMStatus onert_micro::train::train_kernel_CircleConv2D(const OMBackpropExecuteAr params.pad_h = 0; params.pad_w = 0; - // 2. Calculate weight gradient - pal::Conv2DWeightGrad(params, input_shape, utils::castInputData(input_data), output_shape, - utils::castInputData(dloss_doutput_data), weight_shape, - utils::castOutputData(dloss_dweight_data)); - - // 3. Calculate bias gradient - if (dloss_dbias_data) + if (args.is_trainable_layer) { - assert(bias_data != nullptr); - if (bias_data == nullptr) - return UnknownError; + // Check is only bias updating + if (args.train_rank_type != ONLY_BIAS) + { + assert(input_data != nullptr); // FIX memory planner then + + // Get weight shape + OMRuntimeShape dynamic_shapes = backward_storage.getDynamicRuntimeShape(weight_tensor_index); + if (dynamic_shapes.flatSize() != 0) + weight_shape = dynamic_shapes; + + // 2. Calculate weight gradient + pal::Conv2DWeightGrad(params, input_shape, utils::castInputData(input_data), + output_shape, utils::castInputData(dloss_doutput_data), + weight_shape, utils::castOutputData(dloss_dweight_data), + args.train_rank_type); + } + + // 3. Calculate bias gradient + if (dloss_dbias_data) + { + assert(bias_data != nullptr); + if (bias_data == nullptr) + return UnknownError; - pal::Conv2DBiasGrad(output_shape, utils::castInputData(dloss_doutput_data), - utils::castOutputData(dloss_dbias_data)); + pal::Conv2DBiasGrad(output_shape, utils::castInputData(dloss_doutput_data), + utils::castOutputData(dloss_dbias_data)); + } } // 4. Calculate (if needed) input grad diff --git a/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp b/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp index 7aee8849779..7190422c3d8 100644 --- a/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp +++ b/onert-micro/onert-micro/src/train/kernels/FullyConnected.cpp @@ -53,6 +53,8 @@ OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropE const circle::Tensor *weight; const circle::Tensor *output; + int32_t weight_tensor_index = -1; + uint8_t *input_data; uint8_t *dloss_dinput_data; @@ -79,6 +81,9 @@ OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropE // Bias can be nullptr assert(output != nullptr); + weight_tensor_index = runtime_kernel.inputs_index[weightTensorIdx]; + assert(weight_tensor_index != -1); + // Read forward storage { runtime_kernel.getDataFromStorage(op_index, forward_storage, context); @@ -89,7 +94,7 @@ OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropE output_data = runtime_kernel.outputs_data[outputTensorIdx]; // Bias_data can be nullptr // Output_data can be nullptr - assert(input_data != nullptr); + // Input_data can be nullptr assert(weight_data != nullptr); } @@ -133,23 +138,39 @@ OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropE } } - // 2. Calculate weight gradient - pal::FullyConnectedWeightGrad(core::utils::castInputData(dloss_doutput_data), output_shape, - core::utils::castInputData(input_data), input_shape, - core::utils::castOutputData(dloss_dweight_data)); - - // 3. Calculate bias gradient - // Just copy dloss_doutput_data to dloss_dbias_data - // TODO: introduce training inplace - if (dloss_dbias_data) + if (args.is_trainable_layer) { - assert(bias_data != nullptr); - if (bias_data == nullptr) - return UnknownError; + // Check is only bias updating + if (args.train_rank_type != ONLY_BIAS) + { + assert(input_data != nullptr); // FIX memory planner then + + // Get weight shape + OMRuntimeShape weight_shape(weight); + OMRuntimeShape dynamic_shapes = backward_storage.getDynamicRuntimeShape(weight_tensor_index); + if (dynamic_shapes.flatSize() != 0) + weight_shape = dynamic_shapes; + + // 2. Calculate weight gradient + pal::FullyConnectedWeightGrad( + core::utils::castInputData(dloss_doutput_data), output_shape, + core::utils::castInputData(input_data), input_shape, + core::utils::castOutputData(dloss_dweight_data), weight_shape, args.train_rank_type); + } + + // 3. Calculate bias gradient + // Just copy dloss_doutput_data to dloss_dbias_data + // TODO: introduce training inplace + if (dloss_dbias_data) + { + assert(bias_data != nullptr); + if (bias_data == nullptr) + return UnknownError; - std::memcpy(dloss_dbias_data, dloss_doutput_data, - sizeof(OMDataType(output->type())) * - output_shape.dims(output_shape.dimensionsCount() - 1)); + std::memcpy(dloss_dbias_data, dloss_doutput_data, + sizeof(OMDataType(output->type())) * + output_shape.dims(output_shape.dimensionsCount() - 1)); + } } // 4. Calculate (if needed) input grad diff --git a/onert-micro/onert-micro/src/train/kernels/Softmax.cpp b/onert-micro/onert-micro/src/train/kernels/Softmax.cpp index b2ab791fa42..e4577be3ce3 100644 --- a/onert-micro/onert-micro/src/train/kernels/Softmax.cpp +++ b/onert-micro/onert-micro/src/train/kernels/Softmax.cpp @@ -111,8 +111,14 @@ OMStatus onert_micro::train::train_kernel_CircleSoftmax(const OMBackpropExecuteA core::utils::castOutputData(jacobian_row_data), core::utils::castOutputData(dloss_dinput_data)); +#ifdef OM_MEMORY_ESTIMATE + // Deallocate temporary buffer with Jacobian row + status = core::memory::OMMemoryManager::deallocateMemory( + output_shape.flatSize() * sizeof(OMDataType(output->type())), jacobian_row_data); +#elif // Deallocate temporary buffer with Jacobian row status = core::memory::OMMemoryManager::deallocateMemory(jacobian_row_data); +#endif return status; } diff --git a/onert-micro/onert-micro/src/train/train_optimizers/Adam.cpp b/onert-micro/onert-micro/src/train/train_optimizers/Adam.cpp index 0c32d42feca..60c5ee81865 100644 --- a/onert-micro/onert-micro/src/train/train_optimizers/Adam.cpp +++ b/onert-micro/onert-micro/src/train/train_optimizers/Adam.cpp @@ -26,6 +26,124 @@ using namespace onert_micro; using namespace onert_micro::train; using namespace onert_micro::train::optimizers; +namespace +{ +inline std::pair getUpLowerWeightTensorDepth(core::OpTrainableRankType rank, + const uint32_t output_depth) +{ + std::pair result(0u, output_depth); + + switch (rank) + { + case core::ALL: + break; + case core::UP_1_2_PART: + result.second = static_cast(static_cast(output_depth) / 2.f); + break; + case core::LOWER_1_2_PART: + result.first = static_cast(static_cast(output_depth) / 2.f); + break; + default: + assert("Unsupported type"); + break; + } + + return result; +} +} // namespace + +#ifdef OM_MEMORY_ESTIMATE +void Adam::fullReset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage) +{ + for (auto &cur_tensor_index_data : _tensor_to_exponent_avg) + { + uint8_t *allocated_data = cur_tensor_index_data.second; + auto tensor_index = cur_tensor_index_data.first; + + auto tensor = context.getTensorByIndex(tensor_index); + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES + + auto tensor_size = num_elements * sizeof(core::OMDataType(tensor->type())); + + core::memory::OMMemoryManager::deallocateMemory(tensor_size, allocated_data); + } + _tensor_to_exponent_avg.clear(); + + for (auto &cur_tensor_index_data : _tensor_to_exponent_avg_squares) + { + uint8_t *allocated_data = cur_tensor_index_data.second; + + auto tensor_index = cur_tensor_index_data.first; + + auto tensor = context.getTensorByIndex(tensor_index); + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES + + auto tensor_size = num_elements * sizeof(core::OMDataType(tensor->type())); + + core::memory::OMMemoryManager::deallocateMemory(tensor_size, allocated_data); + } + _tensor_to_exponent_avg_squares.clear(); + + for (auto &cur_tensor_index_data : _tensor_index_to_gradient) + { + uint8_t *allocated_data = cur_tensor_index_data.second; + + auto tensor_index = cur_tensor_index_data.first; + + auto tensor = context.getTensorByIndex(tensor_index); + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES + + auto tensor_size = num_elements * sizeof(core::OMDataType(tensor->type())); + + core::memory::OMMemoryManager::deallocateMemory(tensor_size, allocated_data); + } + auto tmp = core::memory::OMMemoryManager::cur_memory_allocated; + _tensor_index_to_gradient.clear(); +} + +void Adam::reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage) +{ + for (auto &cur_tensor_index_data : _tensor_index_to_gradient) + { + uint8_t *allocated_data = cur_tensor_index_data.second; + + auto tensor_index = cur_tensor_index_data.first; + + auto tensor = context.getTensorByIndex(tensor_index); + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES + + auto tensor_size = num_elements * sizeof(core::OMDataType(tensor->type())); + core::memory::OMMemoryManager::deallocateMemory(tensor_size, allocated_data); + } + _tensor_index_to_gradient.clear(); +} + +#endif // OM_MEMORY_ESTIMATE + void Adam::fullReset() { for (auto &cur_tensor_index_data : _tensor_to_exponent_avg) @@ -103,7 +221,8 @@ void Adam::setExponentAvgSquaresDataByTensorIndex(uint16_t tensor_index, uint8_t * Update internal states according to calculated gradients using Adam theory * grad(t) = grad(t - 1) + calculated_grad(t) */ -OMStatus Adam::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context) +OMStatus Adam::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context, + core::OMRuntimeStorage &storage) { auto &backward_tensor_to_data = backward_storage.getTensorIndexToData(); @@ -117,32 +236,36 @@ OMStatus Adam::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeC // This should be done due to execution plan work for (auto &tensor_to_data : backward_tensor_to_data) { - auto tensor = context.getTensorByIndex(tensor_to_data.first); - core::OMRuntimeShape shape(tensor); + auto tensor_index = tensor_to_data.first; + auto tensor = context.getTensorByIndex(tensor_index); + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES - const auto flat_size = shape.flatSize(); - const auto type_size = sizeof(core::OMDataType(tensor->type())); + auto tensor_size = num_elements * sizeof(core::OMDataType(tensor->type())); // Allocate data for exponent calculation uint8_t *exponent_data = nullptr; - OMStatus status = - core::memory::OMMemoryManager::allocateMemory(flat_size * type_size, &exponent_data); + OMStatus status = core::memory::OMMemoryManager::allocateMemory(tensor_size, &exponent_data); assert(status == Ok); if (status != Ok) return UnknownError; // Set to zeros - std::memset(exponent_data, 0, flat_size * type_size); + std::memset(exponent_data, 0, tensor_size); _tensor_to_exponent_avg[tensor_to_data.first] = exponent_data; // Allocate data for exponent square calculation uint8_t *exponent_square_data = nullptr; - status = - core::memory::OMMemoryManager::allocateMemory(flat_size * type_size, &exponent_square_data); + status = core::memory::OMMemoryManager::allocateMemory(tensor_size, &exponent_square_data); assert(status == Ok); if (status != Ok) return UnknownError; // Set to zeros - std::memset(exponent_square_data, 0, flat_size * type_size); + std::memset(exponent_square_data, 0, tensor_size); _tensor_to_exponent_avg_squares[tensor_to_data.first] = exponent_square_data; } } @@ -170,14 +293,18 @@ OMStatus Adam::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeC for (auto &tensor_to_data : backward_tensor_to_data) { auto tensor = context.getTensorByIndex(tensor_to_data.first); - core::OMRuntimeShape shape(tensor); + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); - const auto flat_size = shape.flatSize(); +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_to_data.first).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES auto *grad_data = reinterpret_cast(_tensor_index_to_gradient[tensor_to_data.first]); auto *calculated_data = reinterpret_cast(tensor_to_data.second); - for (uint32_t i = 0; i < flat_size; ++i) + for (uint32_t i = 0; i < num_elements; ++i) { grad_data[i] += calculated_data[i]; } @@ -198,11 +325,12 @@ OMStatus Adam::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeC * * w(t + 1) = w(t) - lambda * m`(t) / (sqrt(v` + epsilon)) */ -OMStatus Adam::updateWeights(const onert_micro::OMTrainingContext &training_config, - core::OMRuntimeContext &context) +OMStatus Adam::updateWeights( + const onert_micro::OMTrainingContext &training_config, core::OMRuntimeContext &context, + core::OMRuntimeStorage &storage, + std::unordered_map &tensor_index_to_rank_type_map) { assert(!_tensor_index_to_gradient.empty()); - for (auto &tensor_to_data : _tensor_index_to_gradient) { auto exponent_squares_it = _tensor_to_exponent_avg_squares.find(tensor_to_data.first); @@ -216,7 +344,15 @@ OMStatus Adam::updateWeights(const onert_micro::OMTrainingContext &training_conf auto tensor = context.getTensorByIndex(tensor_to_data.first); core::OMRuntimeShape shape(tensor); - const auto flat_size = shape.flatSize(); + auto original_d = shape.dims(0); + + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_to_data.first).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES auto *exponent_data = reinterpret_cast(exponent_it->second); auto *exponent_square_data = reinterpret_cast(exponent_squares_it->second); @@ -224,7 +360,7 @@ OMStatus Adam::updateWeights(const onert_micro::OMTrainingContext &training_conf float beta = training_config.beta; float beta_squares = training_config.beta_squares; auto batches = static_cast(training_config.batch_size); - for (uint32_t i = 0; i < flat_size; ++i) + for (uint32_t i = 0; i < num_elements; ++i) { const auto cur_val = calculated_data[i]; exponent_data[i] = beta * exponent_data[i] + (1 - beta) * cur_val; @@ -249,12 +385,17 @@ OMStatus Adam::updateWeights(const onert_micro::OMTrainingContext &training_conf assert((1.f - beta_in_pow_batch) != 0); assert((1.f - beta_square_in_pow_batch) != 0); + auto train_it = tensor_index_to_rank_type_map.find(tensor_to_data.first); + core::OpTrainableRankType rank = train_it == tensor_index_to_rank_type_map.end() + ? core::OpTrainableRankType::ALL + : core::OpTrainableRankType(train_it->second); + auto depth_bounds = getUpLowerWeightTensorDepth(rank, original_d); - for (uint32_t i = 0; i < flat_size; ++i) + for (uint32_t i = 0; i < num_elements; ++i) { float exponent_corrected = exponent_data[i] / (1.f - beta_in_pow_batch); float exponent_square_corrected = exponent_square_data[i] / (1.f - beta_square_in_pow_batch); - f_weight_data[i] -= + f_weight_data[i + depth_bounds.first] -= lambda * (exponent_corrected / (std::sqrt(exponent_square_corrected + epsilon))); } } diff --git a/onert-micro/onert-micro/src/train/train_optimizers/SGD.cpp b/onert-micro/onert-micro/src/train/train_optimizers/SGD.cpp index a33b0ee12bb..a0be8e69647 100644 --- a/onert-micro/onert-micro/src/train/train_optimizers/SGD.cpp +++ b/onert-micro/onert-micro/src/train/train_optimizers/SGD.cpp @@ -18,11 +18,64 @@ #include "train/train_optimizers/SGD.h" #include "core/memory/OMMemoryManager.h" #include "core/OMRuntimeShape.h" +#include "core/OMDataType.h" using namespace onert_micro; using namespace onert_micro::train; using namespace onert_micro::train::optimizers; +namespace +{ +inline std::pair getUpLowerWeightTensorDepth(core::OpTrainableRankType rank, + const uint32_t output_depth) +{ + std::pair result(0u, output_depth); + + switch (rank) + { + case core::ALL: + break; + case core::UP_1_2_PART: + result.second = static_cast(static_cast(output_depth) / 2.f); + break; + case core::LOWER_1_2_PART: + result.first = static_cast(static_cast(output_depth) / 2.f); + break; + default: + assert("Unsupported type"); + break; + } + + return result; +} +} // namespace + +#ifdef OM_MEMORY_ESTIMATE + +void SGD::reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage) +{ + for (auto &cur_tensor_index_data : _tensor_index_to_gradient) + { + uint8_t *allocated_data = cur_tensor_index_data.second; + auto tensor_index = cur_tensor_index_data.first; + + auto tensor = context.getTensorByIndex(tensor_index); + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); + +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_index).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES + + auto tensor_size = num_elements * sizeof(core::OMDataType(tensor->type())); + core::memory::OMMemoryManager::deallocateMemory(tensor_size, allocated_data); + } + _tensor_index_to_gradient.clear(); +} + +#endif // OM_MEMORY_ESTIMATE + void SGD::reset() { for (auto &cur_tensor_index_data : _tensor_index_to_gradient) @@ -38,7 +91,8 @@ void SGD::reset() * Update internal states according to calculated gradients using Adam theory * grad(t) = grad(t - 1) + calculated_grad(t) */ -OMStatus SGD::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context) +OMStatus SGD::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context, + core::OMRuntimeStorage &storage) { auto &backward_tensor_to_data = backward_storage.getTensorIndexToData(); // Check is allocated or not helper buffers @@ -64,14 +118,18 @@ OMStatus SGD::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeCo for (auto &tensor_to_data : backward_tensor_to_data) { auto tensor = context.getTensorByIndex(tensor_to_data.first); - core::OMRuntimeShape shape(tensor); + auto num_elements = core::OMRuntimeShape(tensor).flatSize(); - const auto flat_size = shape.flatSize(); +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_to_data.first).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES auto *grad_data = reinterpret_cast(_tensor_index_to_gradient[tensor_to_data.first]); auto *calculated_data = reinterpret_cast(tensor_to_data.second); - for (uint32_t i = 0; i < flat_size; ++i) + for (uint32_t i = 0; i < num_elements; ++i) { grad_data[i] += calculated_data[i]; } @@ -86,8 +144,10 @@ OMStatus SGD::handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeCo * * w(t + 1) = w(t) - lambda * grad(t) / batch_size */ -OMStatus SGD::updateWeights(const onert_micro::OMTrainingContext &training_config, - core::OMRuntimeContext &context) +OMStatus SGD::updateWeights( + const onert_micro::OMTrainingContext &training_config, core::OMRuntimeContext &context, + core::OMRuntimeStorage &storage, + std::unordered_map &tensor_index_to_rank_type_map) { assert(!_tensor_index_to_gradient.empty()); if (_tensor_index_to_gradient.empty()) @@ -97,8 +157,15 @@ OMStatus SGD::updateWeights(const onert_micro::OMTrainingContext &training_confi { auto tensor = context.getTensorByIndex(tensor_to_data.first); core::OMRuntimeShape shape(tensor); + auto num_elements = shape.flatSize(); + + auto original_d = shape.dims(0); - const auto flat_size = shape.flatSize(); +#ifndef DIS_DYN_SHAPES + int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_to_data.first).flatSize(); + if (dynamic_tensor_size != 0) + num_elements = dynamic_tensor_size; +#endif // DIS_DYN_SHAPES auto *grad_data = reinterpret_cast(tensor_to_data.second); uint8_t *weight_data = nullptr; @@ -112,10 +179,15 @@ OMStatus SGD::updateWeights(const onert_micro::OMTrainingContext &training_confi auto *f_weight_data = reinterpret_cast(weight_data); float lambda = training_config.learning_rate; const uint32_t batch_size = training_config.batch_size; + auto train_it = tensor_index_to_rank_type_map.find(tensor_to_data.first); + core::OpTrainableRankType rank = train_it == tensor_index_to_rank_type_map.end() + ? core::OpTrainableRankType::ALL + : core::OpTrainableRankType(train_it->second); + auto depth_bounds = getUpLowerWeightTensorDepth(rank, original_d); assert(batch_size != 0); - for (uint32_t i = 0; i < flat_size; ++i) + for (uint32_t i = 0; i < num_elements; ++i) { f_weight_data[i] -= (lambda * grad_data[i]) / (static_cast(batch_size)); }