Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert-micro] Support training configure tool #13592

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ if (DIS_FLOAT)
add_definitions(-DDIS_FLOAT)
endif()

# To enable memory estimate
if (OM_MEMORY_ESTIMATE)
add_definitions(-DOM_MEMORY_ESTIMATE)
endif()

# To enable training part
if (ENABLE_TRAINING)
add_definitions(-DENABLE_TRAINING)
Expand Down
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/include/OMConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ enum OMMetrics
CROSS_ENTROPY_METRICS,
ACCURACY,
SPARSE_CROSS_ENTROPY_ACCURACY,
NONE,
};

/*
Expand All @@ -67,6 +68,8 @@ enum OMLoss
* beta_squares - used by ADAM optimizer
* epsilon - used by ADAM optimizer
* num_Step - used by ADAM optimizer
* training_config_info_data - pointer to the training config data, to store training specific
* scenario (default null)
*/
struct OMTrainingContext
{
Expand All @@ -81,6 +84,8 @@ struct OMTrainingContext
uint32_t num_step = 0;
uint32_t num_epoch = 0;
uint32_t epochs = 0;

char *training_config_info_data = nullptr;
};

/*
Expand Down
5 changes: 5 additions & 0 deletions onert-micro/onert-micro/include/OMTrainingInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class OMTrainingInterpreter
void *getInputData(uint32_t position);
void *getInputDataAt(uint32_t position);
void *getOutputDataAt(uint32_t position);

#ifdef OM_MEMORY_ESTIMATE
size_t getPeakFootprintMemory() { return _training_runtime_module.getPeakFootprintMemory(); }
size_t getCurrentFootprintMemory() { return _training_runtime_module.getPeakFootprintMemory(); }
#endif // OM_MEMORY_ESTIMATE
};

} // namespace onert_micro
Expand Down
10 changes: 10 additions & 0 deletions onert-micro/onert-micro/include/core/OMKernelType.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ namespace onert_micro
namespace core
{

// Enum to indicate the degree(rank) to which part of the operation we will train
enum OpTrainableRankType
{
ALL = 0, // 0 - Train all weights in the operation
ONLY_BIAS = 1, // 1 - Train bias only in the operation
UP_1_2_PART = 2, // 2 - Train the upper 1/2 part of the operation
LOWER_1_2_PART = 3, // 3 - Train the lower 1/2 part of the operation
// TODO add more
};

enum OMKernelType
{
Normal,
Expand Down
27 changes: 24 additions & 3 deletions onert-micro/onert-micro/include/core/OMRuntimeContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@

#include "OMStatus.h"

#include "core/OMRuntimeShape.h"
#include "core/OMRuntimeStorage.h"

#include "reader/OMCircleReader.h"
#include "reader/OMWeightOnlyFormatReader.h"
#include "reader/OMTrainingConfigFileReader.h"

#include <cstdint>

Expand All @@ -32,8 +36,9 @@ namespace core
class OMRuntimeContext
{
private:
reader::OMCircleReader _reader;
reader::OMWeightOnlyFormatReader _wof_reader;
reader::OMCircleReader _reader{};
reader::OMWeightOnlyFormatReader _wof_reader{};
reader::OMTrainingConfigReader _train_config_reader{};

public:
OMRuntimeContext() = default;
Expand Down Expand Up @@ -66,7 +71,23 @@ class OMRuntimeContext
return Ok;
}

const bool isConstTensor(uint32_t tensor_index) { return _reader.isConstTensor(tensor_index); }
OMStatus setTrainConfigFile(char *train_config_file_ptr)
{
OMStatus status = Ok;
_train_config_reader.parse(train_config_file_ptr);

status = _train_config_reader.validate(&_reader);
if (status != Ok)
return status;
return Ok;
}

std::unordered_map<uint16_t, uint8_t> getTrainableOpsIndexes()
{
return _train_config_reader.getTrainableOpsIndexes();
}

bool isConstTensor(uint32_t tensor_index) { return _reader.isConstTensor(tensor_index); }

const reader::CircleValues *getCircleOutputs() { return _reader.outputs(); }
const reader::CircleValues *getCircleInputs() { return _reader.inputs(); }
Expand Down
12 changes: 11 additions & 1 deletion onert-micro/onert-micro/include/core/OMRuntimeShape.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,17 @@ class OMRuntimeShape

OMRuntimeShape(const circle::Tensor *tensor)
{
if (tensor == nullptr or tensor->shape() == nullptr)
if (tensor == nullptr)
return;

// Shape is scalar
if (tensor->shape() == nullptr or tensor->shape()->size() == 0)
{
_size = 1;
_dims[0] = 1;
return;
}

_size = tensor->shape()->size();
std::memcpy(_dims, tensor->shape()->data(), sizeof(int32_t) * _size);
}
Expand All @@ -55,6 +63,8 @@ class OMRuntimeShape
// vector.
inline int flatSize() const
{
if (_size == 0)
return 0;
int buffer_size = 1;
const int *dims_data = reinterpret_cast<const int *>(dimsData());
for (int i = 0; i < _size; i++)
Expand Down
10 changes: 5 additions & 5 deletions onert-micro/onert-micro/include/core/OMRuntimeStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class OMRuntimeStorage
{
private:
#ifndef DIS_DYN_SHAPES
std::unordered_map<uint16_t, uint32_t> _tensor_index_to_dynamic_tensor_size;
std::unordered_map<uint16_t, OMRuntimeShape> _tensor_index_to_dynamic_tensor_size;
#endif
std::unordered_map<uint16_t, uint8_t *> _tensor_index_to_data;
std::unordered_map<uint16_t, OMKernelType> _operator_index_to_kernel_type;
Expand Down Expand Up @@ -70,18 +70,18 @@ class OMRuntimeStorage
return Ok;
}
#ifndef DIS_DYN_SHAPES
int32_t getDynamicTensorSize(uint16_t tensor_index)
OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index)
{
auto it = _tensor_index_to_dynamic_tensor_size.find(tensor_index);
if (it == _tensor_index_to_dynamic_tensor_size.end())
return -1;
return {}; // Return empty

return it->second;
}

OMStatus setDynamicTensorSize(uint16_t tensor_index, uint32_t dynamic_size)
OMStatus setDynamicRuntimeShape(uint16_t tensor_index, const OMRuntimeShape &shape)
{
_tensor_index_to_dynamic_tensor_size[tensor_index] = dynamic_size;
_tensor_index_to_dynamic_tensor_size[tensor_index] = shape;
BalyshevArtem marked this conversation as resolved.
Show resolved Hide resolved
return Ok;
}
#endif // DIS_DYN_SHAPES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class OMTrainingRuntimeModule : public OMRuntimeModule
OMStatus loadCheckpointData(OMConfig &config, const char *data);

void *getInputData(int32_t index);

#ifdef OM_MEMORY_ESTIMATE
size_t getPeakFootprintMemory();
size_t getCurrentFootprintMemory();
#endif // OM_MEMORY_ESTIMATE
};

} // namespace core
Expand Down
7 changes: 7 additions & 0 deletions onert-micro/onert-micro/include/core/memory/OMMemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "OMStatus.h"

#include <cstdint>
#include <stdlib.h>

namespace onert_micro
{
Expand All @@ -30,6 +31,12 @@ namespace memory

struct OMMemoryManager
{
// Need for configure tool estimations
#ifdef OM_MEMORY_ESTIMATE
static size_t peak_memory_allocated;
static size_t cur_memory_allocated;
static OMStatus deallocateMemory(uint32_t size, uint8_t *data);
#endif // OM_MEMORY_ESTIMATE
static OMStatus allocateMemory(uint32_t size, uint8_t **data);
static OMStatus deallocateMemory(uint8_t *data);
};
Expand Down
16 changes: 4 additions & 12 deletions onert-micro/onert-micro/include/core/memory/OMRuntimeAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,6 @@ class OMRuntimeAllocator
OMRuntimeAllocator(OMRuntimeAllocator &&) = default;
~OMRuntimeAllocator() = default;

void saveAllocPlan(std::vector<std::vector<uint16_t>> &&alloc_plan)
{
_alloc_plan.clear();
_alloc_plan = std::move(alloc_plan);
}

void saveDeallocPlan(std::vector<std::vector<uint16_t>> &&dealloc_plan)
{
_dealloc_plan.clear();
_dealloc_plan = std::move(dealloc_plan);
}

std::vector<std::vector<uint16_t>> &getAllocPlan() { return _alloc_plan; }

std::vector<std::vector<uint16_t>> &getDeallocPlan() { return _dealloc_plan; }
Expand All @@ -67,6 +55,10 @@ class OMRuntimeAllocator

OMStatus allocate(size_t kernel_index, OMRuntimeContext *context, OMRuntimeStorage *storage);
OMStatus deallocate(size_t kernel_index, OMRuntimeStorage *storage);
// Need for configure tool estimations
#ifdef OM_MEMORY_ESTIMATE
OMStatus deallocate(size_t kernel_index, OMRuntimeStorage *storage, OMRuntimeContext *context);
#endif // OM_MEMORY_ESTIMTE
};

} // namespace memory
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H
#define ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H

#include "OMStatus.h"
#include "OMCircleReader.h"

#include <unordered_map>

namespace onert_micro
{
namespace core
{
namespace reader
{
namespace
{

enum TrainConfigFileFieldsOffsets
{
MAGIC_NUMBER_FIELD = 0,
SCHEMA_VERSION_FIELD = 2,
NUM_LAYERS_FIELD = 4,
FIRST_LAYER_INDEX_FIELD = 8
};

} // namespace

constexpr uint16_t train_config_file_magic_number = 29;
constexpr uint8_t train_config_file_schema_version = 1;

/**
* @brief Loads Training Config files and provides helpers functions
*/
class OMTrainingConfigReader
{
public:
OMTrainingConfigReader() = default;
OMTrainingConfigReader(const OMTrainingConfigReader &) = delete;
OMTrainingConfigReader(OMTrainingConfigReader &&) = default;
OMTrainingConfigReader &operator=(const OMTrainingConfigReader &) = delete;
OMTrainingConfigReader &&operator=(const OMTrainingConfigReader &&) = delete;
~OMTrainingConfigReader() = default;

public:
// To validate _train_config_ptr and compare with circle model saved in reader.
OMStatus validate(OMCircleReader *reader);

// Save pointer
void parse(char *ptr) { _train_config_ptr = ptr; }

// Read and return indexes of trainable layers from config file
// first it is op index in graph, second is rank of the training (see OpTrainableRank)
std::unordered_map<uint16_t, uint8_t> getTrainableOpsIndexes();

private:
char *_train_config_ptr;
};

} // namespace reader
} // namespace core
} // namespace onert_micro

