-
Notifications
You must be signed in to change notification settings - Fork 0
/
wandb_logging.py
66 lines (47 loc) · 1.64 KB
/
wandb_logging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import wandb
import tensorflow as tf
## Logging functions borrowed from Wandb
def log_gradients(model, gradients):
metrics = {}
weights = model.trainable_weights
for (weight, gradients) in zip(weights, gradients):
metrics[
"gradients/" + weight.name.split(":")[0] + ".gradient"
] = wandb.Histogram(tf.convert_to_tensor(gradients))
return metrics
def log_weights(model):
metrics = {}
weights = model.trainable_weights
for weight in weights:
metrics[
"parameters/" + weight.name.split(":")[0] + ".weights"
] = wandb.Histogram(tf.convert_to_tensor(weight))
return metrics
def _array_has_dtype(array):
return hasattr(array, "dtype")
def _update_if_numeric(metrics, key, values):
if not _array_has_dtype(values):
_warn_not_logging(key)
return
if not is_numeric_array(values):
_warn_not_logging_non_numeric(key)
return
metrics[key] = wandb.Histogram(values)
def is_numeric_array(array):
return np.issubdtype(array.dtype, np.number)
def _warn_not_logging_non_numeric(name):
wandb.termwarn(
"Non-numeric values found in layer: {}, not logging this layer".format(name),
repeat=False,
)
def _warn_not_logging(name):
wandb.termwarn(
"Layer {} has undetermined datatype not logging this layer".format(name),
repeat=False,
)
def log(epoch, model, metrics, gradients=None):
wandb.log(log_weights(model), commit=False)
if gradients:
wandb.log(log_gradients(model, gradients), commit=False)
wandb.log({'epoch': epoch}, commit=False)
wandb.log(metrics, commit=True)