-
Notifications
You must be signed in to change notification settings - Fork 2
/
golem.go
148 lines (116 loc) · 5.23 KB
/
golem.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package main
import (
"encoding/json"
"fmt"
"os"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golem/pkg"
"golem/pkg/model"
"github.com/spf13/cobra"
)
func TrainCommand() *cobra.Command {
var trainFile string
var testFile string
var outputFile string
var targetColumn string
var trainingParameters pkg.TrainingParameters
var modelParameters model.TabNetConfig
var cmd = &cobra.Command{
Use: "train -i trainData -o outputFile",
Short: "Trains a new model on the provided training data and saves the trained model",
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
pkg.Train(trainFile, testFile, outputFile, targetColumn, modelParameters, trainingParameters)
return nil
},
}
cmd.Flags().StringVarP(&trainFile, "train-file", "i", "", "name of train file")
cmd.Flags().StringVarP(&testFile, "test-file", "", "", "name of test file")
cmd.Flags().StringVarP(&outputFile, "output-file", "o", "", "name of the file to save model to.")
cmd.Flags().IntVarP(&trainingParameters.BatchSize, "batch-size", "b", 16, "batch size")
cmd.Flags().Float64VarP(&trainingParameters.LearningRate, "learning-rate", "l", 0.01, "learning rate")
cmd.Flags().IntVarP(&trainingParameters.ReportInterval, "report-interval", "r", 10, "loss report interval")
cmd.Flags().IntVarP(&trainingParameters.NumEpochs, "num-epochs", "n", 10, "number of epochs to train")
cmd.Flags().Uint64VarP(&trainingParameters.RndSeed, "random-seed", "x", 42, "random seed")
cmd.Flags().StringSliceVarP(&trainingParameters.CategoricalColumns, "categorical-columns", "", nil, "list of columns holding categorical data")
cmd.Flags().Float64VarP(&trainingParameters.InputDropout, "input-dropout-probability", "", 0.0, "probability of input dropout")
cmd.Flags().IntVarP(&modelParameters.CategoricalEmbeddingDimension, "categorical-embedding-size", "c", 1, "size of categorical embeddings")
cmd.Flags().IntVarP(&modelParameters.NumDecisionSteps, "num-decision-steps", "s", 2, "number of decision steps")
cmd.Flags().IntVarP(&modelParameters.IntermediateFeatureDimension, "feature-dimension", "f", 4, "feature dimension")
cmd.Flags().IntVarP(&modelParameters.OutputDimension, "output-dimension", "k", 4, "output dimension")
cmd.Flags().Float64VarP(&modelParameters.RelaxationFactor, "relaxation-factor", "g", 1.5, "relaxation factor")
cmd.Flags().Float64VarP(&modelParameters.BatchMomentum, "batch-momentum", "", 0.9, "batch momentum")
cmd.Flags().Float64VarP(&modelParameters.SparsityLossWeight, "sparsity-loss-weight", "", 0.0001, "weight of the sparsity loss in total loss")
cmd.Flags().Float64VarP(&modelParameters.ReconstructionLossWeight, "reconstruction-loss-weight", "", 0.0000, "weight of the reconstruction loss in total loss")
cmd.Flags().Float64VarP(&modelParameters.TargetLossWeight, "target-loss-weight", "", 1.0000, "weight of the target loss in total loss")
cmd.Flags().StringVarP(&targetColumn, "target-column", "t", "", "target column")
_ = cmd.MarkFlagRequired("train-file")
_ = cmd.MarkFlagRequired("output-file")
_ = cmd.MarkFlagRequired("target-column")
return cmd
}
func TestCommand() *cobra.Command {
var modelFile string
var inputFile string
var outputFile string
var attentionMapFile string
var cmd = &cobra.Command{
Use: "test -m modelFile -i trainFile [-o outputFile] [-a attentionOutputFile]",
Short: "Runs the provided model on the specified data input and optionally writes the results and attention map",
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
return pkg.Test(modelFile, inputFile, outputFile, attentionMapFile)
},
}
cmd.Flags().StringVarP(&modelFile, "model", "m", "", "name of model to test")
cmd.Flags().StringVarP(&inputFile, "input", "i", "", "name of data input file (optional, uses stdin if not present)")
cmd.Flags().StringVarP(&outputFile, "output", "o", "", "name of output file (optional)")
cmd.Flags().StringVarP(&attentionMapFile, "attentionMap", "a", "", "name of attention map output file (optional)")
_ = cmd.MarkFlagRequired("model")
return cmd
}
var logLevel string
var logFormat string
func main() {
Main := &cobra.Command{Use: "golem", PersistentPreRun: setupLogging}
Main.PersistentFlags().StringVarP(&logLevel, "log-level", "", "info", "Logging level: info error or debug")
Main.PersistentFlags().StringVarP(&logFormat, "log-format", "", "pretty", "Logging format: pretty or json")
Main.AddCommand(TrainCommand())
Main.AddCommand(TestCommand())
if err := Main.Execute(); err != nil {
panic(err)
}
}
func setupLogging(cmd *cobra.Command, args []string) {
switch logLevel {
case "error":
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
case "debug":
zerolog.SetGlobalLevel(zerolog.DebugLevel)
case "info":
zerolog.SetGlobalLevel(zerolog.InfoLevel)
default:
panic("Invalid logging level specified")
}
switch logFormat {
case "pretty":
setupPrettyLogging()
case "json":
default:
panic("Invalid log format specified")
}
}
func setupPrettyLogging() {
writer := zerolog.ConsoleWriter{Out: os.Stderr}
writer.FormatFieldValue = func(i interface{}) string {
switch v := i.(type) {
case json.Number:
val, _ := v.Float64()
return fmt.Sprintf("%.3f", val)
default:
return fmt.Sprintf("%s", i)
}
}
log.Logger = log.Output(writer)
}