Skip to content

Commit

Permalink
Trainer train simplified, tests for load (#168)
Browse files Browse the repository at this point in the history
- the arguments of Trainer.train now are passed to the fit
- unittest for load/restoring from checkpoint
  • Loading branch information
ndem0 authored Jul 25, 2023
1 parent 11da773 commit 1b14b55
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
6 changes: 3 additions & 3 deletions pina/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class Trainer(pl.Trainer):

def __init__(self, solver, kwargs={}):
def __init__(self, solver, **kwargs):
super().__init__(**kwargs)

# get accellerator
Expand All @@ -29,6 +29,6 @@ def __init__(self, solver, kwargs={}):
self._loader = DummyLoader(solver.problem.input_pts, device)


def train(self): # TODO add kwargs and lightining capabilities
return super().fit(self._model, self._loader)
def train(self, **kwargs): # TODO add kwargs and lightining capabilities
return super().fit(self._model, self._loader, **kwargs)

2 changes: 1 addition & 1 deletion tests/test_solvers/test_garom.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_train_cpu():
hidden_dimension=64)
)

trainer = Trainer(solver=solver, kwargs={'max_epochs' : 4, 'accelerator': 'cpu'})
trainer = Trainer(solver=solver, max_epochs=4, accelerator='cpu')
trainer.train()

def test_sample():
Expand Down
43 changes: 40 additions & 3 deletions tests/test_solvers/test_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def poisson_sol(self, pts):

truth_solution = poisson_sol


class myFeature(torch.nn.Module):
"""
Feature: sin(x)
"""


def __init__(self):
super(myFeature, self).__init__()

Expand Down Expand Up @@ -92,8 +92,45 @@ def test_train_cpu():
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
trainer.train()

def test_train_restore():
tmpdir = "tests/tmp_restore"
poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu', default_root_dir=tmpdir)
trainer.train()
print('ggg')
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
t = ntrainer.train(
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
import shutil
shutil.rmtree(tmpdir)

def test_train_load():
tmpdir = "tests/tmp_load"
poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
trainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu',
default_root_dir=tmpdir)
trainer.train()
new_pinn = PINN.load_from_checkpoint(
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
problem = poisson_problem, model=model)
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
assert new_pinn.forward(test_pts).extract(['u']).shape == pinn.forward(test_pts).extract(['u']).shape
torch.testing.assert_close(new_pinn.forward(test_pts).extract(['u']), pinn.forward(test_pts).extract(['u']))
import shutil
shutil.rmtree(tmpdir)


# # TODO fix asap. Basically sampling few variables
# # works only if both variables are in a range.
Expand All @@ -118,7 +155,7 @@ def test_train_extra_feats_cpu():
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
trainer.train()

# TODO, fix GitHub actions to run also on GPU
Expand Down

0 comments on commit 1b14b55

Please sign in to comment.