Skip to content

Commit

Permalink
[onert-micro] Support training configure tool
Browse files Browse the repository at this point in the history
This pr supports training configure tool in onert-micro.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Aug 5, 2024
1 parent ae71417 commit 8237e03
Show file tree
Hide file tree
Showing 46 changed files with 1,361 additions and 232 deletions.
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ if (DIS_FLOAT)
add_definitions(-DDIS_FLOAT)
endif()

# To enable memory estimate
if (OM_MEMORY_ESTIMATE)
add_definitions(-DOM_MEMORY_ESTIMATE)
endif()

# To enable training part
if (ENABLE_TRAINING)
add_definitions(-DENABLE_TRAINING)
Expand Down
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/include/OMConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ enum OMMetrics
MAE_METRICS,
CROSS_ENTROPY_METRICS,
ACCURACY,
NONE,
};

/*
Expand All @@ -65,6 +66,8 @@ enum OMLoss
* beta_squares - used by ADAM optimizer
* epsilon - used by ADAM optimizer
* num_Step - used by ADAM optimizer
* training_config_info_data - pointer to the training config data, to store training specific
* scenario (default null)
*/
struct OMTrainingContext
{
Expand All @@ -79,6 +82,8 @@ struct OMTrainingContext
uint32_t num_step = 0;
uint32_t num_epoch = 0;
uint32_t epochs = 0;

char *training_config_info_data = nullptr;
};