#endif // ONERT_MICRO_CORE_READER_TRAINING_CONFIG_FILE_READER_H
10 changes: 9 additions & 1 deletion onert-micro/onert-micro/include/core/train/OMTrainingHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class OMTrainingHandler
OMRuntimeContext &context);

// Handle with updating weights with current optimizer
OMStatus updateWeights(const OMConfig &config, OMRuntimeContext &context);
OMStatus updateWeights(const OMConfig &config, OMRuntimeContext &context,
OMRuntimeStorage &storage);

// Evaluate metric and save result in metric_val
// Warning: 1) assume that all metric_val for all OMMetrics types actually are float values.
Expand All @@ -95,6 +96,13 @@ class OMTrainingHandler

// Reset and deallocate all internal states
void reset();

#ifdef OM_MEMORY_ESTIMATE

// Reset and deallocate all states
void reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage);

#endif // OM_MEMORY_ESTIMATE
};

} // namespace train
Expand Down
20 changes: 20 additions & 0 deletions onert-micro/onert-micro/include/core/train/OMTrainingStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class OMTrainingStorage
// Note: initial its null
std::unique_ptr<onert_micro::train::optimizers::Adam> _adam_optimizer = nullptr;

// Store rank types
std::unordered_map<uint16_t, core::OpTrainableRankType> _tensor_index_to_train_rank;

