Skip to content

Commit

Permalink
Refactor Network Usage
Browse files Browse the repository at this point in the history
This update improves how Pikafish handles network
usage, making it easier to manage and modify networks in the future.

With the introduction of a dedicated Network class, creating networks has become
straightforward. See uci.cpp:
```cpp
NN::Network({EvalFileDefaultName, "None", ""})
```

The new `Network` encapsulates all network-related logic, significantly reducing
the complexity previously required to support multiple network types.

No functional change
  • Loading branch information
PikaCat-OuO committed Mar 13, 2024
1 parent 87dea9f commit 495df31
Show file tree
Hide file tree
Showing 17 changed files with 593 additions and 492 deletions.
6 changes: 3 additions & 3 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ PGOBENCH = $(WINE_PATH) ./$(EXE) bench
SRCS = benchmark.cpp bitboard.cpp evaluate.cpp main.cpp \
misc.cpp movegen.cpp movepick.cpp position.cpp \
search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp \
nnue/evaluate_nnue.cpp nnue/features/half_ka_v2_hm.cpp \
nnue/nnue_misc.cpp nnue/features/half_ka_v2_hm.cpp nnue/network.cpp \
external/zip.cpp

HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h magics.h \
nnue/evaluate_nnue.h nnue/features/half_ka_v2_hm.h nnue/layers/affine_transform.h \
nnue/nnue_misc.h nnue/features/half_ka_v2_hm.h nnue/layers/affine_transform.h \
nnue/layers/affine_transform_sparse_input.h nnue/layers/clipped_relu.h nnue/layers/simd.h \
nnue/layers/sqr_clipped_relu.h nnue/nnue_accumulator.h nnue/nnue_architecture.h \
nnue/nnue_common.h nnue/nnue_feature_transformer.h position.h \
search.h thread.h thread_win32_osx.h timeman.h \
tt.h tune.h types.h uci.h ucioption.h perft.h \
tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.cpp \
external/zip.h external/miniz.h

OBJS = $(notdir $(SRCS:.cpp=.o))
Expand Down
100 changes: 9 additions & 91 deletions src/evaluate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,99 +22,18 @@
#include <cassert>
#include <cmath>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <optional>
#include <sstream>
#include <vector>

#include "misc.h"
#include "nnue/evaluate_nnue.h"
#include "nnue/network.h"
#include "nnue/nnue_misc.h"
#include "position.h"
#include "types.h"
#include "uci.h"
#include "ucioption.h"

namespace Stockfish {

namespace Eval {

std::string currentEvalFileName = "None";

// Tries to load a NNUE network at startup time, or when the engine
// receives a UCI command "setoption name EvalFile value .*.nnue"
// The name of the NNUE network is always retrieved from the EvalFile option.
// We search the given network in two locations: in the active working directory and
// in the engine directory.
EvalFile NNUE::load_networks(const std::string& rootDirectory,
const OptionsMap& options,
EvalFile evalFile) {

std::string user_eval_file = options[evalFile.optionName];

if (user_eval_file.empty())
user_eval_file = evalFile.defaultName;

std::vector<std::string> dirs = {"", rootDirectory};

for (const std::string& directory : dirs)
if (evalFile.current != user_eval_file)
{
std::stringstream ss = read_zipped_nnue(directory + user_eval_file);
auto description = NNUE::load_eval(ss);
if (!description.has_value())
{
std::ifstream stream(directory + user_eval_file, std::ios::binary);
description = NNUE::load_eval(stream);
}

if (description.has_value())
{
evalFile.current = user_eval_file;
evalFile.netDescription = description.value();
}
}

return evalFile;
}

// Verifies that the last net used was loaded successfully
void NNUE::verify(const OptionsMap& options, const EvalFile& evalFile) {

std::string user_eval_file = options[evalFile.optionName];

if (user_eval_file.empty())
user_eval_file = evalFile.defaultName;

if (evalFile.current != user_eval_file)
{

std::string msg1 =
"Network evaluation parameters compatible with the engine must be available.";
std::string msg2 = "The network file " + user_eval_file + " was not loaded successfully.";
std::string msg3 = "The UCI option EvalFile might need to specify the full path, "
"including the directory name, to the network file.";
std::string msg4 =
"The default net can be downloaded from: "
"https://github.com/official-pikafish/Networks/releases/download/master-net/"
+ evalFile.defaultName;
std::string msg5 = "The engine will be terminated now.";

sync_cout << "info string ERROR: " << msg1 << sync_endl;
sync_cout << "info string ERROR: " << msg2 << sync_endl;
sync_cout << "info string ERROR: " << msg3 << sync_endl;
sync_cout << "info string ERROR: " << msg4 << sync_endl;
sync_cout << "info string ERROR: " << msg5 << sync_endl;

exit(EXIT_FAILURE);
}

sync_cout << "info string NNUE evaluation using " << user_eval_file << " enabled" << sync_endl;
}
}


// Returns a static, purely materialistic evaluation of the position from
// the point of view of the given color. It can be divided by PawnValue to get
// an approximation of the material advantage on the board in terms of pawns.
Expand All @@ -127,7 +46,7 @@ int Eval::simple_eval(const Position& pos, Color c) {

// Evaluate is the evaluator for the outer world. It returns a static evaluation
// of the position from the point of view of the side to move.
Value Eval::evaluate(const Position& pos, int optimism) {
Value Eval::evaluate(const Eval::NNUE::Network& network, const Position& pos, int optimism) {

assert(!pos.checkers());

Expand All @@ -137,7 +56,7 @@ Value Eval::evaluate(const Position& pos, int optimism) {
int simpleEval = simple_eval(pos, stm);

int nnueComplexity;
Value nnue = NNUE::evaluate(pos, true, &nnueComplexity);
Value nnue = network.evaluate(pos, true, &nnueComplexity);

// Blend optimism and eval with nnue complexity and material imbalance
optimism += optimism * (nnueComplexity + std::abs(simpleEval - nnue)) / 781;
Expand All @@ -159,24 +78,23 @@ Value Eval::evaluate(const Position& pos, int optimism) {
// a string (suitable for outputting to stdout) that contains the detailed
// descriptions and values of each evaluation term. Useful for debugging.
// Trace scores are from white's point of view
std::string Eval::trace(Position& pos) {
std::string Eval::trace(Position& pos, const Eval::NNUE::Network& network) {

if (pos.checkers())
return "Final evaluation: none (in check)";

std::stringstream ss;
ss << std::showpoint << std::noshowpos << std::fixed << std::setprecision(2);

ss << '\n' << NNUE::trace(pos) << '\n';
ss << '\n' << NNUE::trace(pos, network) << '\n';

ss << std::showpoint << std::showpos << std::fixed << std::setprecision(2) << std::setw(15);

Value v;
v = NNUE::evaluate(pos);
v = pos.side_to_move() == WHITE ? v : -v;
Value v = network.evaluate(pos);
v = pos.side_to_move() == WHITE ? v : -v;
ss << "NNUE evaluation " << 0.01 * UCI::to_cp(v) << " (white side)\n";

v = evaluate(pos, VALUE_ZERO);
v = evaluate(network, pos, VALUE_ZERO);
v = pos.side_to_move() == WHITE ? v : -v;
ss << "Final evaluation " << 0.01 * UCI::to_cp(v) << " (white side)";
ss << " [with scaled NNUE, ...]";
Expand Down
28 changes: 7 additions & 21 deletions src/evaluate.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,23 @@
namespace Stockfish {

class Position;
class OptionsMap;

namespace Eval {

std::string trace(Position& pos);

int simple_eval(const Position& pos, Color c);
Value evaluate(const Position& pos, int optimism);

// The default net name MUST follow the format nn-[SHA256 first 12 digits].nnue
// for the build process (profile-build and fishtest) to work. Do not change the
// name of the macro, as it is used in the Makefile.
// name of the macro or the location where this macro is defined, as it is used
// in the Makefile/Fishtest.
#define EvalFileDefaultName "pikafish.nnue"

struct EvalFile {
// UCI option name
std::string optionName;
// Default net name, will use one of the macros above
std::string defaultName;
// Selected net name, either via uci option or default
std::string current;
// Net description extracted from the net file
std::string netDescription;
};

namespace NNUE {
class Network;
}

EvalFile load_networks(const std::string&, const OptionsMap&, EvalFile);
void verify(const OptionsMap&, const EvalFile&);
std::string trace(Position& pos, const Eval::NNUE::Network& network);

} // namespace NNUE
int simple_eval(const Position& pos, Color c);
Value evaluate(const NNUE::Network& network, const Position& pos, int optimism);

} // namespace Eval

Expand Down
3 changes: 0 additions & 3 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <iostream>

#include "bitboard.h"
#include "evaluate.h"
#include "misc.h"
#include "position.h"
#include "tune.h"
Expand All @@ -36,8 +35,6 @@ int main(int argc, char* argv[]) {

Tune::init(uci.options);

uci.evalFile = Eval::NNUE::load_networks(uci.working_directory(), uci.options, uci.evalFile);

uci.loop();

return 0;
Expand Down
25 changes: 25 additions & 0 deletions src/misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cstddef>
#include <cstdint>
#include <iosfwd>
#include <memory>
#include <string>
#include <vector>

Expand All @@ -51,6 +52,30 @@ void aligned_large_pages_free(void* mem);

std::stringstream read_zipped_nnue(const std::string& fpath);

// Deleter for automating release of memory area
template<typename T>
struct AlignedDeleter {
void operator()(T* ptr) const {
ptr->~T();
std_aligned_free(ptr);
}
};

template<typename T>
struct LargePageDeleter {
void operator()(T* ptr) const {
ptr->~T();
aligned_large_pages_free(ptr);
}
};

template<typename T>
using AlignedPtr = std::unique_ptr<T, AlignedDeleter<T>>;

template<typename T>
using LargePagePtr = std::unique_ptr<T, LargePageDeleter<T>>;


void dbg_hit_on(bool cond, int slot = 0);
void dbg_mean_of(int64_t value, int slot = 0);
void dbg_stdev_of(int64_t value, int slot = 0);
Expand Down
80 changes: 0 additions & 80 deletions src/nnue/evaluate_nnue.h

This file was deleted.

Loading

0 comments on commit 495df31

Please sign in to comment.