diff --git a/cmd/launch/main.go b/cmd/launch/main.go new file mode 100644 index 00000000..5e936a04 --- /dev/null +++ b/cmd/launch/main.go @@ -0,0 +1,41 @@ +package main + +import ( + "flag" + "fmt" + "os/exec" + "strings" +) + +var numNodes = flag.Int("numNodes", 1, "The number of nodes for distributed training") +var nodeRank = flag.Int("nodeRank", 0, "The rank of the node") +var nprocPerNode = flag.Int("nprocPerNode", 1, "The number of processes on each node") +var masterAddr = flag.String("masterAddr", "127.0.0.1", "The address of master node(rank 0)") +var masterPort = flag.Int("masterPort", 11111, "The port of master node") +var sharedFile = flag.String("sharedFile", "", "The shared file which could be access by all processes") +var trainingCmd = flag.String("trainingCmd", "", "The training command") + +func main() { + flag.Parse() + + commands := []string{} + size := (*numNodes) * (*nprocPerNode) + for i := 0; i < *nprocPerNode; i++ { + rank := (*nprocPerNode)*(*nodeRank) + i + cmd := fmt.Sprintf("%s -rank=%d -size=%d", *trainingCmd, rank, size) + if *masterAddr != "" { + cmd = fmt.Sprintf("%s -masterAddr=%s -masterPort=%d", cmd, *masterAddr, *masterPort) + } else if *sharedFile != "" { + cmd = fmt.Sprintf("%s -sharedFile=%s", cmd, *sharedFile) + } else { + panic("Must set value for masterAddr or sharedFile") + } + commands = append(commands, cmd) + } + + for _, cmd := range commands { + args := strings.Fields(cmd) + cmd := exec.Command(args[0], args[1:]...) + cmd.Start() + } +} diff --git a/example/allreduce/main.go b/example/allreduce/main.go new file mode 100644 index 00000000..7f129ab9 --- /dev/null +++ b/example/allreduce/main.go @@ -0,0 +1,54 @@ +package main + +import ( + "flag" + + torch "github.com/wangkuiyi/gotorch" + F "github.com/wangkuiyi/gotorch/nn/functional" + "github.com/wangkuiyi/gotorch/vision/models" +) + +var masterAddr = flag.String("masterAddr", "127.0.0.1", "The address of master node(rank 0)") +var masterPort = flag.Int("masterPort", 11111, "The port of master node") +var rank = flag.Int("rank", 0, "The rank of the current process") +var size = flag.Int("size", 1, "The size of the processes") + +func getGrads(params []torch.Tensor) (grads []torch.Tensor) { + for _, p := range params { + grads = append(grads, p.Grad()) + } + return +} + +func main() { + flag.Parse() + + ts := torch.NewTCPStore(*masterAddr, int64(*masterPort), int64(*size), *rank == 0) + defer ts.Close() + pg := torch.NewProcessGroupGloo(ts, int64(*rank), int64(*size)) + defer pg.Close() + + net := models.MLP() + opt := torch.SGD(0.01, 0.5, 0, 0, false) + params := net.Parameters() + opt.AddParameters(params) + + for _, p := range params { + pg.Broadcast([]torch.Tensor{p}) + } + + for i := 0; i < 10; i++ { + data := torch.Rand([]int64{16, 28, 28}, false) + label := torch.Ones([]int64{16}, false).CastTo(torch.Long) + + opt.ZeroGrad() + pred := net.Forward(data) + loss := F.NllLoss(pred, label, torch.Tensor{}, -100, "mean") + loss.Backward() + + grads := getGrads(params) + pg.AllReduceCoalesced(grads) + + opt.Step() + } +} diff --git a/nn/module.go b/nn/module.go index 5f21caf3..fd9e0031 100644 --- a/nn/module.go +++ b/nn/module.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "reflect" + "sort" torch "github.com/wangkuiyi/gotorch" ) @@ -163,12 +164,21 @@ func (m *Module) NamedBuffers() map[string]torch.Tensor { return r } +func sortKeys(ts map[string]torch.Tensor) (keys []string) { + for k := range ts { + keys = append(keys, k) + } + sort.Strings(keys) + return +} + // Parameters returns trainable parameters (recursively) func (m *Module) Parameters() []torch.Tensor { result := make([]torch.Tensor, 0) n := m.NamedParameters() - for _, v := range n { - result = append(result, v) + keys := sortKeys(n) + for _, k := range keys { + result = append(result, n[k]) } return result } @@ -177,8 +187,9 @@ func (m *Module) Parameters() []torch.Tensor { func (m *Module) Buffers() []torch.Tensor { result := make([]torch.Tensor, 0) n := m.NamedBuffers() - for _, v := range n { - result = append(result, v) + keys := sortKeys(n) + for _, k := range keys { + result = append(result, n[k]) } return result }