Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend to "switch" to a different net for the endgame #1968

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ files += [
'src/neural/network_random.cc',
'src/neural/network_record.cc',
'src/neural/network_rr.cc',
'src/neural/network_switch.cc',
'src/neural/network_trivial.cc',
'src/neural/onnx/adapters.cc',
'src/neural/onnx/builder.cc',
Expand Down
10 changes: 5 additions & 5 deletions src/neural/blas/network_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,7 @@ BlasComputation<use_eigen>::BlasComputation(
ffn_activation_(ffn_activation),
attn_policy_(attn_policy),
attn_body_(attn_body),
network_(network) {
#ifdef USE_DNNL
omp_set_num_threads(1);
#endif
}
network_(network) { }

template <typename T>
using EigenMatrixMap =
Expand Down Expand Up @@ -457,6 +453,9 @@ void BlasComputation<use_eigen>::MakeEncoderLayer(

template <bool use_eigen>
void BlasComputation<use_eigen>::ComputeBlocking() {
#ifdef USE_DNNL
if (!use_eigen) omp_set_num_threads(1);
#endif
const auto& value_head = weights_.value_heads.at("winner");
const auto& policy_head = weights_.policy_heads.at("vanilla");
// Retrieve network key dimensions from the weights structure.
Expand Down Expand Up @@ -967,6 +966,7 @@ BlasNetwork<use_eigen>::BlasNetwork(const WeightsFile& file,
}

if (use_eigen) {
Eigen::setNbThreads(1);
CERR << "Using Eigen version " << EIGEN_WORLD_VERSION << "."
<< EIGEN_MAJOR_VERSION << "." << EIGEN_MINOR_VERSION;
CERR << "Eigen max batch size is " << max_batch_size_ << ".";
Expand Down
198 changes: 198 additions & 0 deletions src/neural/network_switch.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2024 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.

Additional permission under GNU GPL version 3 section 7

If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/

#include <thread>

#include "chess/bitboard.h"
#include "neural/factory.h"
#include "neural/network.h"
#include "utils/logging.h"

namespace lczero {

namespace {

class SwitchNetwork;

class SwitchComputation : public NetworkComputation {
public:
SwitchComputation(std::unique_ptr<NetworkComputation> main_comp,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(totally optional)
I'd do it more generic, with a vector of computation+thresholds..

std::unique_ptr<NetworkComputation> endgame_comp,
int threshold, int threads)
: main_comp_(std::move(main_comp)),
endgame_comp_(std::move(endgame_comp)),
threshold_(threshold),
threads_(threads) {}

void AddInput(InputPlanes&& input) override {
auto pieces = input[1].mask | input[2].mask | input[3].mask |
input[4].mask | input[7].mask | input[8].mask |
input[9].mask | input[10].mask;
if (BitBoard(pieces).count() > threshold_) {
is_endgame_.push_back(false);
rev_idx_.push_back(main_cnt_);
main_cnt_++;
main_comp_->AddInput(std::move(input));
} else {
is_endgame_.push_back(true);
rev_idx_.push_back(endgame_cnt_);
endgame_cnt_++;
endgame_comp_->AddInput(std::move(input));
}
}

void ComputeBlocking() override {
if (threads_ > 1 && main_cnt_ > 0 && endgame_cnt_ > 0) {
std::thread main(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that starting a thread is usually slow, especially on linux.
In other backends, e.g. network_mux.cc we did condvar-based synchronization instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to test the waters. If it shows potential we can consider a better performing threading approach. (Note the default it one thread and I didn't mention it in the description).

[](NetworkComputation* comp) { comp->ComputeBlocking(); },
main_comp_.get());
endgame_comp_->ComputeBlocking();
main.join();
} else {
if (main_cnt_ > 0) main_comp_->ComputeBlocking();
if (endgame_cnt_ > 0) endgame_comp_->ComputeBlocking();
}
}

int GetBatchSize() const override { return main_cnt_ + endgame_cnt_; }

float GetQVal(int sample) const override {
if (is_endgame_[sample]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe helper function

NetworkComputation* GetComp(int sample) { return is_endgame_[sample] ? endgame_comp : main_comp_; }

?

return endgame_comp_->GetQVal(rev_idx_[sample]);
}
return main_comp_->GetQVal(rev_idx_[sample]);
}

float GetDVal(int sample) const override {
if (is_endgame_[sample]) {
return endgame_comp_->GetDVal(rev_idx_[sample]);
}
return main_comp_->GetDVal(rev_idx_[sample]);
}

float GetMVal(int sample) const override {
if (is_endgame_[sample]) {
return endgame_comp_->GetMVal(rev_idx_[sample]);
}
return main_comp_->GetMVal(rev_idx_[sample]);
}

float GetPVal(int sample, int move_id) const override {
if (is_endgame_[sample]) {
return endgame_comp_->GetPVal(rev_idx_[sample], move_id);
}
return main_comp_->GetPVal(rev_idx_[sample], move_id);
}

private:
std::unique_ptr<NetworkComputation> main_comp_;
std::unique_ptr<NetworkComputation> endgame_comp_;
int main_cnt_ = 0;
int endgame_cnt_ = 0;
std::vector<size_t> rev_idx_;
std::vector<bool> is_endgame_;
int threshold_;
int threads_;
};

class SwitchNetwork : public Network {
public:
SwitchNetwork(const std::optional<WeightsFile>& weights,
const OptionsDict& options) {
auto backends = NetworkFactory::Get()->GetBackendsList();

threshold_ = options.GetOrDefault<int>("threshold", 6);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without looking into the code, neither variable nor the parameter name's meaning "threshold" was clear. Maybe something with pieces in the name? pieces_threshold? piece_count? :-\ Not sure what would be better though.

threads_ = options.GetOrDefault<int>("threads", 1);

auto& main_options =
options.HasSubdict("main") ? options.GetSubdict("main") : options;

main_net_ = NetworkFactory::Get()->Create(
main_options.GetOrDefault<std::string>("backend", backends[0]), weights,
main_options);

std::optional<WeightsFile> endgame_weights;
if (!options.IsDefault<std::string>("endgame_weights")) {
auto name = options.Get<std::string>("endgame_weights");
CERR << "Loading endgame weights file from: " << name;
endgame_weights = LoadWeightsFromFile(name);
}

auto& endgame_options =
options.HasSubdict("endgame") ? options.GetSubdict("endgame") : options;

endgame_net_ = NetworkFactory::Get()->Create(
endgame_options.GetOrDefault<std::string>("backend", backends[0]),
endgame_weights ? endgame_weights : weights, endgame_options);

capabilities_ = main_net_->GetCapabilities();
capabilities_.Merge(endgame_net_->GetCapabilities());
}

std::unique_ptr<NetworkComputation> NewComputation() override {
std::unique_ptr<NetworkComputation> main_comp = main_net_->NewComputation();
std::unique_ptr<NetworkComputation> endgame_comp =
endgame_net_->NewComputation();
return std::make_unique<SwitchComputation>(
std::move(main_comp), std::move(endgame_comp), threshold_, threads_);
}

const NetworkCapabilities& GetCapabilities() const override {
return capabilities_;
}

int GetMiniBatchSize() const override {
return std::min(main_net_->GetMiniBatchSize(),
endgame_net_->GetMiniBatchSize());
}

bool IsCpu() const override {
return main_net_->GetMiniBatchSize() | endgame_net_->GetMiniBatchSize();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsCpu()

}

void InitThread(int id) override {
main_net_->InitThread(id);
endgame_net_->InitThread(id);
}

private:
std::unique_ptr<Network> main_net_;
std::unique_ptr<Network> endgame_net_;
int threshold_;
int threads_;
NetworkCapabilities capabilities_;
};

std::unique_ptr<Network> MakeSwitchNetwork(
const std::optional<WeightsFile>& weights, const OptionsDict& options) {
return std::make_unique<SwitchNetwork>(weights, options);
}

REGISTER_NETWORK("switch", MakeSwitchNetwork, -800)

} // namespace
} // namespace lczero
7 changes: 7 additions & 0 deletions src/utils/optionsdict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ class Lexer {
static const std::string kNumberChars = "0123456789-.";
if (kNumberChars.find(str_[idx_]) != std::string::npos) {
ReadNumber();
// If the next character isn't a separator, this is probably an identifier
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe IsAlpha() would work better? Otherwise we'd need to remember to check for more characters here if we extend the syntax.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem I'm trying to address here is that strings starting with a number (e.g. a typical net name like 744204.pb.gz) are identified as a number and parsing fails at the first unexpected character - it is easier to blacklist a couple of characters than to whitelist dozens.

// starting with a digit.
if (idx_ != str_.size() && !std::isspace(str_[idx_]) &&
str_[idx_] != ',' && str_[idx_] != ')') {
idx_ = last_offset_;
ReadIdentifier();
}
return;
}

Expand Down