Skip to content

Commit

Permalink
CI: black 35, 36
Browse files Browse the repository at this point in the history
  • Loading branch information
mbackenkoehler committed Feb 28, 2024
1 parent b6ff928 commit 3538cba
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -369,26 +369,25 @@
"outputs": [],
"source": [
"class SelectTarget(BaseTransform):\n",
" \n",
" def __init__(self, target_index: int) -> None:\n",
" super().__init__()\n",
" self.target_index = target_index\n",
" \n",
"\n",
" def forward(self, data):\n",
" data.y = data.y[:, self.target_index]\n",
" return data\n",
"\n",
"\n",
"class TargetNormalization(BaseTransform):\n",
" \n",
" def __init__(self, mean: torch.Tensor, std: torch.Tensor) -> None:\n",
" super().__init__()\n",
" self.mean = mean\n",
" self.std = std\n",
" \n",
"\n",
" def forward(self, data):\n",
" data.y = (data.y - self.mean) / self.std\n",
" return data\n",
" \n",
"\n",
" @classmethod\n",
" def fit(cls, dataset, dim: int = 0, cat_dim: int = 0) -> \"TargetNormalization\":\n",
" y = torch.cat([dataset.get(i).y for i in range(len(dataset))], dim=cat_dim)\n",
Expand Down
21 changes: 11 additions & 10 deletions teachopencadd/talktorials/T036_e3_equivariant_gnn/talktorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -803,26 +803,25 @@
"outputs": [],
"source": [
"class SelectTarget(BaseTransform):\n",
" \n",
" def __init__(self, target_index: int) -> None:\n",
" super().__init__()\n",
" self.target_index = target_index\n",
" \n",
"\n",
" def forward(self, data):\n",
" data.y = data.y[:, self.target_index]\n",
" return data\n",
"\n",
"\n",
"class TargetNormalization(BaseTransform):\n",
" \n",
" def __init__(self, mean: torch.Tensor, std: torch.Tensor) -> None:\n",
" super().__init__()\n",
" self.mean = mean\n",
" self.std = std\n",
" \n",
"\n",
" def forward(self, data):\n",
" data.y = (data.y - self.mean) / self.std\n",
" return data\n",
" \n",
"\n",
" @classmethod\n",
" def fit(cls, dataset, dim: int = 0, cat_dim: int = 0) -> \"TargetNormalization\":\n",
" y = torch.cat([dataset.get(i).y for i in range(len(dataset))], dim=cat_dim)\n",
Expand Down Expand Up @@ -876,7 +875,7 @@
" self.test_split = self.shuffled_index[\n",
" int(self.num_examples * (train_ratio + val_ratio)) : self.num_examples\n",
" ]\n",
" \n",
"\n",
" def _dataset(self):\n",
" return QM9(\n",
" DATA,\n",
Expand All @@ -893,12 +892,14 @@
" DATA,\n",
" pre_filter=lambda data: num_heavy_atoms(data) < 9,\n",
" pre_transform=add_complete_graph_edge_index,\n",
" transform=Compose([normalize_target, select_target])\n",
" transform=Compose([normalize_target, select_target]),\n",
" )\n",
" \n",
" @cached_property \n",
"\n",
" @cached_property\n",
" def target(self):\n",
" return torch.cat([self.dataset.get(i).y[:, self.target_idx] for i in range(len(self.dataset))], dim=0)\n",
" return torch.cat(\n",
" [self.dataset.get(i).y[:, self.target_idx] for i in range(len(self.dataset))], dim=0\n",
" )\n",
"\n",
" def loader(self, split, **loader_kwargs) -> DataLoader:\n",
" dataset = self.dataset[split]\n",
Expand Down

0 comments on commit 3538cba

Please sign in to comment.