Skip to content

Commit

Permalink
[onert/train] Register LayerScopeTensor to registry (#14235)
Browse files Browse the repository at this point in the history
This PR registers LayerScopeTensor from each layer into tensor registry.

ONE-DCO-1.0-Signed-off-by: seunghui youn <[email protected]>

--------------------------------------

draft : #13486
  • Loading branch information
zetwhite authored Oct 21, 2024
1 parent 760c5d5 commit b480c56
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,43 @@ FunctionMap BackendContext::generateFunctionMap()

void BackendContext::planLayerScopeTensors([[maybe_unused]] const FunctionMap &fn_map)
{
// TODO: Register LayerScopeTensors
const auto &ops = trainable_graph()->operations();

auto register_tensors = [this](const ir::OperationIndex &op_idx,
std::optional<LayerScopeTensors> &&tensors) {
if (not tensors.has_value())
return;

auto ls_tensors = tensors.value();
for (auto i = 0u; i < ls_tensors.size(); ++i)
{
LayerScopeTensorIndex tensor_idx(op_idx, i);
_tensor_builder->registerLayerScopeTensor(tensor_idx, ls_tensors[i]);

VERBOSE(BackendContext) << "(idx:" << tensor_idx << ") registered" << std::endl;
}
return;
};

for (auto &pair : fn_map)
{
const auto &op_idx = pair.first;
auto &fn_seq = pair.second;

const ir::IOperation *op = &ops.at(op_idx);
const auto trainable_op = dynamic_cast<const ir::train::TrainableOperation *>(op);
assert(trainable_op != nullptr);

if (not trainable_op->isRequiredForBackward())
continue;

VERBOSE(BackendContext) << "register layerscope tensor for " << trainable_op->name()
<< std::endl;

fn_seq->iterate([&](exec::train::ITrainableFunction &fn) {
register_tensors(op_idx, (&fn)->registerLayerScopeTensors());
});
}

const auto ctx_data = data();
TensorPlanner tensor_planner{*ctx_data->tgraph.get(), ctx_data->external_operands};
Expand Down

0 comments on commit b480c56

Please sign in to comment.