public:
OMTrainingStorage() = default;
OMTrainingStorage(const OMTrainingStorage &) = delete;
Expand All @@ -70,6 +73,16 @@ class OMTrainingStorage
_target_index_to_target_data[target_index] = data;
}

void addTrainRank(uint16_t tensor_index, core::OpTrainableRankType train_rank)
{
_tensor_index_to_train_rank[tensor_index] = train_rank;
BalyshevArtem marked this conversation as resolved.
Show resolved Hide resolved
}

std::unordered_map<uint16_t, core::OpTrainableRankType> &getTensorIndexToRankTypeTable()
{
return _tensor_index_to_train_rank;
}

// Choose and set optimizer defined in config
OMStatus setOptimizer(const OMConfig &config);

Expand All @@ -88,6 +101,13 @@ class OMTrainingStorage
return _target_index_to_target_data[target_index];
}

#ifdef OM_MEMORY_ESTIMATE

// Reset and deallocate all states
void reset(core::OMRuntimeContext &context, core::OMRuntimeStorage &storage);

#endif // OM_MEMORY_ESTIMATE

// Reset and deallocate all states
void reset();
};
Expand Down
3 changes: 3 additions & 0 deletions onert-micro/onert-micro/include/execute/OMTestUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ std::vector<U> checkKernel(uint32_t num_inputs,

assert(num_inputs == interpreter.getNumberOfInputs());

interpreter.reset();
interpreter.allocateInputs();

for (uint32_t i = 0; i < num_inputs; ++i)
{
T *input_data = reinterpret_cast<T *>(interpreter.getInputDataAt(i));
Expand Down
Loading