Skip to content

Commit

Permalink
Refactor device_type
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanUnderhill committed Jan 31, 2025
1 parent e6b77f2 commit 4f2f084
Show file tree
Hide file tree
Showing 20 changed files with 51 additions and 80 deletions.
2 changes: 2 additions & 0 deletions src/cpu/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ struct CpuInterface : DeviceInterface {
CpuInterface() {
}

DeviceType GetType() const override { return DeviceType::CPU; }

void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
assert(!ort_allocator_);
ort_allocator_ = &allocator;
Expand Down
2 changes: 2 additions & 0 deletions src/cuda/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct CudaInterfaceImpl final : DeviceInterface {
~CudaInterfaceImpl() {
}

DeviceType GetType() const override { return DeviceType::CUDA; }

void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override {
Ort::api = &api;
assert(!ort_allocator_);
Expand Down
2 changes: 2 additions & 0 deletions src/dml/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ struct InterfaceImpl : DeviceInterface {
Ort::ThrowOnError(Ort::api->GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&dml_api_)));
}

DeviceType GetType() const override { return DeviceType::DML; }

void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override {
Ort::api = &api;
assert(!ort_allocator_);
Expand Down
11 changes: 5 additions & 6 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ GeneratorParams::GeneratorParams(const Config& config)

GeneratorParams::GeneratorParams(const Model& model)
: config{*model.config_.get()},
p_device{model.p_device_},
device_type{model.device_type_},
p_device{model.p_device_inputs_},
is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} {
use_cuda_graph = is_cuda_graph_enabled_;
if (use_cuda_graph) {
Expand All @@ -213,12 +212,12 @@ GeneratorParams::GeneratorParams(const Model& model)
}

void GeneratorParams::TryGraphCapture(int max_bs) {
if (!is_cuda_graph_enabled_ || device_type == DeviceType::CPU) {
if (!is_cuda_graph_enabled_ || p_device->GetType() == DeviceType::CPU) {
// no-op
return;
}

if (DeviceType::CUDA == device_type || DeviceType::DML == device_type) {
if (DeviceType::CUDA == p_device->GetType() || DeviceType::DML == p_device->GetType()) {
if (max_bs == 0) {
throw std::runtime_error("Graph capture is enabled, but max_batch_size is not set.");
}
Expand Down Expand Up @@ -323,8 +322,8 @@ void Generator::AppendTokens(cpu_span<const int32_t> input_ids) {
constexpr std::array<DeviceType, 3> devices_supporting_continuous_decoding{DeviceType::CPU, DeviceType::CUDA, DeviceType::WEBGPU};
if (search_->GetSequenceLength() != 0 &&
std::none_of(devices_supporting_continuous_decoding.begin(), devices_supporting_continuous_decoding.end(),
[this](DeviceType device_type) { return device_type == state_->params_->device_type; }))
throw std::runtime_error("Continuous decoding is not supported on the selected device type (" + to_string(state_->params_->device_type) +
[this](DeviceType device_type) { return device_type == state_->params_->p_device->GetType(); }))
throw std::runtime_error("Continuous decoding is not supported on the selected device type (" + to_string(state_->params_->p_device->GetType()) +
"). Please recreate the generator instance to avoid using continuous decoding.");

if (last_action_ == Action::generated) {
Expand Down
12 changes: 1 addition & 11 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,6 @@ struct OrtTensor {
// OgaSequences are a vector of int32 vectors
using TokenSequences = std::vector<std::vector<int32_t>>;

enum struct DeviceType {
CPU,
CUDA,
DML,
WEBGPU,
QNN,
MAX
};

std::string to_string(DeviceType device_type);
DeviceInterface* GetDeviceInterface(DeviceType type);

Expand All @@ -87,8 +78,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChec
bool use_cuda_graph{};
int BatchBeamSize() const { return search.num_beams * search.batch_size; }

DeviceInterface* p_device{};
DeviceType device_type{DeviceType::CPU};
DeviceInterface* p_device{}; // Scoring device (usually CPU, but can be CUDA)

cpu_span<int32_t> aux_input_ids{}; // Intermediate solution to be used with SetInputs function for multimodal and whisper models

Expand Down
2 changes: 1 addition & 1 deletion src/models/adapters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void Adapters::LoadAdapter(const char* adapter_file_path, const std::string& ada
}

adapters_.emplace(adapter_name, std::make_unique<Adapter>(adapter_file_path,
model_->device_type_ == DeviceType::CUDA
model_->p_device_->GetType() == DeviceType::CUDA
? &model_->p_device_->GetAllocator()
: nullptr));
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/captured_graph_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void CapturedGraphInfoRecycler::operator()(CapturedGraphInfo* captured_graph_inf
}

CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const {
if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA)) {
if (!params.use_cuda_graph || (model.p_device_->GetType() != DeviceType::CUDA)) {
return nullptr;
}

Expand Down
10 changes: 5 additions & 5 deletions src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ bool IntermediatePipelineState::HasOutput(std::string_view name) const {
}

bool IntermediatePipelineState::SupportsPrimaryDevice() const {
if (model_.device_type_ == DeviceType::CPU || model_.device_type_ == DeviceType::QNN) {
if (model_.p_device_->GetType() == DeviceType::CPU || model_.p_device_->GetType() == DeviceType::QNN) {
return true;
} else if (model_.device_type_ == DeviceType::CUDA) {
} else if (model_.p_device_->GetType() == DeviceType::CUDA) {
if (!model_.config_->model.decoder.pipeline[id_].session_options.has_value()) {
// No session options, so this session uses the default session options.
// Default session options supports the cuda device type.
Expand Down Expand Up @@ -134,7 +134,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
if (!pipeline_state->SupportsPrimaryDevice()) {
std::ostringstream oss;
oss << "Managed input " << input_name << " resides on the primary device type ("
<< to_string(model_.device_type_) << "). "
<< to_string(model_.p_device_->GetType()) << "). "
<< "But the pipeline model "
<< model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
<< " is expecting it to reside elsewhere.";
Expand All @@ -159,7 +159,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
if (!pipeline_state->SupportsPrimaryDevice()) {
std::ostringstream oss;
oss << "Managed output " << output_name << " resides on the primary device type ("
<< to_string(model_.device_type_) << "). "
<< to_string(model_.p_device_->GetType()) << "). "
<< "But the pipeline model "
<< model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
<< " is expecting it to reside elsewhere.";
Expand All @@ -178,7 +178,7 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
if (!pipeline_state->SupportsPrimaryDevice()) {
std::ostringstream oss;
oss << "Managed input " << input_name << " resides on the primary device type ("
<< to_string(model_.device_type_) << "). "
<< to_string(model_.p_device_->GetType()) << "). "
<< "But the pipeline model "
<< model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id
<< " is expecting it to reside elsewhere.";
Expand Down
22 changes: 4 additions & 18 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Logits::Logits(State& state)
type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} {
output_raw_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);

if (model_.device_type_ == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
if (model_.p_device_inputs_->GetType() == DeviceType::CUDA && !model_.config_->model.eos_token_ids.empty()) {
auto& cpu_ids = model_.config_->model.eos_token_ids;
cuda_eos_token_ids_ = model_.p_device_->Allocate<int32_t>(cpu_ids.size());
copy(std::span<const int32_t>{cpu_ids}, cuda_eos_token_ids_.CpuSpan());
Expand Down Expand Up @@ -70,9 +70,9 @@ DeviceSpan<float> Logits::Get() {
if (logits_.empty() || logits_of_last_token->GetTensorMutableRawData() != logits_.Span().data())
logits_ = WrapTensor<float>(*model_.p_device_inputs_, *logits_of_last_token);

if (model_.device_type_ == DeviceType::CUDA) {
if (model_.p_device_inputs_->GetType() == DeviceType::CUDA) {
if (!cuda_eos_token_ids_.empty())
model_.p_device_->LaunchHandleEOSArray(
model_.p_device_inputs_->LaunchHandleEOSArray(
logits_.Span().data(),
static_cast<int>(shape_[0]) /* batch_beam_size*/,
static_cast<int>(shape_[2]) /* vocab_size */,
Expand Down Expand Up @@ -107,21 +107,7 @@ void Logits::Update(const DeviceSpan<int32_t>& next_tokens, size_t new_kv_length
}

shape_[1] = new_kv_length;
StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType<Ort::Float16_t> ? sb_logits16_ : sb_logits32_;
output_raw_ = !sb_logits ? OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_)
: sb_logits->CreateTensorOnStaticBuffer(shape_, type_);

if (state_.GetCapturedGraphInfo()) {
if (!sb_logits16_ && !sb_logits32_) {
if (type_ == Ort::TypeToTensorType<float>) {
sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get();
}
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>) {
sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get();
}
}
}

output_raw_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
state_.outputs_[output_index_] = output_raw_.get();
}

Expand Down
4 changes: 0 additions & 4 deletions src/models/logits.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ struct Logits {
// OrtValue wrapped in a DeviceMemory object to make it universal
DeviceSpan<float> logits_;

// Used for decoding runs with cuda graphs.
StaticBuffer* sb_logits32_{};
StaticBuffer* sb_logits16_{};

DeviceSpan<int32_t> cuda_eos_token_ids_; // eos_token_ids from params, but in cuda accessible memory
};

Expand Down
24 changes: 9 additions & 15 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,10 @@ Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
Model::~Model() = default;

void Model::InitDeviceAllocator(OrtSession& session) {
EnsureDeviceOrtInit(session, device_type_);
EnsureDeviceOrtInit(session, p_device_->GetType());

// Only CUDA does every input on the device
if (device_type_ == DeviceType::CUDA)
if (p_device_->GetType() == DeviceType::CUDA)
p_device_inputs_ = p_device_;
else
p_device_inputs_ = GetDeviceInterface(DeviceType::CPU);
Expand Down Expand Up @@ -413,8 +413,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
// Device type determines the scoring device.
// Only use the primary session options to determine the device type
if (is_primary_session_options) {
device_type_ = DeviceType::CUDA; // Scoring will use CUDA
p_device_ = GetDeviceInterface(device_type_);
p_device_ = GetDeviceInterface(DeviceType::CUDA);

// Create and set our cudaStream_t
ort_provider_options->UpdateValue("user_compute_stream", p_device_->GetCudaStream());
Expand Down Expand Up @@ -451,15 +450,10 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
InitDmlInterface(p_device_luid);
}

if (!disable_graph_capture) {
session_options.AddConfigEntry("ep.dml.enable_graph_capture", "1");
session_options.AddConfigEntry("ep.dml.disable_memory_arena", "1");
}

SetDmlProvider(session_options);

if (is_primary_session_options)
device_type_ = DeviceType::DML; // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
p_device_ = GetDeviceInterface(DeviceType::DML); // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
#endif
} else if (provider_options.name == "qnn") {
session_options.AddConfigEntry("ep.share_ep_contexts", "1");
Expand All @@ -473,12 +467,12 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
// on the other hand, not sure if is_primary_session_options is the right thing to check here.
if (const auto opt_it = opts.find("enable_htp_shared_memory_allocator");
opt_it != opts.end() && opt_it->second == "1") {
device_type_ = DeviceType::QNN;
p_device_ = GetDeviceInterface(DeviceType::QNN);
}

session_options.AppendExecutionProvider("QNN", opts);
} else if (provider_options.name == "webgpu") {
device_type_ = DeviceType::WEBGPU;
p_device_ = GetDeviceInterface(DeviceType::WEBGPU);
std::unordered_map<std::string, std::string> opts;
for (auto& option : provider_options.options) {
opts.emplace(option.first, option.second);
Expand All @@ -488,9 +482,9 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
throw std::runtime_error("Unknown provider type: " + provider_options.name);
}

if (!p_device_) {
p_device_ = GetDeviceInterface(device_type_);
}
// Fallback to CPU if no provider specific interface was set
if (!p_device_)
p_device_ = GetDeviceInterface(DeviceType::CPU);
}

void Model::CreateSessionOptions() {
Expand Down
1 change: 0 additions & 1 deletion src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
std::unique_ptr<Config> config_;
std::unique_ptr<OrtSessionOptions> session_options_;

DeviceType device_type_{DeviceType::CPU};
mutable DeviceInterface* p_device_{}; // The device we're running on (matches device_type_) used for things that work the same on all devices
mutable DeviceInterface* p_device_inputs_{}; // For some model inputs, the device might be the CPU device (all but KV cache currently)
mutable DeviceInterface* p_device_kvcache_{}; // The kvcache is always allocated in device memory (TODO: Remove in favor of just p_device_?)
Expand Down
4 changes: 2 additions & 2 deletions src/models/position_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_lengt
state_.inputs_[posid_input_index_] = position_ids_.get();
}

if (model_.device_type_ == DeviceType::CUDA)
if (model_.p_device_inputs_->GetType() == DeviceType::CUDA)
model_.p_device_inputs_->UpdatePositionIds(position_ids_->GetTensorMutableRawData(), static_cast<int>(position_ids_shape_[0]), total_length, new_kv_length, type_);
else {
type_ == Ort::TypeToTensorType<int32_t> ? UpdatePositionIDsImpl<int32_t>(total_length, new_kv_length)
Expand Down Expand Up @@ -170,7 +170,7 @@ void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_len
CreateNextAttentionMaskTensor(total_length);
state_.inputs_[mask_input_index_] = attention_mask_.get();

if (model_.device_type_ == DeviceType::CUDA) {
if (model_.p_device_inputs_->GetType() == DeviceType::CUDA) {
int max_length = sb_attention_mask_ ? state_.params_->search.max_length : total_length;
bool update_only = sb_attention_mask_ && !is_first_mask_update_;
model_.p_device_inputs_->UpdateAttentionMask(attention_mask_next_->GetTensorMutableRawData(),
Expand Down
2 changes: 1 addition & 1 deletion src/models/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ DeviceSpan<float> Whisper_State::Run(int current_length, DeviceSpan<int32_t>& ne
auto src_data = init_presents_[i]->GetTensorRawData();
auto dest_data = presents_[i]->GetTensorMutableRawData();

switch (model_.device_type_) {
switch (model_.p_device_inputs_->GetType()) {
#if 0 // USE_CUDA
case DeviceType::CUDA:
if (self_attn_kv_cache_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
Expand Down
2 changes: 1 addition & 1 deletion src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
}))
.def_property_readonly("type", [](const Model& model) { return model.config_->model.type; })
.def_property_readonly(
"device_type", [](const Model& model) { return to_string(model.device_type_); }, "The device type the model is running on")
"device_type", [](const Model& model) { return to_string(model.p_device_->GetType()); }, "The device type the model is running on")
.def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); });

pybind11::class_<PyGenerator>(m, "Generator")
Expand Down
2 changes: 2 additions & 0 deletions src/qnn/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ struct InterfaceImpl : DeviceInterface {
InterfaceImpl() {
}

DeviceType GetType() const override { return DeviceType::QNN; }

void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
assert(!ort_allocator_);
ort_allocator_ = &allocator;
Expand Down
11 changes: 11 additions & 0 deletions src/smartptrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,19 @@ struct DeviceSpan {
friend struct DeviceSpan; // All DeviceSpans are friends
};

enum struct DeviceType {
CPU,
CUDA,
DML,
WEBGPU,
QNN,
MAX
};

struct DeviceInterface {
virtual ~DeviceInterface() {}

virtual DeviceType GetType() const = 0;
virtual void InitOrt(const OrtApi& api, Ort::Allocator& allocator) = 0;
virtual Ort::Allocator& GetAllocator() = 0;

Expand Down
2 changes: 2 additions & 0 deletions src/webgpu/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ struct InterfaceImpl : DeviceInterface {
InterfaceImpl() {
}

DeviceType GetType() const override { return DeviceType::WEBGPU; }

void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
assert(!ort_allocator_);
ort_allocator_ = &allocator;
Expand Down
1 change: 0 additions & 1 deletion test/sampling_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ struct SamplingBenchmark {
params->search.max_length = 10;
params->search.batch_size = batch_size_;
params->p_device = Generators::GetDeviceInterface(device_type_);
params->device_type = device_type_;

std::random_device rd;
std::mt19937 engine(rd());
Expand Down
Loading

0 comments on commit 4f2f084

Please sign in to comment.