Skip to content

Commit

Permalink
enhance: Use template to remove unittest duplication (#39144)
Browse files Browse the repository at this point in the history
Issue: #38666

Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored Jan 13, 2025
1 parent 032292a commit 2a02bbe
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 1,151 deletions.
76 changes: 70 additions & 6 deletions internal/core/src/common/VectorTrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,107 @@
// limitations under the License.

#pragma once
#include "Types.h"

#include <string>
#include <type_traits>

#include "Array.h"
#include "Types.h"
#include "common/type_c.h"
#include "pb/common.pb.h"
#include "pb/plan.pb.h"
#include "pb/schema.pb.h"

namespace milvus {

#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \
using elem_type = std::conditional_t< \
std::is_same_v<TraitType, milvus::BinaryVector>, \
BinaryVector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::Float16Vector>, \
Float16Vector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::BFloat16Vector>, \
BFloat16Vector::embedded_type, \
FloatVector::embedded_type>>>;

#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
auto schema_data_type = \
std::is_same_v<TraitType, milvus::FloatVector> \
? FloatVector::schema_data_type \
: std::is_same_v<TraitType, milvus::Float16Vector> \
? Float16Vector::schema_data_type \
: std::is_same_v<TraitType, milvus::BFloat16Vector> \
? BFloat16Vector::schema_data_type \
: BinaryVector::schema_data_type;

class VectorTrait {};

class FloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_FLOAT;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_FLOAT;
static constexpr auto c_data_type = CDataType::FloatVector;
static constexpr auto schema_data_type =
proto::schema::DataType::FloatVector;
static constexpr auto vector_type = proto::plan::VectorType::FloatVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::FloatVector;
};

class BinaryVector : public VectorTrait {
public:
using embedded_type = uint8_t;
static constexpr auto metric_type = DataType::VECTOR_BINARY;
static constexpr int32_t dim_factor = 8;
static constexpr auto data_type = DataType::VECTOR_BINARY;
static constexpr auto c_data_type = CDataType::BinaryVector;
static constexpr auto schema_data_type =
proto::schema::DataType::BinaryVector;
static constexpr auto vector_type = proto::plan::VectorType::BinaryVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::BinaryVector;
};

class Float16Vector : public VectorTrait {
public:
using embedded_type = float16;
static constexpr auto metric_type = DataType::VECTOR_FLOAT16;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_FLOAT16;
static constexpr auto c_data_type = CDataType::Float16Vector;
static constexpr auto schema_data_type =
proto::schema::DataType::Float16Vector;
static constexpr auto vector_type = proto::plan::VectorType::Float16Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::Float16Vector;
};

class BFloat16Vector : public VectorTrait {
public:
using embedded_type = bfloat16;
static constexpr auto metric_type = DataType::VECTOR_BFLOAT16;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_BFLOAT16;
static constexpr auto c_data_type = CDataType::BFloat16Vector;
static constexpr auto schema_data_type =
proto::schema::DataType::BFloat16Vector;
static constexpr auto vector_type = proto::plan::VectorType::BFloat16Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::BFloat16Vector;
};

class SparseFloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr auto metric_type = DataType::VECTOR_SPARSE_FLOAT;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_SPARSE_FLOAT;
static constexpr auto c_data_type = CDataType::SparseFloatVector;
static constexpr auto schema_data_type =
proto::schema::DataType::SparseFloatVector;
static constexpr auto vector_type =
proto::plan::VectorType::SparseFloatVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::SparseFloatVector;
};

template <typename T>
Expand Down
Loading

0 comments on commit 2a02bbe

Please sign in to comment.