/*
Expand Down
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/include/OMTrainingInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class OMTrainingInterpreter
void *getInputData(uint32_t position);
void *getInputDataAt(uint32_t position);
void *getOutputDataAt(uint32_t position);

#ifdef OM_MEMORY_ESTIMATE
size_t getPeakFootprintMemory() { return _training_runtime_module.getPeakFootprintMemory(); }
size_t getCurrentFootprintMemory() { return _training_runtime_module.getPeakFootprintMemory(); }
#endif // OM_MEMORY_ESTIMATE
};

} // namespace onert_micro
Expand Down
10 changes: 10 additions & 0 deletions onert-micro/onert-micro/include/core/OMKernelType.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ namespace onert_micro
namespace core
{

// Enum to indicate the degree(rank) to which part of the operation we will train
enum OpTrainableRankType
{
ALL = 0, // 0 - Train all weights in the operation
ONLY_BIAS = 1, // 1 - Train bias only in the operation
UP_1_2_PART = 2, // 2 - Train the upper 1/2 part of the operation
LOWER_1_2_PART = 3, // 3 - Train the lower 1/2 part of the operation
// TODO add more
};

enum OMKernelType
{
Normal,
Expand Down
27 changes: 24 additions & 3 deletions onert-micro/onert-micro/include/core/OMRuntimeContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@

#include "OMStatus.h"

#include "core/OMRuntimeShape.h"
#include "core/OMRuntimeStorage.h"

#include "reader/OMCircleReader.h"
#include "reader/OMWeightOnlyFormatReader.h"
#include "reader/OMTrainingConfigFileReader.h"

#include <cstdint>

Expand All @@ -32,8 +36,9 @@ namespace core
class OMRuntimeContext
{
private:
reader::OMCircleReader _reader;
reader::OMWeightOnlyFormatReader _wof_reader;
reader::OMCircleReader _reader{};
reader::OMWeightOnlyFormatReader _wof_reader{};
reader::OMTrainingConfigReader _train_config_reader{};

public:
OMRuntimeContext() = default;
Expand Down Expand Up @@ -66,7 +71,23 @@ class OMRuntimeContext
return Ok;
}

const bool isConstTensor(uint32_t tensor_index) { return _reader.isConstTensor(tensor_index); }
OMStatus setTrainConfigFile(char *train_config_file_ptr)
{
OMStatus status = Ok;
_train_config_reader.parse(train_config_file_ptr);

status = _train_config_reader.validate(&_reader);
if (status != Ok)
return status;
return Ok;
}

std::unordered_map<uint16_t, uint8_t> getTrainableOpsIndexes()
{
return _train_config_reader.getTrainableOpsIndexes();
}

bool isConstTensor(uint32_t tensor_index) { return _reader.isConstTensor(tensor_index); }

const reader::CircleValues *getCircleOutputs() { return _reader.outputs(); }
const reader::CircleValues *getCircleInputs() { return _reader.inputs(); }
Expand Down
2 changes: 2 additions & 0 deletions onert-micro/onert-micro/include/core/OMRuntimeShape.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class OMRuntimeShape
// vector.
inline int flatSize() const
{
if (_size == 0)
return 0;
int buffer_size = 1;
const int *dims_data = reinterpret_cast<const int *>(dimsData());
for (int i = 0; i < _size; i++)
Expand Down
10 changes: 5 additions & 5 deletions onert-micro/onert-micro/include/core/OMRuntimeStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class OMRuntimeStorage
{
private:
#ifndef DIS_DYN_SHAPES
std::unordered_map<uint16_t, uint32_t> _tensor_index_to_dynamic_tensor_size;
std::unordered_map<uint16_t, OMRuntimeShape> _tensor_index_to_dynamic_tensor_size;
#endif
std::unordered_map<uint16_t, uint8_t *> _tensor_index_to_data;
std::unordered_map<uint16_t, OMKernelType> _operator_index_to_kernel_type;
Expand Down Expand Up @@ -70,18 +70,18 @@ class OMRuntimeStorage
return Ok;
}
#ifndef DIS_DYN_SHAPES
int32_t getDynamicTensorSize(uint16_t tensor_index)
OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index)
{
auto it = _tensor_index_to_dynamic_tensor_size.find(tensor_index);
if (it == _tensor_index_to_dynamic_tensor_size.end())
return -1;
return {}; // Return empty

return it->second;
}

OMStatus setDynamicTensorSize(uint16_t tensor_index, uint32_t dynamic_size)
OMStatus setDynamicRuntimeShape(uint16_t tensor_index, const OMRuntimeShape &shape)
{
_tensor_index_to_dynamic_tensor_size[tensor_index] = dynamic_size;
_tensor_index_to_dynamic_tensor_size[tensor_index] = shape;
return Ok;
}
#endif // DIS_DYN_SHAPES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class OMTrainingRuntimeModule : public OMRuntimeModule
OMStatus loadCheckpointData(OMConfig &config, const char *data);

void *getInputData(int32_t index);

#ifdef OM_MEMORY_ESTIMATE
size_t getPeakFootprintMemory();
size_t getCurrentFootprintMemory();
#endif // OM_MEMORY_ESTIMATE
};

} // namespace core
Expand Down
7 changes: 7 additions & 0 deletions onert-micro/onert-micro/include/core/memory/OMMemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "OMStatus.h"

#include <cstdint>
#include <stdlib.h>

namespace onert_micro
{
Expand All @@ -30,6 +31,12 @@ namespace memory

struct OMMemoryManager
{
// Need for configure tool estimations
#ifdef OM_MEMORY_ESTIMATE
static size_t peak_memory_allocated;
static size_t cur_memory_allocated;
static OMStatus deallocateMemory(uint32_t size, uint8_t *data);
#endif // OM_MEMORY_ESTIMATE
static OMStatus allocateMemory(uint32_t size, uint8_t **data);
static OMStatus deallocateMemory(uint8_t *data);
};
Expand Down
16 changes: 4 additions & 12 deletions onert-micro/onert-micro/include/core/memory/OMRuntimeAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,6 @@ class OMRuntimeAllocator
OMRuntimeAllocator(OMRuntimeAllocator &&) = default;
~OMRuntimeAllocator() = default;

void saveAllocPlan(std::vector<std::vector<uint16_t>> &&alloc_plan)
{
_alloc_plan.clear();
_alloc_plan = std::move(alloc_plan);
}

void saveDeallocPlan(std::vector<std::vector<uint16_t>> &&dealloc_plan)
{
_dealloc_plan.clear();
_dealloc_plan = std::move(dealloc_plan);
}

std::vector<std::vector<uint16_t>> &getAllocPlan() { return _alloc_plan; }

std::vector<std::vector<uint16_t>> &getDeallocPlan() { return _dealloc_plan; }
Expand All @@ -67,6 +55,10 @@ class OMRuntimeAllocator

OMStatus allocate(size_t kernel_index, OMRuntimeContext *context, OMRuntimeStorage *storage);
OMStatus deallocate(size_t kernel_index, OMRuntimeStorage *storage);
// Need for configure tool estimations
#ifdef OM_MEMORY_ESTIMATE
OMStatus deallocate(size_t kernel_index, OMRuntimeStorage *storage, OMRuntimeContext *context);
#endif // OM_MEMORY_ESTIMTE
};

} // namespace memory
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H
#define ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H

#include "OMStatus.h"
#include "OMCircleReader.h"

#include <unordered_map>

namespace onert_micro
{
namespace core
{
namespace reader
{
namespace
{

enum TrainConfigFileFieldsOffsets
{
MAGIC_NUMBER_FIELD = 0,
SCHEMA_VERSION_FIELD = 2,
NUM_LAYERS_FIELD = 4,
FIRST_LAYER_INDEX_FIELD = 8
};

} // namespace

constexpr uint16_t train_config_file_magic_number = 29;
constexpr uint8_t train_config_file_schema_version = 1;

/**
* @brief Loads Training Config files and provides helpers functions
*/
class OMTrainingConfigReader
{
public:
OMTrainingConfigReader() = default;
OMTrainingConfigReader(const OMTrainingConfigReader &) = delete;
OMTrainingConfigReader(OMTrainingConfigReader &&) = default;
OMTrainingConfigReader &operator=(const OMTrainingConfigReader &) = delete;
OMTrainingConfigReader &&operator=(const OMTrainingConfigReader &&) = delete;
~OMTrainingConfigReader() = default;

public:
// To validate _train_config_ptr and compare with circle model saved in reader.
OMStatus validate(OMCircleReader *reader);

// Save pointer
void parse(char *ptr) { _train_config_ptr = ptr; }

// Read and return indexes of trainable layers from config file
// first it is op index in graph, second is rank of the training (see OpTrainableRank)
std::unordered_map<uint16_t, uint8_t> getTrainableOpsIndexes();

private:
char *_train_config_ptr;
};

} // namespace reader
} // namespace core
} // namespace onert_micro

