Skip to content

Commit

Permalink
[onert-micro] support weight quantized (int8) FullyConnected kernel
Browse files Browse the repository at this point in the history
- FullyConnected kernel (input:FLOAT32 + weights:INT8)

ONE-DCO-1.0-Signed-off-by:  Evgenii Maltsev [email protected]
  • Loading branch information
Torrero committed Oct 1, 2024
1 parent 5d8026f commit 96a6f80
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 15 deletions.
2 changes: 2 additions & 0 deletions onert-micro/onert-micro/include/core/OMKernelData.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ struct FullyConnectedParams
int32_t weights_offset;
int32_t output_offset;
int32_t output_multiplier;
const float *weights_scales;
bool is_channel_wise_quant;
int output_shift;
// uint8_t, etc, activation params.
int32_t quantized_activation_min;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ OMStatus FullyConnected(const core::FullyConnectedParams &params, const InputTyp
return Ok;
}

template <>
OMStatus inline FullyConnected<float>(const core::FullyConnectedParams &params,

template <typename WeightType>
OMStatus inline FullyConnected(const core::FullyConnectedParams &params,
const float *input_data,
const core::OMRuntimeShape &filter_shape,
const float *filter_data, const float *bias_data,
const WeightType *filter_data, const float *bias_data,
const core::OMRuntimeShape &output_shape, float *output_data)
{
const float output_activation_min = params.float_activation_min;
Expand All @@ -93,12 +94,24 @@ OMStatus inline FullyConnected<float>(const core::FullyConnectedParams &params,

for (int b = 0; b < batches; ++b)
{
const float *weight_scale_ptr = params.weights_scales;
for (int out_c = 0; out_c < output_depth; ++out_c)
{
float total = 0.f;
for (int d = 0; d < accum_depth; ++d)
{
total += input_data[b * accum_depth + d] * filter_data[out_c * accum_depth + d];
auto input_value = input_data[b * accum_depth + d];
if (std::is_same<WeightType, float>::value)
{
total += input_value * filter_data[out_c * accum_depth + d];
}
else
{
const float filter_scale = *weight_scale_ptr;
const float filter_value =
static_cast<float>(filter_data[out_c * accum_depth + d]) * filter_scale;
total += input_value * filter_value;
}
}
float bias_value = 0.0f;
if (bias_data)
Expand All @@ -107,6 +120,12 @@ OMStatus inline FullyConnected<float>(const core::FullyConnectedParams &params,
}
output_data[out_c + output_depth * b] =
std::min(std::max(total + bias_value, output_activation_min), output_activation_max);

if (std::is_same<WeightType, int8_t>::value)
{
if (params.is_channel_wise_quant)
weight_scale_ptr++;
}
}
}
return Ok;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,82 @@ const std::vector<float> reference_output_data = {263.84323, 260.84323, 259.8432

} // namespace fully_connected_float

namespace fully_connected_float_weights_quantized_int8
{

/*
* FullyConnected Kernel:
* Input - float32
* Weight - int8
* Bias - float32
* Out - float32
*
* Input(1, 4) Weight(4, 4) Bias(4)
* \ | /
* \ | /
* FullyConnected
* |
* Output(1, 4)
*/

const unsigned char test_kernel_model_circle[] = {
0x20, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00,
0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00,
0x12, 0x00, 0x00, 0x00, 0xd8, 0x00, 0x00, 0x00, 0xf4, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0xe0, 0x02, 0x00, 0x00,
0xc8, 0x02, 0x00, 0x00, 0x24, 0x02, 0x00, 0x00, 0xc0, 0x01, 0x00, 0x00, 0x90, 0x01, 0x00, 0x00,
0x74, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0xcc, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x5f, 0x6f, 0x70, 0x5f, 0x74, 0x61, 0x62, 0x6c, 0x65,
0x00, 0x00, 0x00, 0x00, 0x22, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00,
0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x4f, 0x4e, 0x45, 0x5f, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x74, 0x61, 0x62, 0x6c, 0x65,
0x00, 0x00, 0x00, 0x00, 0x62, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00,
0x6e, 0x6e, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00,
0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00,
0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x7c, 0x01, 0x00, 0x00, 0xc4, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00,
0x44, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00,
0x07, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x18, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x90, 0xfe, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0xe0, 0xfe, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x00, 0xc4, 0xfe, 0xff, 0xff,
0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0xff, 0xff, 0xff,
0x34, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x62, 0x69, 0x61, 0x73, 0x00, 0x00, 0x00, 0x00, 0xa6, 0xff, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x40, 0xc0,
0x00, 0x00, 0x80, 0x40, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x18, 0x00, 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x09, 0x94, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
0x40, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x00, 0x00,
0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x55, 0x7f, 0x00, 0x7f, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x00,
0x00, 0x00, 0x00, 0x7f, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x06, 0x83, 0xc1, 0x3c, 0x04, 0x02, 0x01, 0x3d,
0x85, 0x42, 0x21, 0x3d, 0x06, 0x83, 0x41, 0x3d, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x00, 0x00, 0xf0, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00};

const std::vector<float> input_data = {17.491695, 15.660671, 4.7347794, -15.796822};

const std::vector<float> reference_output_data = {-19.529659, 60.642685, 20.673897, -90.780930};

} // namespace fully_connected_float_weights_quantized_int8

class TestDataFloatFullyConnected : public TestDataFullyConnectedBase<float>
{
public:
Expand All @@ -109,6 +185,20 @@ class TestDataFloatFullyConnected : public TestDataFullyConnectedBase<float>
~TestDataFloatFullyConnected() override = default;
};

class TestDataFloatWQInt8FullyConnected : public TestDataFullyConnectedBase<float>
{
public:
TestDataFloatWQInt8FullyConnected()
{
_input_data = fully_connected_float_weights_quantized_int8::input_data;
_reference_output_data = fully_connected_float_weights_quantized_int8::reference_output_data;
_test_kernel_model_circle =
fully_connected_float_weights_quantized_int8::test_kernel_model_circle;
}

~TestDataFloatWQInt8FullyConnected() override = default;
};

} // namespace test_model
} // namespace onert_micro

Expand Down
34 changes: 29 additions & 5 deletions onert-micro/onert-micro/src/execute/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,35 @@ onert_micro::execute::execute_kernel_CircleFullyConnected(const OMExecuteArgs &e
if (status != Ok)
return status;

status =
pal::FullyConnected(params, core::utils::castInputData<float>(input_data),
OMRuntimeShape(weight), core::utils::castInputData<float>(weight_data),
core::utils::castInputData<float>(bias_data), OMRuntimeShape(output),
core::utils::castOutputData<float>(output_data));
switch (weight->type())
{
case circle::TensorType_FLOAT32:
{

status = pal::FullyConnected(
params, core::utils::castInputData<float>(input_data), OMRuntimeShape(weight),
core::utils::castInputData<float>(weight_data),
core::utils::castInputData<float>(bias_data), OMRuntimeShape(output),
core::utils::castOutputData<float>(output_data));
}
break;
case circle::TensorType_INT8:
{
// weight quantized INT8 mode
params.weights_scales =
reinterpret_cast<const float *>(weight->quantization()->scale()->data());
params.is_channel_wise_quant = weight->quantization()->scale()->size() > 1;

status = pal::FullyConnected(
params, core::utils::castInputData<float>(input_data), OMRuntimeShape(weight),
core::utils::castInputData<int8_t>(weight_data),
core::utils::castInputData<float>(bias_data), OMRuntimeShape(output),
core::utils::castOutputData<float>(output_data));
}
break;
default:
assert(false && "Unsupported hybrid weight type");
}
}
break;
#endif // DIS_FLOAT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ TEST_F(FullyConnectedTest, S16_P)
EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0));
}

