Skip to content

Commit

Permalink
add configs for box models with different losses
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Dec 12, 2023
1 parent 0dce68f commit 29fc9d4
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

configs/
#configs/
23 changes: 9 additions & 14 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,19 @@ class ChebaiBaseNet(LightningModule):
def __init__(
self,
criterion: torch.nn.Module = None,
out_dim: Optional[int] = None,
train_metrics: Optional[torch.nn.Module] = None,
val_metrics: Optional[torch.nn.Module] = None,
test_metrics: Optional[torch.nn.Module] = None,
out_dim=None,
metrics: Optional[Dict[str, torch.nn.Module]] = None,
pass_loss_kwargs=True,
**kwargs,
):
super().__init__()
self.criterion = criterion
self.save_hyperparameters(ignore=["criterion", "train_metrics", "val_metrics", "test_metrics"])
self.save_hyperparameters(ignore=["criterion"])
self.out_dim = out_dim
self.optimizer_kwargs = kwargs.get("optimizer_kwargs", dict())
self.train_metrics = train_metrics
self.validation_metrics = val_metrics
self.test_metrics = test_metrics
self.train_metrics = metrics["train"]
self.validation_metrics = metrics["validation"]
self.test_metrics = metrics["test"]
self.pass_loss_kwargs = pass_loss_kwargs

def __init_subclass__(cls, **kwargs):
Expand All @@ -41,13 +39,10 @@ def __init_subclass__(cls, **kwargs):
def _get_prediction_and_labels(self, data, labels, output):
return output, labels

def _process_labels_in_batch(self, batch):
return batch.y.float()

def _process_batch(self, batch, batch_idx):
return dict(
features=batch.x,
labels=self._process_labels_in_batch(batch),
labels=batch.y.float(),
model_kwargs=batch.additional_fields["model_kwargs"],
loss_kwargs=batch.additional_fields["loss_kwargs"],
idents=batch.additional_fields["idents"],
Expand Down Expand Up @@ -87,13 +82,13 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
loss_kwargs = dict()
if self.pass_loss_kwargs:
loss_kwargs = loss_kwargs_candidates
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
loss = self.criterion(loss_data, loss_labels, model=self, **loss_kwargs)
d["loss"] = loss
self.log(
f"{prefix}loss",
loss.item(),
batch_size=batch.x.shape[0],
on_step=True,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
Expand Down
Loading

0 comments on commit 29fc9d4

Please sign in to comment.