diff --git a/README.md b/README.md index 472d1b9..ea32292 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,8 @@ Pre-reqs: # go get -u github.com/notnil/chess go get -u github.com/Tilps/chess go get -u github.com/nightlyone/lockfile - +go get -u github.com/jaypipes/ghw" +go get -u github.com/shettyh/threadpool ``` Pull or download the `master` branch @@ -56,4 +57,4 @@ Building the client for each platform: GOOS=windows GOARCH=amd64 go build -o lczero-client.exe GOOS=darwin GOARCH=amd64 go build -o lczero-client_mac GOOS=linux GOARCH=amd64 go build -o lczero-client_linux -``` \ No newline at end of file +``` diff --git a/appveyor.yml b/appveyor.yml index bd4fa2c..691eca1 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -14,6 +14,8 @@ environment: install: - go get -u github.com/Tilps/chess - go get -u github.com/nightlyone/lockfile + - go get -u github.com/shettyh/threadpool + - go get -u github.com/jaypipes/ghw build_script: - go build -o lc0-training-client%NAME% lc0_main.go artifacts: @@ -26,4 +28,3 @@ deploy: secure: USFAdwQKTXqOXQjCYQfzWvzRpUhvqJLBkN4hbOg+j876vDxGZHt9bMYayb5evePp on: appveyor_repo_tag: true - diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..fb6ae28 --- /dev/null +++ b/build.sh @@ -0,0 +1,16 @@ +echo Downloading required dependencies... +echo -ne [0/4] Tilps/chess\\r +go get github.com/Tilps/chess +echo -ne [1/4] nightlyone/lockfile\\r +go get github.com/nightlyone/lockfile +echo -ne [2/4] jaypipes/ghw\\r +go get github.com/jaypipes/ghw +echo -ne [3/4] shettyh/threadpool\\r +go get github.com/shettyh/threadpool +echo -ne [DONE] installed dependencies\\n +echo building windows... +GOOS=windows GOARCH=amd64 go build +echo finished with errno: $? +echo building linux... +GOOS=linux GOARCH=arm go build +echo finished with errno: $? diff --git a/lc0_main.go b/lc0_main.go index 6e2ebaa..578c487 100644 --- a/lc0_main.go +++ b/lc0_main.go @@ -33,6 +33,8 @@ import ( "github.com/Tilps/chess" "github.com/nightlyone/lockfile" + "github.com/jaypipes/ghw" + "github.com/shettyh/threadpool" ) var ( @@ -57,6 +59,7 @@ var ( user = flag.String("user", "", "Username") password = flag.String("password", "", "Password") gpu = flag.Int("gpu", -1, "GPU to use (ignored if --backend-opts used)") + quiet = flag.Bool("quiet", false, "force quiet mode or force non quiet mode") // debug = flag.Bool("debug", false, "Enable debug mode to see verbose output and save logs") lc0Args = flag.String("lc0args", "", "") backopts = flag.String("backend-opts", "", @@ -81,6 +84,35 @@ type Settings struct { Localhost string } +type GameTask struct { + cli *http.Client + ctr int +} + +func isFlagPassed(name string) bool { + found := false + flag.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + +func (t *GameTask) Run() { + var err error + err = nextGame(t.cli, t.ctr) + if err != nil { + if err.Error() == "retry" { + time.Sleep(1 * time.Second) + err = nextGame(t.cli, t.ctr) + } + log.Print(err) + log.Print("Sleeping for 30 seconds...") + time.Sleep(30 * time.Second) + } +} + const inf = "inf" /* @@ -401,7 +433,9 @@ func (c *cmdWrapper) launch(networkPath string, otherNetPath string, args []stri c.Cmd.Args = append(c.Cmd.Args, "--no-share-trees") } - fmt.Printf("Args: %v\n", c.Cmd.Args) + if !*quiet { + fmt.Printf("Args: %v\n", c.Cmd.Args) + } stdout, err := c.Cmd.StdoutPipe() if err != nil { @@ -1002,7 +1036,10 @@ func nextGame(httpClient *http.Client, count int) error { if err != nil { return err } - log.Printf("serverParams: %s", serverParams) + if !*quiet { + log.Println(*quiet) + log.Printf("serverParams: %s", serverParams) + } if nextGame.BookUrl != "" { book, err := getBook(&http.Client{}, nextGame.BookUrl, nextGame.BookSha) @@ -1134,6 +1171,16 @@ func maybeSetTrainOnly() { } } +func getGpuNumber() (int) { + gpu, err := ghw.GPU() + if err != nil { + fmt.Printf("Error getting GPU info: %v", err) + return 0 + } + + return len(gpu.GraphicsCards) +} + func main() { fmt.Printf("Lc0 client version %v\n", getExtraParams()["version"]) @@ -1238,19 +1285,28 @@ func main() { *localHost = defaultLocalHost } + var gpunum int + gpunum = getGpuNumber() + fmt.Printf("Detected %v GPU(s)\n", gpunum) + + if !isFlagPassed("quiet") { + *quiet = gpunum > 1 + } + + if *quiet { + fmt.Println("quiet_mode: on") + } else { + fmt.Println("quiet_mode: off") + } + httpClient := &http.Client{Timeout:300 * time.Second} startTime = time.Now() + pool := threadpool.NewThreadPool(gpunum,100) for i := 0; ; i++ { - err := nextGame(httpClient, i) - if err != nil { - if err.Error() == "retry" { - time.Sleep(1 * time.Second) - continue - } - log.Print(err) - log.Print("Sleeping for 30 seconds...") - time.Sleep(30 * time.Second) - continue + task := &GameTask{httpClient, i} + pool.Execute(task) + if i % gpunum == 0 { + time.Sleep(time.Second * 10) } } } diff --git a/lc0_smain.go b/lc0_smain.go new file mode 100644 index 0000000..e9f96a3 --- /dev/null +++ b/lc0_smain.go @@ -0,0 +1,1322 @@ +// A new client to work with the lc0 binary. +// +// +package main + +import ( + "bufio" + "bytes" + "compress/gzip" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "net/url" + "os" + "os/exec" + "path" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "client" + + "github.com/Tilps/chess" + "github.com/nightlyone/lockfile" + "github.com/jaypipes/ghw" + "github.com/shettyh/threadpool" +) + +var ( + startTime time.Time + totalGames int + pendingNextGame *client.NextGameResponse + randId int + hasCudnn bool + hasCuda bool + hasOpenCL bool + hasEigen bool + hasDx bool + parallelism32 bool + testedDxNet string + + lc0Exe = "lc0" + defaultLocalHost = "Unknown" + gpuType = "Unknown" + + localHost = flag.String("localhost", "", "Localhost name to send to the server when reporting\n(defaults to Unknown, overridden by the configuration file)") + hostname = flag.String("hostname", "http://api.lczero.org", "Address of the server") + user = flag.String("user", "", "Username") + password = flag.String("password", "", "Password") + gpu = flag.Int("gpu", -1, "GPU to use (ignored if --backend-opts used)") + quiet = flag.Bool("quiet", false, "force quiet mode or force non quiet mode") + // debug = flag.Bool("debug", false, "Enable debug mode to see verbose output and save logs") + lc0Args = flag.String("lc0args", "", "") + backopts = flag.String("backend-opts", "", + `Options for the lc0 mux. backend. Example: --backend-opts="cudnn(gpu=1)"`) + parallel = flag.Int("parallelism", -1, "Number of games to play in parallel (-1 for default)") + cacheDir = flag.String("cache", "", "Directory to use for downloaded files cache (if it exists)") + useTestServer = flag.Bool("use-test-server", false, "Set host name to test server.") + runId = flag.Uint("run", 0, "Which training run to contribute to (default 0 to let server decide)") + keep = flag.Bool("keep", false, "Do not delete old network files") + version = flag.Bool("version", false, "Print version and exit.") + trainOnly = flag.Bool("train-only", false, "Do not play match games") + report_host = flag.Bool("report-host", false, "Send hostname to server for more fine-grained statistics") + report_gpu = flag.Bool("report-gpu", false, "Send gpu info to server for more fine-grained statistics") + cudnn = flag.Bool("cudnn", true, "Prefer the cudnn backend (if available)") + settingsPath = flag.String("config", "", "JSON configuration file to use") +) + +// Settings holds username and password. +type Settings struct { + User string + Pass string + Localhost string +} + +type GameTask struct { + cli *http.Client + ctr int + gpuIteration int +} + +func isFlagPassed(name string) bool { + found := false + flag.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + +func (t *GameTask) Run() { + time.Sleep(time.Second * 4) + var err error + err = nextGame(t.cli, t.ctr, t.gpuIteration) + if err != nil { + if err.Error() == "retry" { + time.Sleep(1 * time.Second) + err = nextGame(t.cli, t.ctr, t.gpuIteration) + } + log.Print(err) + log.Print("Sleeping for 30 seconds...") + time.Sleep(30 * time.Second) + } +} + +const inf = "inf" + +/* + Reads the user and password from a config file and returns empty strings if anything went wrong. +*/ +func readSettings(path string) (string, string, string) { + settings := Settings{} + file, err := os.Open(path) + if err != nil { + // File was not found + return "", "", "" + } + defer file.Close() + decoder := json.NewDecoder(file) + err = decoder.Decode(&settings) + if err != nil { + log.Fatal("Error decoding JSON ", err) + return "", "", "" + } + return settings.User, settings.Pass, settings.Localhost +} + +/* + Prompts the user for a username and password and creates the config file. +*/ +func createSettings(path string) (string, string) { + settings := Settings{} + + fmt.Printf("Please enter your username and password, an account will be automatically created.\n") + fmt.Printf("Note that this password will be stored in plain text, so avoid a password that is\n") + fmt.Printf("also used for sensitive applications. It also cannot be recovered.\n") + fmt.Printf("Enter username : ") + fmt.Scanf("%s\n", &settings.User) + fmt.Printf("Enter password : ") + fmt.Scanf("%s\n", &settings.Pass) + jsonSettings, err := json.Marshal(settings) + if err != nil { + log.Fatal("Cannot encode settings to JSON ", err) + return "", "" + } + settingsFile, err := os.Create(path) + defer settingsFile.Close() + if err != nil { + log.Fatal("Could not create output file ", err) + return "", "" + } + settingsFile.Write(jsonSettings) + return settings.User, settings.Pass +} + +func getExtraParams() map[string]string { + return map[string]string{ + "user": *user, + "password": *password, + "version": "33", + "token": strconv.Itoa(randId), + "train_only": strconv.FormatBool(*trainOnly), + "hostname": *localHost, + "gpu": gpuType, + "gpu_id": strconv.Itoa(*gpu), + } +} + +func uploadGame(httpClient *http.Client, path string, pgn string, + nextGame client.NextGameResponse, version string, fp_threshold float64) error { + + var retryCount uint32 + + for { + retryCount++ + if retryCount > 3 { + return errors.New("UploadGame failed: Too many retries") + } + + extraParams := getExtraParams() + extraParams["training_id"] = strconv.Itoa(int(nextGame.TrainingId)) + extraParams["network_id"] = strconv.Itoa(int(nextGame.NetworkId)) + extraParams["pgn"] = pgn + extraParams["engineVersion"] = version + if fp_threshold >= 0.0 { + extraParams["fp_threshold"] = strconv.FormatFloat(fp_threshold, 'E', -1, 64) + } + request, err := client.BuildUploadRequest(*hostname+"/upload_game", extraParams, "file", path) + if err != nil { + log.Printf("BUR: %v", err) + return err + } + resp, err := httpClient.Do(request) + if err != nil { + log.Printf("http.Do: %v", err) + return err + } + body := &bytes.Buffer{} + _, err = body.ReadFrom(resp.Body) + if err != nil { + log.Print(err) + log.Print("Error uploading, retrying...") + time.Sleep(time.Second * (2 << retryCount)) + continue + } + resp.Body.Close() + if resp.StatusCode != 200 && strings.Contains(body.String(), " upgrade ") { + log.Printf("The lc0 version you are using is not accepted by the server") + if strings.Contains(version, "dev") { + log.Printf("It is an unreleased development version") + } else if strings.Contains(version, "rc") { + log.Printf("It is a release candidate") + } + log.Printf("You probably need the latest release") + os.Exit(5) + } + break + } + + totalGames++ + var duration = time.Since(startTime) + var speed = int(float64(totalGames) / duration.Hours() * 24) + log.Printf("Completed %d games in %s time (%d games/day)", totalGames, duration, speed) + + err := os.Remove(path) + if err != nil { + log.Printf("Failed to remove training file: %v", err) + } + + return nil +} + +type gameInfo struct { + pgn string + fname string + // If >= 0, this is the value which if resign threshold was set + // higher a false positive would have occurred if the game had been + // played with resign. + fp_threshold float64 + player1 string + result string +} + +type cmdWrapper struct { + Cmd *exec.Cmd + Pgn string + Input io.WriteCloser + BestMove chan string + gi chan gameInfo + Version string + Retry chan bool +} + +func (c *cmdWrapper) openInput() { + var err error + c.Input, err = c.Cmd.StdinPipe() + if err != nil { + log.Fatal(err) + } +} + +func convertMovesToPGN(moves []string, result string, start_ply_count int) string { + game := chess.NewGame(chess.UseNotation(chess.LongAlgebraicNotation{})) + if len(moves) > 6 && moves[len(moves)-7] == "from_fen" { + fen := strings.Join(moves[len(moves)-6:], " ") + moves = moves[:len(moves)-7] + pair := &chess.TagPair{ + Key: "FEN", + Value: fen, + } + tagPairs := []*chess.TagPair{pair} + fen_func, _ := chess.FEN(fen) + game = chess.NewGame(chess.UseNotation(chess.LongAlgebraicNotation{}), fen_func, chess.TagPairs(tagPairs)) + } + for _, m := range moves { + err := game.MoveStr(m) + if err != nil { + log.Fatalf("movstr: %v", err) + } + } + if game.Outcome() == chess.NoOutcome && len(game.EligibleDraws()) > 1 { + game.Draw(game.EligibleDraws()[1]) + } + game2 := chess.NewGame() + b, err := game.MarshalText() + if err != nil { + log.Fatalf("MarshalText failed: %v", err) + } + b_str := string(b) + if strings.HasSuffix(b_str, " *") && result != "" { + to_append := "1/2-1/2" + if result == "whitewon" { + to_append = "1-0" + } else if result == "blackwon" { + to_append = "0-1" + } + b = []byte(strings.TrimRight(b_str, "*") + to_append) + } + game2.UnmarshalText(b) + return game2.String() + " {OL: " + strconv.Itoa(start_ply_count) + "}" +} + +func createCmdWrapper() *cmdWrapper { + c := &cmdWrapper{ + gi: make(chan gameInfo), + BestMove: make(chan string), + Version: "v0.10.0", + Retry: make(chan bool), + } + return c +} + +func checkLc0() { + cmd := exec.Command(lc0Exe) + cmd.Args = append(cmd.Args, "--help") + out, err := cmd.CombinedOutput() + if err != nil { + log.Fatal(err) + } + if bytes.Contains(out, []byte("eigen")) { + hasEigen = true + } + if bytes.Contains(out, []byte("dx12")) { + hasDx = true + } + if bytes.Contains(out, []byte("cuda-auto")) { + hasCuda = true + parallelism32 = true + } + if bytes.Contains(out, []byte("cudnn-auto")) && *cudnn { + hasCudnn = true + parallelism32 = true + } + if bytes.Contains(out, []byte("opencl")) { + hasOpenCL = true + } +} + +func checkDx(networkPath string) { + if !hasEigen { + log.Fatalf("Dx12 backend cannot be validated") + } + log.Println("Sanity checking the dx12 driver.") + cmd := exec.Command(lc0Exe) + sGpu := "" + if *gpu >= 0 { + sGpu = fmt.Sprintf(",gpu=%v", *gpu) + } + cmd.Args = append(cmd.Args, "benchmark", "-w", networkPath, "--backend=check") + cmd.Args = append(cmd.Args, fmt.Sprintf("--backend-opts=mode=check,freq=1.0,atol=5e-1,dx12%v", sGpu)) + // Add the startpos fen to get consistent behavior with old and new lc0 benchmark. + cmd.Args = append(cmd.Args, "--fen=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") + out, err := cmd.CombinedOutput() + if err != nil { + log.Fatal(err) + } + if bytes.Contains(out, []byte("*** ERROR check failed")) { + log.Fatal("The dx12 backend failed the self check - try updating gpu drivers") + } + log.Println("The dx12 driver passed the initial sanity check.") +} + +func (c *cmdWrapper) launch(networkPath string, otherNetPath string, args []string, input bool, gpuIteration ... int) { + c.Cmd = exec.Command(lc0Exe) + // Add the "selfplay" or "uci" part first + mode := args[0] + c.Cmd.Args = append(c.Cmd.Args, mode) + args = args[1:] + + if mode != "selfplay" { + c.Cmd.Args = append(c.Cmd.Args, "--backend=multiplexing") + } + if *lc0Args != "" { + log.Println("WARNING: Option --lc0args is for testing, not production use!") + log.SetPrefix("TESTING: ") + parts := strings.Split(*lc0Args, " ") + c.Cmd.Args = append(c.Cmd.Args, parts...) + } + parallelism := *parallel + sGpu := "" + if *gpu >= 0 { + sGpu = fmt.Sprintf(",gpu=%v", *gpu) + } else { + sGpu = fmt.Sprintf(",gpu=%v", gpuIteration[0]) + } + // Check the dx12 backend if it is the first time or we changed net, but only if no higher + // priority backend is available. + if !hasCuda && !hasCudnn && hasDx && testedDxNet != networkPath { + checkDx(networkPath) + testedDxNet = networkPath + } + if *backopts != "" { + // Check against small token blacklist, currently only "mlh", "random" and "recordreplay" + tokens := regexp.MustCompile("[,=().0-9]").Split(*backopts, -1) + for _, token := range tokens { + switch token { + case "mlh", "random", "recordreplay": + log.Fatalf("Not accepted in --backend-opts: %s", token) + } + } + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=%s", *backopts)) + } else if hasCudnn { + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=backend=cudnn-auto%v", sGpu)) + if parallelism <= 0 && parallelism32 { + parallelism = 32 + } + } else if hasCuda { + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=backend=cuda-auto%v", sGpu)) + if parallelism <= 0 && parallelism32 { + parallelism = 32 + } + } else if hasDx { + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=check(freq=1e-5,atol=5e-1,dx12%v)", sGpu)) + } else if hasOpenCL { + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=backend=opencl%v", sGpu)) + } + if parallelism > 0 && mode == "selfplay" { + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--parallelism=%v", parallelism)) + } + c.Cmd.Args = append(c.Cmd.Args, args...) + if otherNetPath == "" { + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--weights=%s", networkPath)) + } else { + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--player1.weights=%s", networkPath)) + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--player2.weights=%s", otherNetPath)) + c.Cmd.Args = append(c.Cmd.Args, "--no-share-trees") + } + + if !*quiet { + fmt.Printf("Args: %v\n", c.Cmd.Args) + } + + stdout, err := c.Cmd.StdoutPipe() + if err != nil { + log.Fatal(err) + } + + c.Cmd.Stderr = c.Cmd.Stdout + + // If the game wasn't played with resign, and the engine supports it, + // this will be populated by the resign_report before the gameready + // with the value which the resign threshold should be kept below to + // avoid a false positive. + last_fp_threshold := -1.0 + go func() { + defer close(c.BestMove) + defer close(c.gi) + stdoutScanner := bufio.NewScanner(stdout) + for stdoutScanner.Scan() { + line := stdoutScanner.Text() + // fmt.Printf("lc0: %s\n", line) + switch { + case strings.HasPrefix(line, "Unknown command line flag"): + fmt.Println(line) + log.Fatal("You probably have an old lc0 version") + case strings.Contains(line, "GPU: GeForce GTX 16"): + fallthrough // Does not contain "fp16" so the following works fine. + case strings.Contains(line, "Switching to"): + fmt.Println(line) + if parallelism == 32 && parallelism32 && !strings.Contains(line, "fp16") { + parallelism32 = false + if mode == "selfplay" && *parallel <= 0 { + log.Println("Restarting with default parallelism") + c.Retry <- true + } + } + case strings.HasPrefix(line, "resign_report "): + args := strings.Split(line, " ") + fp_threshold_idx := -1 + for idx, arg := range args { + if arg == "fp_threshold" { + fp_threshold_idx = idx + 1 + } + } + if fp_threshold_idx >= 0 { + last_fp_threshold, err = strconv.ParseFloat(args[fp_threshold_idx], 64) + if err != nil { + log.Printf("Malformed resign_report: %q", line) + last_fp_threshold = -1.0 + } + } + fmt.Println(line) + case strings.HasPrefix(line, "gameready "): + // filename is between "trainingfile" and "gameid" + idx1 := strings.Index(line, "trainingfile") + idx2 := strings.LastIndex(line, "gameid") + idx3 := strings.LastIndex(line, "moves") + if idx1 < 0 || idx2 < 0 || idx3 < 0 { + log.Printf("Malformed gameready: %q", line) + break + } + idx4 := strings.LastIndex(line, "player1") + idx5 := strings.LastIndex(line, "result") + idx6 := strings.LastIndex(line, "play_start_ply") + result := "" + if idx5 < 0 { + idx5 = idx3 + } else { + result = line[idx5+7 : idx3-1] + } + player := "" + if idx4 >= 0 { + player = line[idx4+8 : idx5-1] + } + start_ply_count := -1 + if idx6 >= 0 { + start_ply_count, err = strconv.Atoi(line[idx6+15 : idx4-1]) + } + file := line[idx1+13 : idx2-1] + pgn := convertMovesToPGN(strings.Split(line[idx3+6:len(line)], " "), result, start_ply_count) + fmt.Printf("PGN: %s\n", pgn) + c.gi <- gameInfo{pgn: pgn, fname: file, fp_threshold: last_fp_threshold, player1: player, result: result} + last_fp_threshold = -1.0 + case strings.HasPrefix(line, "bestmove "): + // fmt.Println(line) + c.BestMove <- strings.Split(line, " ")[1] + case strings.HasPrefix(line, "id name Lc0 "): + c.Version = strings.Split(line, " ")[3] + fmt.Println(line) + case strings.HasPrefix(line, "info"): + break + case strings.HasPrefix(line, "GPU: "): + if *report_gpu && *backopts == "" { + gpuType = strings.TrimPrefix(line, "GPU: ") + } + fmt.Println(line) + case strings.HasPrefix(line, "Selected device: "): + if *report_gpu && *backopts == "" { + gpuType = strings.TrimPrefix(line, "Selected device: ") + } + fmt.Println(line) + case strings.HasPrefix(line, "BLAS"): + if *report_gpu && *backopts == "" { + gpuType = "None" + } + fmt.Println(line) + case strings.HasPrefix(line, "*** ERROR check failed"): + fmt.Println(line) + log.Fatal("The dx12 backend failed the self check - try updating gpu drivers") + default: + fmt.Println(line) + } + } + }() + + if input { + c.openInput() + } + + err = c.Cmd.Start() + if err != nil { + log.Fatal(err) + } +} + +func resultToNum(result string) int { + if result == "whitewon" { + return 1 + } + if result == "blackwon" { + return -1 + } + return 0 +} + +func playMatch(httpClient *http.Client, ngr client.NextGameResponse, baselinePath string, candidatePath string, params []string, gpuIteration int) (*client.NextGameResponse, error) { + // lc0 needs selfplay first in the argument list. + params = append([]string{"selfplay"}, params...) + // Training flag used for simplicity for now. + params = append(params, "--training=true") + + hasVisitsParam := false + for i := range params { + if strings.HasPrefix(params[i], "--visits=") || strings.HasPrefix(params[i], "--playouts=") { + hasVisitsParam = true + } + } + if !hasVisitsParam { + params = append(params, "--visits=800") + } + c := createCmdWrapper() + c.launch(candidatePath, baselinePath, params /* input= */, false, gpuIteration) + trainDirHolder := make([]string, 1) + trainDirHolder[0] = "" + defer func() { + // Remove the training dir when we're done training. + trainDir := trainDirHolder[0] + if trainDir != "" { + log.Printf("Removing traindir: %s", trainDir) + err := os.RemoveAll(trainDir) + if err != nil { + log.Printf("Error removing train dir: %v", err) + } + } + }() + doneCh := make(chan bool) + gameInfoCh := make(chan gameInfo) + reverseDoneCh := make(chan bool) + wg := &sync.WaitGroup{} + wg.Add(1) + var pendingNextGame *client.NextGameResponse + go func() { + defer wg.Done() + defer close(doneCh) + errCount := 0 + curng := &ngr + var flipped []gameInfo + var normal []gameInfo + for done := false; !done; { + select { + case <-reverseDoneCh: + log.Println("Match uploader exiting") + return + case gi, _ := <-gameInfoCh: + if gi.player1 == "black" { + flipped = append(flipped, gi) + } else { + normal = append(normal, gi) + } + for true { + if curng != nil { + if curng.Flip && len(flipped) > 0 { + l := len(flipped) + nextgi := flipped[l-1] + flipped = flipped[:l-1] + log.Println("uploading match result") + extraParams := getExtraParams() + extraParams["engineVersion"] = c.Version + client.UploadMatchResult(httpClient, *hostname, curng.MatchGameId, -resultToNum(nextgi.result), nextgi.pgn, extraParams) + log.Println("uploaded") + curng = nil + } else if !curng.Flip && len(normal) > 0 { + l := len(normal) + nextgi := normal[l-1] + normal = normal[:l-1] + log.Println("uploading match result") + extraParams := getExtraParams() + extraParams["engineVersion"] = c.Version + client.UploadMatchResult(httpClient, *hostname, curng.MatchGameId, resultToNum(nextgi.result), nextgi.pgn, extraParams) + log.Println("uploaded") + curng = nil + } + } + if curng != nil { + break + } + ng, err := client.NextGame(httpClient, *hostname, getExtraParams()) + if err != nil { + fmt.Printf("Error talking to server: %v\n", err) + errCount++ + if errCount < 10 { + break + } + return + } + if ng.Type != ngr.Type || ng.Sha != ngr.Sha || ng.CandidateSha != ngr.CandidateSha { + log.Println("Current match finished.") + pendingNextGame = &ng + return + } + curng = &ng + errCount = 0 + } + } + } + }() + progressOrKill := false + for done := false; !done; { + select { + case <-c.Retry: + close(reverseDoneCh) + return nil, errors.New("retry") + case <-doneCh: + done = true + progressOrKill = true + log.Println("Received message to end matches, killing lc0") + c.Cmd.Process.Kill() + case _, ok := <-c.BestMove: + // Just swallow the best moves, not actually needed. + if !ok { + log.Printf("BestMove channel closed unexpectedly, exiting match loop") + break + } + case gi, ok := <-c.gi: + if !ok { + log.Printf("GameInfo channel closed, exiting match loop") + done = true + break + } + progressOrKill = true + trainDirHolder[0] = path.Dir(gi.fname) + wg.Add(1) + go func() { + select { + case <-doneCh: + case gameInfoCh <- gi: + } + wg.Done() + }() + } + } + + log.Println("Waiting for lc0 to stop") + err := c.Cmd.Wait() + if err != nil { + fmt.Printf("lc0 exited with: %v", err) + } + log.Println("lc0 stopped") + close(reverseDoneCh) + + log.Println("Waiting for uploads to complete") + wg.Wait() + if !progressOrKill { + return nil, errors.New("Client self-exited without producing any matches.") + } + return pendingNextGame, nil +} + +func train(httpClient *http.Client, ngr client.NextGameResponse, + networkPath string, otherNetPath string, count int, params []string, doneCh chan bool) error { + // lc0 needs selfplay first in the argument list. + params = append([]string{"selfplay"}, params...) + params = append(params, "--training=true") + c := createCmdWrapper() + c.launch(networkPath, otherNetPath, params /* input= */, false) + trainDirHolder := make([]string, 1) + trainDirHolder[0] = "" + defer func() { + // Remove the training dir when we're done training. + trainDir := trainDirHolder[0] + if trainDir != "" { + log.Printf("Removing traindir: %s", trainDir) + err := os.RemoveAll(trainDir) + if err != nil { + log.Printf("Error removing train dir: %v", err) + } + } + }() + wg := &sync.WaitGroup{} + numGames := 1 + progressOrKill := false + for done := false; !done; { + select { + case <-c.Retry: + return errors.New("retry") + case <-doneCh: + done = true + progressOrKill = true + log.Println("Received message to end training, killing lc0") + c.Cmd.Process.Kill() + case _, ok := <-c.BestMove: + // Just swallow the best moves, only needed for match play. + if !ok { + log.Printf("BestMove channel closed unexpectedly, exiting train loop") + break + } + case gi, ok := <-c.gi: + if !ok { + log.Printf("GameInfo channel closed, exiting train loop") + done = true + break + } + fmt.Printf("Uploading game: %d\n", numGames) + numGames++ + progressOrKill = true + trainDirHolder[0] = path.Dir(gi.fname) + log.Printf("trainDir=%s", trainDirHolder[0]) + wg.Add(1) + go func() { + uploadGame(httpClient, gi.fname, gi.pgn, ngr, c.Version, gi.fp_threshold) + wg.Done() + }() + } + } + + log.Println("Waiting for lc0 to stop") + err := c.Cmd.Wait() + if err != nil { + fmt.Printf("lc0 exited with: %v", err) + } + log.Println("lc0 stopped") + + log.Println("Waiting for uploads to complete") + wg.Wait() + if !progressOrKill { + return errors.New("Client self-exited without producing any games.") + } + return nil +} + +func checkValidNetwork(dir string, sha string) (string, error) { + // Sha already exists? + path := filepath.Join(dir, sha) + _, err := os.Stat(path) + if err == nil { + file, _ := os.Open(path) + reader, err := gzip.NewReader(file) + if err == nil { + var bytes []byte + bytes, err = ioutil.ReadAll(reader) + sum := sha256.Sum256(bytes) + got := fmt.Sprintf("%x", sum) + if sha != got { + text := fmt.Sprintf("sha mismatch want:\n%s\ngot\n%s\n", sha, got) + err = errors.New(text) + } + } + file.Close() + if err != nil { + fmt.Printf("Deleting invalid network...\n") + os.Remove(path) + return path, err + } else { + return path, nil + } + } + return path, err +} + +func removeAllExcept(dir string, sha string, keepTime string) error { + files, err := ioutil.ReadDir(dir) + if err != nil { + return err + } + for _, file := range files { + if file.Name() == sha { + continue + } + timeLimit, _ := time.ParseDuration(keepTime) + if time.Since(file.ModTime()) < timeLimit { + continue + } + fmt.Printf("Removing %v\n", file.Name()) + err := os.RemoveAll(filepath.Join(dir, file.Name())) + if err != nil { + return err + } + } + return nil +} + +func acquireLock(dir string, sha string) (lockfile.Lockfile, error) { + lockpath, _ := filepath.Abs(filepath.Join(dir, sha+".lck")) + lock, err := lockfile.New(lockpath) + if err != nil { + // Unknown error. Exit. + log.Fatalf("Cannot init lockfile: %v", err) + } + // Attempt to acquire lock + err = lock.TryLock() + return lock, err +} + +func makeCacheDir(dir string) string { + userCache := *cacheDir + if len(userCache) == 0 { + if runtime.GOOS == "linux" { + userCache = os.Getenv("XDG_CACHE_HOME") + if len(userCache) == 0 { + homeDir := os.Getenv("HOME") + if len(homeDir) != 0 { + userCache = homeDir + "/.cache" + } + } + } else if runtime.GOOS == "darwin" { + homeDir := os.Getenv("HOME") + if len(homeDir) != 0 { + userCache = homeDir + "/Library/Caches" + } + } + } + if len(userCache) != 0 { + _, err := os.Stat(userCache) + if err == nil { + if len(*cacheDir) == 0 { + userCache = filepath.Join(userCache, "lc0") + } + dir = filepath.Join(userCache, dir) + } + } + os.MkdirAll(dir, os.ModePerm) + return dir +} + +func getNetwork(httpClient *http.Client, sha string, keepTime string) (string, error) { + dir := makeCacheDir("client-cache") + if keepTime != inf { + err := removeAllExcept(dir, sha, keepTime) + if err != nil { + log.Printf("Failed to remove old network(s): %v", err) + } + } + path, err := checkValidNetwork(dir, sha) + if err == nil { + // There is already a valid network. Use it. + return path, nil + } + + // Otherwise, let's download it + lock, err := acquireLock(dir, sha) + + if err != nil { + if err == lockfile.ErrBusy { + log.Println("Download initiated by other client - waiting") + for i := 0; i < 60; i++ { + time.Sleep(time.Second) + path, err := checkValidNetwork(dir, sha) + if err == nil { + return path, nil + } + } + return "", errors.New("Timed out") + } else { + log.Fatalf("Unable to lock: %v", err) + } + } + + // Lockfile acquired, download it + defer lock.Unlock() + fmt.Println("Downloading network...") + for i := 0; i < 3; i++ { + if i > 0 { + log.Println("Waiting 10 seconds before retrying") + time.Sleep(10 * time.Second) + } + err = client.DownloadNetwork(httpClient, *hostname, path, sha) + if err == nil { + return checkValidNetwork(dir, sha) + } + log.Printf("Network download failed: %v", err) + } + return "", err +} + +func checkValidBook(path string, sha string) (string, error) { + // File already exists? + _, err := os.Stat(path) + if err == nil { + file, _ := os.Open(path) + sum := sha256.New() + _, err := io.Copy(sum, file) + got := fmt.Sprintf("%x", sum.Sum(nil)) + if sha != got { + text := fmt.Sprintf("book sha mismatch want:\n%s\ngot\n%s\n", sha, got) + err = errors.New(text) + } + file.Close() + if err != nil { + fmt.Printf("Deleting invalid book...\n") + os.Remove(path) + return path, err + } else { + return path, nil + } + } + return path, err +} + +func getBook(httpClient *http.Client, book_url string, sha string) (string, error) { + dir := makeCacheDir("books") + u, err := url.Parse(book_url) + if err != nil { + log.Println("Unable to parse book URL") + return "", err + } + s := strings.Split(u.Path, "/") + book_name := s[len(s)-1] + path := filepath.Join(dir, book_name) + _, err = checkValidBook(path, sha) + if err == nil { + // Book is there, use it. + return path, nil + } + + // Otherwise, let's download it + lock, err := acquireLock(dir, book_name) + + if err != nil { + if err == lockfile.ErrBusy { + log.Println("Book download initiated by other client") + return "", err + } else { + log.Fatalf("Unable to lock: %v", err) + } + } + + // Lockfile acquired, download it + defer lock.Unlock() + fmt.Println("Downloading book...") + + r, err := httpClient.Get(book_url) + if err != nil { + log.Println("Book download failed") + return "", err + } + + out, err := ioutil.TempFile(dir, book_name+"_tmp") + if err != nil { + log.Println("Unable to create temporary file") + return "", err + } + + _, err = io.Copy(out, r.Body) + r.Body.Close() + out.Close() + if err == nil { + err = os.Rename(out.Name(), path) + } + // Ensure tmpfile is erased + os.Remove(out.Name()) + + return checkValidBook(path, sha) +} + +func nextGame(httpClient *http.Client, count int, gpuIteration int) error { + var nextGame client.NextGameResponse + var err error + fmt.Println("using gpu no %v", gpuIteration) + if pendingNextGame != nil { + nextGame = *pendingNextGame + pendingNextGame = nil + err = nil + } else { + nextGame, err = client.NextGame(httpClient, *hostname, getExtraParams()) + if err != nil { + return err + } + } + var serverParams []string + err = json.Unmarshal([]byte(nextGame.Params), &serverParams) + if err != nil { + return err + } + if !*quiet { + log.Printf("serverParams: %s", serverParams) + } + + if nextGame.BookUrl != "" { + book, err := getBook(&http.Client{}, nextGame.BookUrl, nextGame.BookSha) + if err != nil { + return err + } + // Replace the book file with the correct path + for i := range serverParams { + if strings.HasPrefix(serverParams[i], "--openings-pgn=") { + serverParams[i] = "--openings-pgn=" + book + break + } + } + } + + if nextGame.Type == "match" { + log.Println("Getting networks for match") + networkPath, err := getNetwork(httpClient, nextGame.Sha, inf) + if err != nil { + return err + } + candidatePath, err := getNetwork(httpClient, nextGame.CandidateSha, inf) + if err != nil { + return err + } + log.Println("Starting match") + possibleNextGame, err := playMatch(httpClient, nextGame, networkPath, candidatePath, serverParams, gpuIteration) + if err != nil { + log.Printf("playMatch: %v", err) + return err + } + pendingNextGame = possibleNextGame + return nil + } + + if nextGame.Type == "train" { + keepTime := nextGame.KeepTime + if *keep { + keepTime = inf + } else if keepTime == "" { + // Four hours should be enough for clients serving 2 parallel runs in + // the same directory, even after one or two failed failed promotions. + keepTime = "4h" + } + networkPath, err := getNetwork(httpClient, nextGame.Sha, keepTime) + if err != nil { + return err + } + otherNetPath := "" + if nextGame.CandidateSha != "" { + otherNetPath, err = getNetwork(httpClient, nextGame.CandidateSha, inf) + if err != nil { + return err + } + } + doneCh := make(chan bool) + go func() { + defer close(doneCh) + errCount := 0 + for { + time.Sleep(60 * time.Second) + if nextGame.Type == "Done" { + return + } + ng, err := client.NextGame(httpClient, *hostname, getExtraParams()) + if err != nil { + fmt.Printf("Error talking to server: %v\n", err) + errCount++ + if errCount < 10 { + continue + } + return + } + if ng.Type != nextGame.Type || ng.Sha != nextGame.Sha { + // Prefetch the next net before terminating game. + if ng.Type == "match" { + getNetwork(httpClient, ng.CandidateSha, inf) + } else { + getNetwork(httpClient, ng.Sha, inf) + } + pendingNextGame = &ng + return + } + errCount = 0 + } + }() + err = train(httpClient, nextGame, networkPath, otherNetPath, count, serverParams, doneCh) + // Ensure the anonymous function stops retrying. + nextGame.Type = "Done" + if err != nil { + return err + } + return nil + } + + return errors.New("Unknown game type: " + nextGame.Type) +} + +// Ensure Tilps/chess is new enough. +func testChessVersion() { + if chess.GetLibraryVersion() < 3 { + log.Fatal("You need a more recent version of package github.com/Tilps/chess") + } +} + +func hideLc0argsFlag() { + shown := new(flag.FlagSet) + flag.VisitAll(func(f *flag.Flag) { + if f.Name != "lc0args" { + shown.Var(f.Value, f.Name, f.Usage) + } + }) + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + shown.PrintDefaults() + } +} + +func maybeSetTrainOnly() { + found := false + flag.Visit(func(f *flag.Flag) { + if f.Name == "train-only" { + found = true + } + }) + if !found && !hasCudnn && !hasCuda && !hasDx { + *trainOnly = true + log.Println("Will only run training games, use -train-only=false to override") + } +} + +func getGpuNumber() (int) { + gpu, err := ghw.GPU() + if err != nil { + fmt.Printf("Error getting GPU info: %v", err) + return 0 + } + + return len(gpu.GraphicsCards) +} + +func main() { + fmt.Printf("Lc0 client version %v\n", getExtraParams()["version"]) + + testChessVersion() + + hideLc0argsFlag() + flag.Parse() + + if *version { + return + } + + if runtime.GOOS == "windows" { + lc0Exe = "lc0.exe" + } + dir, _ := os.Getwd() + fi, err := os.Stat(path.Join(dir, lc0Exe)) + if err == nil && !fi.Mode().IsDir() { + lc0Exe = path.Join(dir, lc0Exe) + } + checkLc0() + + maybeSetTrainOnly() + + // 640 ought to be enough for anybody. + if *runId > 640 { + log.Fatal("Training run number too large") + } + randBytes := make([]byte, 2) + _, err = rand.Reader.Read(randBytes) + if err != nil { + randId = -1 + } else { + randId = int(*runId)<<16 | int(randBytes[0])<<8 | int(randBytes[1]) + } + + if *useTestServer { + *hostname = "http://testserver.lczero.org" + } + + log.SetFlags(log.LstdFlags | log.Lshortfile) + + if len(*settingsPath) == 0 { + *settingsPath = "lc0-training-client-config.json" + configDir := "" + if runtime.GOOS == "linux" { + configDir = os.Getenv("XDG_CONFIG_HOME") + if len(configDir) == 0 { + homeDir := os.Getenv("HOME") + if len(homeDir) != 0 { + configDir = homeDir + "/.config" + } + } + } else if runtime.GOOS == "darwin" { + homeDir := os.Getenv("HOME") + if len(homeDir) != 0 { + configDir = homeDir + "/Library/Preferences" + } + } + + if len(configDir) != 0 { + configDir = filepath.Join(configDir, "lc0") + _, err = os.Stat(configDir) + if os.IsNotExist(err) { + err = os.Mkdir(configDir, os.ModePerm) + } + if err == nil { + *settingsPath = filepath.Join(configDir, *settingsPath) + } + } + } + + settingsUser, settingsPassword, settingsHost := readSettings(*settingsPath) + if len(*user) == 0 || len(*password) == 0 { + *user = settingsUser + *password = settingsPassword + + if len(*user) == 0 || len(*password) == 0 { + *user, *password = createSettings(*settingsPath) + } + } + + if len(settingsHost) != 0 && len(*localHost) == 0 { + *localHost = settingsHost + } + + if len(*user) == 0 { + log.Fatal("You must specify a username") + } + if len(*password) == 0 { + log.Fatal("You must specify a non-empty password") + } + + if *report_host && len(*localHost) == 0 { + s, err := os.Hostname() + if err == nil { + *localHost = s + } + } + + if len(*localHost) == 0 { + *localHost = defaultLocalHost + } + + var gpunum int + gpunum = getGpuNumber() + fmt.Printf("Detected %v GPU(s)\n", gpunum) + + if !isFlagPassed("quiet") { + *quiet = gpunum > 1 + } + + if *quiet { + fmt.Println("quiet_mode: on") + } else { + fmt.Println("quiet_mode: off") + } + + httpClient := &http.Client{Timeout:300 * time.Second} + startTime = time.Now() + gpuIteration := 0 + pool := threadpool.NewThreadPool(gpunum,100) + for i := 0; ; i++ { + task := &GameTask{httpClient, i, gpuIteration} + pool.Execute(task) + if i+1 % gpunum == 0 { + time.Sleep(time.Second * 10) + gpuIteration = 0 + } + gpuIteration++ + } + pool.Close() +}