// test hybrid kernel input:float32 + weight:int8
TEST_F(FullyConnectedTest, FloatWQInt8_P)
{
onert_micro::test_model::TestDataFloatWQInt8FullyConnected test_data_kernel;
std::vector<float> output_data_vector =
onert_micro::execute::testing::checkKernel<float>(1, &test_data_kernel);
EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0));
}

TEST_F(FullyConnectedTest, Wrong_weight_shape_NEG)
{
onert_micro::test_model::NegTestDataWrongWeightShapeFullyConnectedKernel test_data_kernel;
Expand Down
55 changes: 49 additions & 6 deletions onert-micro/onert-micro/src/import/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ constexpr uint32_t outputTensorIdx = 0;
OMStatus
onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs &config_args)
{

OMRuntimeContext &runtime_context = config_args.runtime_context;
uint16_t op_index = config_args.kernel_index;
OMRuntimeStorage &runtime_storage = config_args.runtime_storage;
Expand All @@ -50,7 +51,6 @@ onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs
const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
const circle::Tensor *weight = runtime_kernel.inputs[weightTensorIdx];
const circle::Tensor *bias = runtime_kernel.inputs[biasTensorIdx];

const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];

assert(input != nullptr);
Expand All @@ -60,13 +60,56 @@ onert_micro::import::configure_kernel_CircleFullyConnected(const OMConfigureArgs

OMStatus status = Ok;

if ((input->type() == circle::TensorType_FLOAT32 &&
weight->type() != circle::TensorType_FLOAT32) or
(input->type() == circle::TensorType_INT8 && weight->type() != circle::TensorType_INT8) or
(input->type() == circle::TensorType_INT16 && weight->type() != circle::TensorType_INT16))
#ifndef DIS_FLOAT
if (weight->type() == circle::TensorType_FLOAT32)
{

status = utils::checkCondition(input->type() == circle::TensorType_FLOAT32 and
output->type() == circle::TensorType_FLOAT32 and
(!bias or bias->type() == circle::TensorType_FLOAT32));
if (status != Ok)
return status;
}
#endif // DIS_FLOAT
#ifndef DIS_QUANT
if (weight->type() == circle::TensorType_UINT8)
{

status = utils::checkCondition(input->type() == circle::TensorType_UINT8 and
output->type() == circle::TensorType_UINT8 and
(!bias or bias->type() == circle::TensorType_INT32));
if (status != Ok)
return status;
}
else if (weight->type() == circle::TensorType_INT8)
{
return UnsupportedType;
status = utils::checkCondition(input->type() == circle::TensorType_INT8 or
input->type() == circle::TensorType_FLOAT32);
if (status != Ok)
return status;

status = utils::checkCondition(output->type() == circle::TensorType_INT8 or
output->type() == circle::TensorType_FLOAT32);
if (status != Ok)
return status;

status = utils::checkCondition(!bias or bias->type() == circle::TensorType_INT32 or
bias->type() == circle::TensorType_INT64 or
bias->type() == circle::TensorType_FLOAT32);
if (status != Ok)
return status;

if (input->type() == circle::TensorType_FLOAT32)
{
// hybrid mode
// Check it is channel wise quantization
status = utils::checkCondition(weight->quantization() != nullptr and
weight->quantization()->scale() != nullptr);
if (status != Ok)
return status;
}
}
#endif // DIS_QUANT

core::OMRuntimeShape weight_shape(weight);
core::OMRuntimeShape bias_shape(bias);
Expand Down

0 comments on commit 96a6f80

Please sign in to comment.