-
Notifications
You must be signed in to change notification settings - Fork 528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
backend to "switch" to a different net for the endgame #1968
base: master
Are you sure you want to change the base?
Changes from all commits
3d27104
4ee9fb0
313738a
079fad8
f5ca7f7
8ac4cd2
3bf3091
81dd539
76b28fb
ea20b5a
1404ff3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
/* | ||
This file is part of Leela Chess Zero. | ||
Copyright (C) 2024 The LCZero Authors | ||
|
||
Leela Chess is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
|
||
Leela Chess is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
|
||
You should have received a copy of the GNU General Public License | ||
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
Additional permission under GNU GPL version 3 section 7 | ||
|
||
If you modify this Program, or any covered work, by linking or | ||
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA | ||
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a | ||
modified version of those libraries), containing parts covered by the | ||
terms of the respective license agreement, the licensors of this | ||
Program grant you additional permission to convey the resulting work. | ||
*/ | ||
|
||
#include <thread> | ||
|
||
#include "chess/bitboard.h" | ||
#include "neural/factory.h" | ||
#include "neural/network.h" | ||
#include "utils/logging.h" | ||
|
||
namespace lczero { | ||
|
||
namespace { | ||
|
||
class SwitchNetwork; | ||
|
||
class SwitchComputation : public NetworkComputation { | ||
public: | ||
SwitchComputation(std::unique_ptr<NetworkComputation> main_comp, | ||
std::unique_ptr<NetworkComputation> endgame_comp, | ||
int threshold, int threads) | ||
: main_comp_(std::move(main_comp)), | ||
endgame_comp_(std::move(endgame_comp)), | ||
threshold_(threshold), | ||
threads_(threads) {} | ||
|
||
void AddInput(InputPlanes&& input) override { | ||
auto pieces = input[1].mask | input[2].mask | input[3].mask | | ||
input[4].mask | input[7].mask | input[8].mask | | ||
input[9].mask | input[10].mask; | ||
if (BitBoard(pieces).count() > threshold_) { | ||
is_endgame_.push_back(false); | ||
rev_idx_.push_back(main_cnt_); | ||
main_cnt_++; | ||
main_comp_->AddInput(std::move(input)); | ||
} else { | ||
is_endgame_.push_back(true); | ||
rev_idx_.push_back(endgame_cnt_); | ||
endgame_cnt_++; | ||
endgame_comp_->AddInput(std::move(input)); | ||
} | ||
} | ||
|
||
void ComputeBlocking() override { | ||
if (threads_ > 1 && main_cnt_ > 0 && endgame_cnt_ > 0) { | ||
std::thread main( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that starting a thread is usually slow, especially on linux. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to test the waters. If it shows potential we can consider a better performing threading approach. (Note the default it one thread and I didn't mention it in the description). |
||
[](NetworkComputation* comp) { comp->ComputeBlocking(); }, | ||
main_comp_.get()); | ||
endgame_comp_->ComputeBlocking(); | ||
main.join(); | ||
} else { | ||
if (main_cnt_ > 0) main_comp_->ComputeBlocking(); | ||
if (endgame_cnt_ > 0) endgame_comp_->ComputeBlocking(); | ||
} | ||
} | ||
|
||
int GetBatchSize() const override { return main_cnt_ + endgame_cnt_; } | ||
|
||
float GetQVal(int sample) const override { | ||
if (is_endgame_[sample]) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe helper function
? |
||
return endgame_comp_->GetQVal(rev_idx_[sample]); | ||
} | ||
return main_comp_->GetQVal(rev_idx_[sample]); | ||
} | ||
|
||
float GetDVal(int sample) const override { | ||
if (is_endgame_[sample]) { | ||
return endgame_comp_->GetDVal(rev_idx_[sample]); | ||
} | ||
return main_comp_->GetDVal(rev_idx_[sample]); | ||
} | ||
|
||
float GetMVal(int sample) const override { | ||
if (is_endgame_[sample]) { | ||
return endgame_comp_->GetMVal(rev_idx_[sample]); | ||
} | ||
return main_comp_->GetMVal(rev_idx_[sample]); | ||
} | ||
|
||
float GetPVal(int sample, int move_id) const override { | ||
if (is_endgame_[sample]) { | ||
return endgame_comp_->GetPVal(rev_idx_[sample], move_id); | ||
} | ||
return main_comp_->GetPVal(rev_idx_[sample], move_id); | ||
} | ||
|
||
private: | ||
std::unique_ptr<NetworkComputation> main_comp_; | ||
std::unique_ptr<NetworkComputation> endgame_comp_; | ||
int main_cnt_ = 0; | ||
int endgame_cnt_ = 0; | ||
std::vector<size_t> rev_idx_; | ||
std::vector<bool> is_endgame_; | ||
int threshold_; | ||
int threads_; | ||
}; | ||
|
||
class SwitchNetwork : public Network { | ||
public: | ||
SwitchNetwork(const std::optional<WeightsFile>& weights, | ||
const OptionsDict& options) { | ||
auto backends = NetworkFactory::Get()->GetBackendsList(); | ||
|
||
threshold_ = options.GetOrDefault<int>("threshold", 6); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without looking into the code, neither variable nor the parameter name's meaning "threshold" was clear. Maybe something with |
||
threads_ = options.GetOrDefault<int>("threads", 1); | ||
|
||
auto& main_options = | ||
options.HasSubdict("main") ? options.GetSubdict("main") : options; | ||
|
||
main_net_ = NetworkFactory::Get()->Create( | ||
main_options.GetOrDefault<std::string>("backend", backends[0]), weights, | ||
main_options); | ||
|
||
std::optional<WeightsFile> endgame_weights; | ||
if (!options.IsDefault<std::string>("endgame_weights")) { | ||
auto name = options.Get<std::string>("endgame_weights"); | ||
CERR << "Loading endgame weights file from: " << name; | ||
endgame_weights = LoadWeightsFromFile(name); | ||
} | ||
|
||
auto& endgame_options = | ||
options.HasSubdict("endgame") ? options.GetSubdict("endgame") : options; | ||
|
||
endgame_net_ = NetworkFactory::Get()->Create( | ||
endgame_options.GetOrDefault<std::string>("backend", backends[0]), | ||
endgame_weights ? endgame_weights : weights, endgame_options); | ||
|
||
capabilities_ = main_net_->GetCapabilities(); | ||
capabilities_.Merge(endgame_net_->GetCapabilities()); | ||
} | ||
|
||
std::unique_ptr<NetworkComputation> NewComputation() override { | ||
std::unique_ptr<NetworkComputation> main_comp = main_net_->NewComputation(); | ||
std::unique_ptr<NetworkComputation> endgame_comp = | ||
endgame_net_->NewComputation(); | ||
return std::make_unique<SwitchComputation>( | ||
std::move(main_comp), std::move(endgame_comp), threshold_, threads_); | ||
} | ||
|
||
const NetworkCapabilities& GetCapabilities() const override { | ||
return capabilities_; | ||
} | ||
|
||
int GetMiniBatchSize() const override { | ||
return std::min(main_net_->GetMiniBatchSize(), | ||
endgame_net_->GetMiniBatchSize()); | ||
} | ||
|
||
bool IsCpu() const override { | ||
return main_net_->GetMiniBatchSize() | endgame_net_->GetMiniBatchSize(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IsCpu() |
||
} | ||
|
||
void InitThread(int id) override { | ||
main_net_->InitThread(id); | ||
endgame_net_->InitThread(id); | ||
} | ||
|
||
private: | ||
std::unique_ptr<Network> main_net_; | ||
std::unique_ptr<Network> endgame_net_; | ||
int threshold_; | ||
int threads_; | ||
NetworkCapabilities capabilities_; | ||
}; | ||
|
||
std::unique_ptr<Network> MakeSwitchNetwork( | ||
const std::optional<WeightsFile>& weights, const OptionsDict& options) { | ||
return std::make_unique<SwitchNetwork>(weights, options); | ||
} | ||
|
||
REGISTER_NETWORK("switch", MakeSwitchNetwork, -800) | ||
|
||
} // namespace | ||
} // namespace lczero |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,6 +124,13 @@ class Lexer { | |
static const std::string kNumberChars = "0123456789-."; | ||
if (kNumberChars.find(str_[idx_]) != std::string::npos) { | ||
ReadNumber(); | ||
// If the next character isn't a separator, this is probably an identifier | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe IsAlpha() would work better? Otherwise we'd need to remember to check for more characters here if we extend the syntax. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem I'm trying to address here is that strings starting with a number (e.g. a typical net name like 744204.pb.gz) are identified as a number and parsing fails at the first unexpected character - it is easier to blacklist a couple of characters than to whitelist dozens. |
||
// starting with a digit. | ||
if (idx_ != str_.size() && !std::isspace(str_[idx_]) && | ||
str_[idx_] != ',' && str_[idx_] != ')') { | ||
idx_ = last_offset_; | ||
ReadIdentifier(); | ||
} | ||
return; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(totally optional)
I'd do it more generic, with a vector of computation+thresholds..