Skip to content

Commit

Permalink
Merge pull request #55 from choderalab/sample_weight
Browse files Browse the repository at this point in the history
not re-batching test set
  • Loading branch information
yuanqing-wang authored Oct 30, 2020
2 parents 1c46403 + 8a5d9fd commit d80d280
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions espaloma/app/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,27 +183,43 @@ def test(self):
for metric in self.metrics:
results[metric.__name__] = {}

# NOTE: we are not doing this here since this will lead to OOM
# from time to time
# make it just one giant graph
g = list(self.data)
g = dgl.batch_hetero(g)
g = g.to(self.device)
# g = list(self.data)
# g = dgl.batch_hetero(g)
# g = g.to(self.device)

for state_name, state in self.states.items(): # loop through states
# load the state dict
self.net.load_state_dict(state)

# local scope
with g.local_scope():

for metric in self.metrics:

# loop through the metrics
results[metric.__name__][state_name] = (
metric(g_input=self.normalize.unnorm(self.net(g)))
.detach()
.cpu()
.numpy()
)
for metric in self.metrics:
assert isinstance(metric, esp.metrics.Metric)
input_fn, target_fn = metric.between

inputs = []
targets = []

for g in self.data:
with g.local_scope():
g = g.to(self.device)
g_input = self.normalize.unnorm(self.net(g))
inputs.append(input_fn(g_input))
targets.append(target_fn(g_input))

inputs = torch.cat(inputs, dim=0)
targets = torch.cat(targets, dim=0)

# loop through the metrics
results[metric.__name__][state_name] = (
metric.base_metric(inputs, targets)
.detach()
.cpu()
.numpy()
)

self.ref_g = self.normalize.unnorm(self.net(g))

Expand Down

0 comments on commit d80d280

Please sign in to comment.