-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
399 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,399 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:06.963098Z", | ||
"start_time": "2023-10-13T14:07:04.994889Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from gnn_tracking.training.tc import TCModule\n", | ||
"from pathlib import Path\n", | ||
"from torch.profiler import profile, record_function, ProfilerActivity\n", | ||
"\n", | ||
"# from object_condensation.pytorch.losses import condensation_loss" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"outputs": [], | ||
"source": [ | ||
"chkpt_home = Path(\n", | ||
" \"/home/kl5675/Documents/23/git_sync/hyperparameter_optimization2/scripts/pixel/lightning_logs/\"\n", | ||
")\n", | ||
"assert chkpt_home.is_dir()\n", | ||
"chkpt_path = (\n", | ||
" chkpt_home\n", | ||
" / \"vagabond-tasteful-hyrax/checkpoints_persist/epoch=451-step=406800.ckpt\"\n", | ||
")\n", | ||
"assert chkpt_path.is_file()\n", | ||
"data_home = Path(\n", | ||
" \"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/\"\n", | ||
")\n", | ||
"assert data_home.is_dir()\n", | ||
"data_path = data_home / \"part_1\" / \"data21000_s0.pt\"\n", | ||
"assert data_path.is_file()" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:06.963646Z", | ||
"start_time": "2023-10-13T14:07:06.955942Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"outputs": [], | ||
"source": [ | ||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:07.010300Z", | ||
"start_time": "2023-10-13T14:07:06.980576Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[36m[10:07:07] DEBUG: Getting class PreTrainedECGraphTCN from module gnn_tracking.models.track_condensation_networks\u001b[0m\n", | ||
"/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:196: UserWarning: Attribute 'hc_in' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['hc_in'])`.\n", | ||
" rank_zero_warn(\n", | ||
"\u001b[36m[10:07:07] DEBUG: Getting class MLGraphConstruction from module gnn_tracking.models.graph_construction\u001b[0m\n", | ||
"\u001b[36m[10:07:07] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction\u001b[0m\n", | ||
"\u001b[36m[10:07:07] DEBUG: Getting class PotentialLoss from module gnn_tracking.metrics.losses\u001b[0m\n", | ||
"\u001b[36m[10:07:07] DEBUG: Getting class DBSCANHyperParamScanner from module gnn_tracking.postprocessing.dbscanscanner\u001b[0m\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"lmodel = TCModule.load_from_checkpoint(chkpt_path, map_location=device)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:07.848961Z", | ||
"start_time": "2023-10-13T14:07:07.051270Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"outputs": [], | ||
"source": [ | ||
"data = torch.load(data_path)\n", | ||
"data.to(device)\n", | ||
"assert data\n", | ||
"model = lmodel.model" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:07.874688Z", | ||
"start_time": "2023-10-13T14:07:07.861655Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"outputs": [], | ||
"source": [ | ||
"dp = lmodel.preproc(data)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:08.110693Z", | ||
"start_time": "2023-10-13T14:07:07.863722Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"outputs": [], | ||
"source": [ | ||
"class MemLogger:\n", | ||
" def __init__(self):\n", | ||
" self.mem = 0\n", | ||
"\n", | ||
" def log(self, desc=\"\"):\n", | ||
" current = torch.cuda.memory_allocated() / 1e9\n", | ||
" added = current - self.mem\n", | ||
" print(f\"{desc:<30} added {added:>8.2f} GB, total {current:>8.2f} GB\")\n", | ||
" self.mem = current" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:08.115666Z", | ||
"start_time": "2023-10-13T14:07:08.108381Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"outputs": [], | ||
"source": [ | ||
"from torch import Tensor as T\n", | ||
"from torch.nn.functional import relu\n", | ||
"\n", | ||
"\n", | ||
"def condensation_loss(\n", | ||
" *,\n", | ||
" beta: T,\n", | ||
" x: T,\n", | ||
" object_id: T,\n", | ||
" weights: T,\n", | ||
" q_min: float,\n", | ||
" noise_threshold: int,\n", | ||
") -> dict[str, T]:\n", | ||
" # To protect against nan in divisions\n", | ||
" eps = 1e-9\n", | ||
"\n", | ||
" # x: n_nodes x n_outdim\n", | ||
" not_noise = object_id > noise_threshold\n", | ||
" unique_oids = torch.unique(object_id[not_noise])\n", | ||
" assert len(unique_oids) > 0, \"No particles found, cannot evaluate loss\"\n", | ||
" # n_nodes x n_pids\n", | ||
" # The nodes in every column correspond to the hits of a single particle and\n", | ||
" # should attract each other\n", | ||
" attractive_mask = object_id.view(-1, 1) == unique_oids.view(1, -1)\n", | ||
"\n", | ||
" q = torch.arctanh(beta) ** 2 + q_min\n", | ||
" assert not torch.isnan(q).any(), \"q contains NaNs\"\n", | ||
" # n_objs\n", | ||
" alphas = torch.argmax(q.view(-1, 1) * attractive_mask, dim=0)\n", | ||
"\n", | ||
" # _j means indexed by hits\n", | ||
" # _k means indexed by objects\n", | ||
"\n", | ||
" # n_objs x n_outdim\n", | ||
" x_k = x[alphas]\n", | ||
" # 1 x n_objs\n", | ||
" q_k = q[alphas].view(1, -1)\n", | ||
"\n", | ||
" dist_j_k = torch.cdist(x, x_k)\n", | ||
"\n", | ||
" qw_j_k = weights.view(-1, 1) * q.view(-1, 1) * q_k\n", | ||
"\n", | ||
" repulsive_mask = (~attractive_mask) & (dist_j_k < 1)\n", | ||
" # We have to include the hits-per-object normalization factor here, because\n", | ||
" # after applying the mask we only have a 1D tensor anymore\n", | ||
" qw_att_j_k = (qw_j_k / (attractive_mask.sum(dim=0) + eps))[attractive_mask]\n", | ||
" qw_rep_j_k = (qw_j_k / ((~attractive_mask).sum(dim=0) + eps))[repulsive_mask]\n", | ||
"\n", | ||
" # Attractive potential/loss\n", | ||
" v_att_j_k = qw_att_j_k * torch.square(dist_j_k)[attractive_mask]\n", | ||
" # It's important to directly do the .mean here so we don't keep these large\n", | ||
" # matrices in memory longer than we need them\n", | ||
" # Attractive potential per object normalized over number of hits in object\n", | ||
" v_att_k = torch.sum(v_att_j_k, dim=0)\n", | ||
" v_att = torch.sum(v_att_k) / len(unique_oids)\n", | ||
"\n", | ||
" # Repulsive potential/loss\n", | ||
" v_rep_j_k = qw_rep_j_k * (1 - dist_j_k[repulsive_mask])\n", | ||
" v_rep_k = torch.sum(v_rep_j_k, dim=0)\n", | ||
" v_rep = torch.sum(v_rep_k) / len(unique_oids)\n", | ||
"\n", | ||
" l_coward = torch.mean(1 - beta[alphas])\n", | ||
" l_noise = torch.mean(beta[~not_noise])\n", | ||
"\n", | ||
" return {\n", | ||
" \"attractive\": v_att,\n", | ||
" \"repulsive\": v_rep,\n", | ||
" \"coward\": l_coward,\n", | ||
" \"noise\": l_noise,\n", | ||
" }" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:08.117687Z", | ||
"start_time": "2023-10-13T14:07:08.112548Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": "{'attractive': tensor(1.7951, dtype=torch.float64),\n 'repulsive': tensor(1.9509, dtype=torch.float64),\n 'coward': tensor(0.2157, dtype=torch.float64),\n 'noise': tensor(0.7748, dtype=torch.float64)}" | ||
}, | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import sys\n", | ||
"\n", | ||
"sys.path.append(\"/home/kl5675/Documents/23/git_sync/object_condensation\")\n", | ||
"from tests.loss_test_cases import generate_test_data\n", | ||
"from tests.test_losses_torch import TorchCondensationMockData\n", | ||
"\n", | ||
"td = generate_test_data()\n", | ||
"\n", | ||
"td = TorchCondensationMockData.from_numpy(td)\n", | ||
"cl = condensation_loss(\n", | ||
" beta=td.beta.squeeze(),\n", | ||
" x=td.x,\n", | ||
" object_id=td.object_id.squeeze(),\n", | ||
" weights=td.weights.squeeze(),\n", | ||
" q_min=td.q_min,\n", | ||
" noise_threshold=0,\n", | ||
")\n", | ||
"cl" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:21.611605Z", | ||
"start_time": "2023-10-13T14:07:21.516071Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"{'attractive': tensor(1.7951, dtype=torch.float64), 'repulsive': tensor(1.9509, dtype=torch.float64), 'coward': tensor(0.2157, dtype=torch.float64), 'noise': tensor(0.7748, dtype=torch.float64)}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(cl, flush=True)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:22.476753Z", | ||
"start_time": "2023-10-13T14:07:22.361876Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"empty added 0.17 GB, total 0.17 GB\n", | ||
"model evaluated added 14.75 GB, total 14.92 GB\n", | ||
"loss evaluated added 3.99 GB, total 18.91 GB\n", | ||
"backward done evaluated added -18.74 GB, total 0.17 GB\n", | ||
"step done added 0.02 GB, total 0.18 GB\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"ml = MemLogger()\n", | ||
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", | ||
"ml.log(\"empty\")\n", | ||
"out = model(dp)\n", | ||
"ml.log(\"model evaluated\")\n", | ||
"loss = condensation_loss(\n", | ||
" beta=out[\"B\"],\n", | ||
" x=out[\"H\"],\n", | ||
" object_id=data.particle_id,\n", | ||
" q_min=0.1,\n", | ||
" noise_threshold=0,\n", | ||
" weights=torch.ones_like(data.particle_id),\n", | ||
")\n", | ||
"total_loss = loss[\"attractive\"] + loss[\"repulsive\"] + loss[\"noise\"] + loss[\"coward\"]\n", | ||
"ml.log(\"loss evaluated\")\n", | ||
"optimizer.zero_grad()\n", | ||
"total_loss.backward()\n", | ||
"ml.log(\"backward done evaluated\")\n", | ||
"optimizer.step()\n", | ||
"ml.log(\"step done\")" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-10-13T14:07:29.241361Z", | ||
"start_time": "2023-10-13T14:07:28.933839Z" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"empty added 0.14 GB, total 0.14 GB\n", | ||
"model evaluated added 14.76 GB, total 14.90 GB\n", | ||
"loss evaluated added 3.87 GB, total 18.77 GB\n", | ||
"backward done evaluated added -18.60 GB, total 0.17 GB\n", | ||
"step done added 0.02 GB, total 0.18 GB\n", | ||
"\n", | ||
"\n", | ||
"empty added 25.88 GB, total 25.88 GB\n", | ||
"model evaluated added 14.76 GB, total 40.63 GB\n", | ||
"loss evaluated added 14.18 GB, total 54.82 GB\n", | ||
"backward done evaluated added -28.93 GB, total 25.88 GB\n", | ||
"step done added 0.02 GB, total 25.90 GB" | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"start_time": "2023-10-13T14:07:08.514827Z" | ||
} | ||
} | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |