diff --git a/meson.build b/meson.build index 413f04cfe7..afc2562e78 100644 --- a/meson.build +++ b/meson.build @@ -204,6 +204,7 @@ files += [ 'src/neural/network_random.cc', 'src/neural/network_record.cc', 'src/neural/network_rr.cc', + 'src/neural/network_switch.cc', 'src/neural/network_trivial.cc', 'src/neural/onnx/adapters.cc', 'src/neural/onnx/builder.cc', diff --git a/src/neural/blas/network_blas.cc b/src/neural/blas/network_blas.cc index a212ce931e..b08f47d731 100644 --- a/src/neural/blas/network_blas.cc +++ b/src/neural/blas/network_blas.cc @@ -228,11 +228,7 @@ BlasComputation::BlasComputation( ffn_activation_(ffn_activation), attn_policy_(attn_policy), attn_body_(attn_body), - network_(network) { -#ifdef USE_DNNL - omp_set_num_threads(1); -#endif -} + network_(network) { } template using EigenMatrixMap = @@ -457,6 +453,9 @@ void BlasComputation::MakeEncoderLayer( template void BlasComputation::ComputeBlocking() { +#ifdef USE_DNNL + if (!use_eigen) omp_set_num_threads(1); +#endif const auto& value_head = weights_.value_heads.at("winner"); const auto& policy_head = weights_.policy_heads.at("vanilla"); // Retrieve network key dimensions from the weights structure. @@ -967,6 +966,7 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, } if (use_eigen) { + Eigen::setNbThreads(1); CERR << "Using Eigen version " << EIGEN_WORLD_VERSION << "." << EIGEN_MAJOR_VERSION << "." << EIGEN_MINOR_VERSION; CERR << "Eigen max batch size is " << max_batch_size_ << "."; diff --git a/src/neural/network_switch.cc b/src/neural/network_switch.cc new file mode 100644 index 0000000000..5b016a3e08 --- /dev/null +++ b/src/neural/network_switch.cc @@ -0,0 +1,198 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. + */ + +#include + +#include "chess/bitboard.h" +#include "neural/factory.h" +#include "neural/network.h" +#include "utils/logging.h" + +namespace lczero { + +namespace { + +class SwitchNetwork; + +class SwitchComputation : public NetworkComputation { + public: + SwitchComputation(std::unique_ptr main_comp, + std::unique_ptr endgame_comp, + int threshold, int threads) + : main_comp_(std::move(main_comp)), + endgame_comp_(std::move(endgame_comp)), + threshold_(threshold), + threads_(threads) {} + + void AddInput(InputPlanes&& input) override { + auto pieces = input[1].mask | input[2].mask | input[3].mask | + input[4].mask | input[7].mask | input[8].mask | + input[9].mask | input[10].mask; + if (BitBoard(pieces).count() > threshold_) { + is_endgame_.push_back(false); + rev_idx_.push_back(main_cnt_); + main_cnt_++; + main_comp_->AddInput(std::move(input)); + } else { + is_endgame_.push_back(true); + rev_idx_.push_back(endgame_cnt_); + endgame_cnt_++; + endgame_comp_->AddInput(std::move(input)); + } + } + + void ComputeBlocking() override { + if (threads_ > 1 && main_cnt_ > 0 && endgame_cnt_ > 0) { + std::thread main( + [](NetworkComputation* comp) { comp->ComputeBlocking(); }, + main_comp_.get()); + endgame_comp_->ComputeBlocking(); + main.join(); + } else { + if (main_cnt_ > 0) main_comp_->ComputeBlocking(); + if (endgame_cnt_ > 0) endgame_comp_->ComputeBlocking(); + } + } + + int GetBatchSize() const override { return main_cnt_ + endgame_cnt_; } + + float GetQVal(int sample) const override { + if (is_endgame_[sample]) { + return endgame_comp_->GetQVal(rev_idx_[sample]); + } + return main_comp_->GetQVal(rev_idx_[sample]); + } + + float GetDVal(int sample) const override { + if (is_endgame_[sample]) { + return endgame_comp_->GetDVal(rev_idx_[sample]); + } + return main_comp_->GetDVal(rev_idx_[sample]); + } + + float GetMVal(int sample) const override { + if (is_endgame_[sample]) { + return endgame_comp_->GetMVal(rev_idx_[sample]); + } + return main_comp_->GetMVal(rev_idx_[sample]); + } + + float GetPVal(int sample, int move_id) const override { + if (is_endgame_[sample]) { + return endgame_comp_->GetPVal(rev_idx_[sample], move_id); + } + return main_comp_->GetPVal(rev_idx_[sample], move_id); + } + + private: + std::unique_ptr main_comp_; + std::unique_ptr endgame_comp_; + int main_cnt_ = 0; + int endgame_cnt_ = 0; + std::vector rev_idx_; + std::vector is_endgame_; + int threshold_; + int threads_; +}; + +class SwitchNetwork : public Network { + public: + SwitchNetwork(const std::optional& weights, + const OptionsDict& options) { + auto backends = NetworkFactory::Get()->GetBackendsList(); + + threshold_ = options.GetOrDefault("threshold", 6); + threads_ = options.GetOrDefault("threads", 1); + + auto& main_options = + options.HasSubdict("main") ? options.GetSubdict("main") : options; + + main_net_ = NetworkFactory::Get()->Create( + main_options.GetOrDefault("backend", backends[0]), weights, + main_options); + + std::optional endgame_weights; + if (!options.IsDefault("endgame_weights")) { + auto name = options.Get("endgame_weights"); + CERR << "Loading endgame weights file from: " << name; + endgame_weights = LoadWeightsFromFile(name); + } + + auto& endgame_options = + options.HasSubdict("endgame") ? options.GetSubdict("endgame") : options; + + endgame_net_ = NetworkFactory::Get()->Create( + endgame_options.GetOrDefault("backend", backends[0]), + endgame_weights ? endgame_weights : weights, endgame_options); + + capabilities_ = main_net_->GetCapabilities(); + capabilities_.Merge(endgame_net_->GetCapabilities()); + } + + std::unique_ptr NewComputation() override { + std::unique_ptr main_comp = main_net_->NewComputation(); + std::unique_ptr endgame_comp = + endgame_net_->NewComputation(); + return std::make_unique( + std::move(main_comp), std::move(endgame_comp), threshold_, threads_); + } + + const NetworkCapabilities& GetCapabilities() const override { + return capabilities_; + } + + int GetMiniBatchSize() const override { + return std::min(main_net_->GetMiniBatchSize(), + endgame_net_->GetMiniBatchSize()); + } + + bool IsCpu() const override { + return main_net_->GetMiniBatchSize() | endgame_net_->GetMiniBatchSize(); + } + + void InitThread(int id) override { + main_net_->InitThread(id); + endgame_net_->InitThread(id); + } + + private: + std::unique_ptr main_net_; + std::unique_ptr endgame_net_; + int threshold_; + int threads_; + NetworkCapabilities capabilities_; +}; + +std::unique_ptr MakeSwitchNetwork( + const std::optional& weights, const OptionsDict& options) { + return std::make_unique(weights, options); +} + +REGISTER_NETWORK("switch", MakeSwitchNetwork, -800) + +} // namespace +} // namespace lczero diff --git a/src/utils/optionsdict.cc b/src/utils/optionsdict.cc index b515914c36..624b33e3ea 100644 --- a/src/utils/optionsdict.cc +++ b/src/utils/optionsdict.cc @@ -124,6 +124,13 @@ class Lexer { static const std::string kNumberChars = "0123456789-."; if (kNumberChars.find(str_[idx_]) != std::string::npos) { ReadNumber(); + // If the next character isn't a separator, this is probably an identifier + // starting with a digit. + if (idx_ != str_.size() && !std::isspace(str_[idx_]) && + str_[idx_] != ',' && str_[idx_] != ')') { + idx_ = last_offset_; + ReadIdentifier(); + } return; }