diff --git a/lc0_main.go b/lc0_main.go index 222e38b..91eb6fd 100644 --- a/lc0_main.go +++ b/lc0_main.go @@ -282,7 +282,7 @@ func checkLc0() { if bytes.Contains(out, []byte("blas")) { hasBlas = true } - if bytes.Contains(out, []byte("dx")) { + if bytes.Contains(out, []byte("dx12")) { hasDx = true } if bytes.Contains(out, []byte("cudnn-fp16")) { @@ -339,9 +339,9 @@ func (c *cmdWrapper) launch(networkPath string, otherNetPath string, args []stri c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=backend=cudnn%v", sGpu)) } else if hasDx { if !hasBlas { - log.Fatalf("Dx backend cannot be validated") + log.Fatalf("Dx12 backend cannot be validated") } - c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=check(freq=.01,atol=5e-1,dx%v)", sGpu)) + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=check(freq=.01,atol=5e-1,dx12%v)", sGpu)) } else if hasOpenCL { c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--backend-opts=backend=opencl%v", sGpu)) } @@ -463,7 +463,7 @@ func (c *cmdWrapper) launch(networkPath string, otherNetPath string, args []stri fmt.Println(line) case strings.HasPrefix(line, "*** ERROR check failed"): fmt.Println(line) - log.Fatal("The dx backend failed the self check - try updating gpu drivers") + log.Fatal("The dx12 backend failed the self check - try updating gpu drivers") case strings.HasPrefix(line, "GPU compute capability:"): cc, _ := strconv.ParseFloat(strings.Split(line, " ")[3], 32) if cc >= 7.0 {