Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: Use template to remove unittest duplication #39144

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 70 additions & 6 deletions internal/core/src/common/VectorTrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,107 @@
// limitations under the License.

#pragma once
#include "Types.h"

#include <string>
#include <type_traits>

#include "Array.h"
#include "Types.h"
#include "common/type_c.h"
#include "pb/common.pb.h"
#include "pb/plan.pb.h"
#include "pb/schema.pb.h"

namespace milvus {

#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \
using elem_type = std::conditional_t< \
std::is_same_v<TraitType, milvus::BinaryVector>, \
BinaryVector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::Float16Vector>, \
Float16Vector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::BFloat16Vector>, \
BFloat16Vector::embedded_type, \
FloatVector::embedded_type>>>;

#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
auto schema_data_type = \
std::is_same_v<TraitType, milvus::FloatVector> \
? FloatVector::schema_data_type \
: std::is_same_v<TraitType, milvus::Float16Vector> \
? Float16Vector::schema_data_type \
: std::is_same_v<TraitType, milvus::BFloat16Vector> \
? BFloat16Vector::schema_data_type \
: BinaryVector::schema_data_type;

class VectorTrait {};

class FloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_FLOAT;
static constexpr auto c_data_type = CDataType::FloatVector;
static constexpr auto schema_data_type =
proto::schema::DataType::FloatVector;
static constexpr auto vector_type = proto::plan::VectorType::FloatVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::FloatVector;
};

class BinaryVector : public VectorTrait {
public:
using embedded_type = uint8_t;
static constexpr auto metric_type = DataType::VECTOR_BINARY;
static constexpr int32_t dim_factor = 8;
static constexpr auto data_type = DataType::VECTOR_BINARY;
static constexpr auto c_data_type = CDataType::BinaryVector;
static constexpr auto schema_data_type =
proto::schema::DataType::BinaryVector;
static constexpr auto vector_type = proto::plan::VectorType::BinaryVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::BinaryVector;
};

class Float16Vector : public VectorTrait {
public:
using embedded_type = float16;
static constexpr auto metric_type = DataType::VECTOR_FLOAT16;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_FLOAT16;
static constexpr auto c_data_type = CDataType::Float16Vector;
static constexpr auto schema_data_type =
proto::schema::DataType::Float16Vector;
static constexpr auto vector_type = proto::plan::VectorType::Float16Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::Float16Vector;
};

class BFloat16Vector : public VectorTrait {
public:
using embedded_type = bfloat16;
static constexpr auto metric_type = DataType::VECTOR_BFLOAT16;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_BFLOAT16;
static constexpr auto c_data_type = CDataType::BFloat16Vector;
static constexpr auto schema_data_type =
proto::schema::DataType::BFloat16Vector;
static constexpr auto vector_type = proto::plan::VectorType::BFloat16Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::BFloat16Vector;
};

class SparseFloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_SPARSE_FLOAT;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_SPARSE_FLOAT;
static constexpr auto c_data_type = CDataType::SparseFloatVector;
static constexpr auto schema_data_type =
proto::schema::DataType::SparseFloatVector;
static constexpr auto vector_type =
proto::plan::VectorType::SparseFloatVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::SparseFloatVector;
};

template <typename T>
Expand Down
Loading
Loading