Skip to content

Commit

Permalink
fix codacy
Browse files Browse the repository at this point in the history
  • Loading branch information
Dario Coscia committed Jan 31, 2024
1 parent 59e4beb commit 97d0c0a
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions pina/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def plot_samples(self, problem, variables=None, filename=None, **kwargs):
"""
Plot the training grid samples.
:param AbstractProblem problem: The PINA problem from where to plot the domain.
:param AbstractProblem problem: The PINA problem from where to plot
the domain.
:param list(str) variables: Variables to plot. If None, all variables
are plotted. If 'spatial', only spatial variables are plotted. If
'temporal', only temporal variables are plotted. Defaults to None.
Expand All @@ -39,7 +40,8 @@ def plot_samples(self, problem, variables=None, filename=None, **kwargs):
variables = problem.temporal_domain.variables

if len(variables) not in [1, 2, 3]:
raise ValueError('Samples can be plotted only in dimensions 1, 2 and 3.')
raise ValueError('Samples can be plotted only in '
'dimensions 1, 2 and 3.')

fig = plt.figure()
proj = '3d' if len(variables) == 3 else None
Expand Down Expand Up @@ -96,7 +98,8 @@ def _1d_plot(self, pts, pred, v, method, truth_solution=None, **kwargs):

if truth_solution:
truth_output = truth_solution(pts).detach()
ax.plot(pts.extract(v), truth_output, label='True solution', **kwargs)
ax.plot(pts.extract(v), truth_output,
label='True solution', **kwargs)

# TODO: pred is a torch.Tensor, so no labels is available
# extra variable for labels should be
Expand Down Expand Up @@ -189,7 +192,8 @@ def plot(self,

if len(components) > 1:
raise NotImplementedError('Multidimensional plots are not implemented, '
'set components to an available components of the problem.')
'set components to an available components of'
' the problem.')
v = [
var for var in solver.problem.input_variables
if var not in fixed_variables.keys()
Expand All @@ -205,7 +209,8 @@ def plot(self,
pts = pts.to(device=solver.device)

# computing soluting and sending to cpu
predicted_output = solver.forward(pts).extract(components).as_subclass(torch.Tensor).cpu().detach()
predicted_output = solver.forward(pts).extract(components)
predicted_output = predicted_output.as_subclass(torch.Tensor).cpu().detach()
pts = pts.cpu()
truth_solution = getattr(solver.problem, 'truth_solution', None)

Expand Down

0 comments on commit 97d0c0a

Please sign in to comment.