#endif // ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H
10 changes: 9 additions & 1 deletion onert-micro/onert-micro/include/core/train/OMTrainingHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class OMTrainingHandler
OMRuntimeContext &context);

// Handle with updating weights with current optimizer
OMStatus updateWeights(const OMConfig &config, OMRuntimeContext &context);
OMStatus updateWeights(const OMConfig &config, OMRuntimeContext &context,
OMRuntimeStorage &storage);

// Evaluate metric and save result in metric_val
// Warning: 1) assume that all metric_val for all OMMetrics types actually are float values.
Expand All @@ -95,6 +96,13 @@ class OMTrainingHandler

// Reset and deallocate all internal states
void reset();

#ifdef OM_MEMORY_ESTIMATE

// Reset and deallocate all states
void reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage);

#endif // OM_MEMORY_ESTIMATE
};

} // namespace train
Expand Down
20 changes: 20 additions & 0 deletions onert-micro/onert-micro/include/core/train/OMTrainingStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class OMTrainingStorage
// Note: initial its null
std::unique_ptr<onert_micro::train::optimizers::Adam> _adam_optimizer = nullptr;

// Store rank types
std::unordered_map<uint16_t, core::OpTrainableRankType> _tensor_index_to_train_rank;

public:
OMTrainingStorage() = default;
OMTrainingStorage(const OMTrainingStorage &) = delete;
Expand All @@ -70,6 +73,16 @@ class OMTrainingStorage
_target_index_to_target_data[target_index] = data;
}

void addTrainRank(uint16_t tensor_index, core::OpTrainableRankType train_rank)
{
_tensor_index_to_train_rank[tensor_index] = train_rank;
}

std::unordered_map<uint16_t, core::OpTrainableRankType> &getTensorIndexToRankTypeTable()
{
return _tensor_index_to_train_rank;
}

// Choose and set optimizer defined in config
OMStatus setOptimizer(const OMConfig &config);

Expand All @@ -88,6 +101,13 @@ class OMTrainingStorage
return _target_index_to_target_data[target_index];
}

#ifdef OM_MEMORY_ESTIMATE

// Reset and deallocate all states
void reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage);

#endif // OM_MEMORY_ESTIMATE

// Reset and deallocate all states
void reset();
};
Expand Down
Loading

0 comments on commit 8237e03

Please sign in to comment.