Skip to content

Commit

Permalink
Remove outlet
Browse files Browse the repository at this point in the history
  • Loading branch information
s9latimm committed Oct 19, 2024
1 parent e242438 commit 251cb81
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions src/nse/controller/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,18 @@ def __init__(self, experiment: Experiment, device: str, steps: int, layers: list
knowledge = self.__experiment.knowledge.detach()
learning = self.__experiment.learning.detach()
inlet = self.__experiment.inlet.detach()
outlet = self.__experiment.outlet.detach()

self.__u = torch.tensor([[i.u] for _, i in inlet + knowledge], dtype=torch.float64, device=self._device)
self.__v = torch.tensor([[i.v] for _, i in inlet + knowledge], dtype=torch.float64, device=self._device)

self.__null = torch.zeros(len(learning), 1, dtype=torch.float64, device=self._device)

self.__out = len(outlet)
self.__u_out = torch.zeros(len(outlet), 1, dtype=torch.float64, device=self._device)

self.__knowledge = (
torch.tensor([[k.x] for k, _ in outlet + inlet + knowledge],
torch.tensor([[k.x] for k, _ in inlet + knowledge],
dtype=torch.float64,
requires_grad=True,
device=self._device),
torch.tensor([[k.y] for k, _ in outlet + inlet + knowledge],
torch.tensor([[k.y] for k, _ in inlet + knowledge],
dtype=torch.float64,
requires_grad=True,
device=self._device),
Expand Down Expand Up @@ -91,12 +87,8 @@ def __loss(self):

u, v, *_ = self.__forward(self.__knowledge)

u_loss = self._mse(u[self.__out:], self.__u)
v_loss = self._mse(v[self.__out:], self.__v)

# # prohibits the model from hallucinating an incoming flow from right
# if self.__out > 0:
# u_loss += self._mse(torch.clamp(u[:self.__out], max=0), self.__u_out)
u_loss = self._mse(u, self.__u)
v_loss = self._mse(v, self.__v)

*_, f, g = self.__forward(self.__learning, True)

Expand Down

0 comments on commit 251cb81

Please sign in to comment.