Skip to content

Commit

Permalink
#59 fixed notebook (and related files) to create figure 1b
Browse files Browse the repository at this point in the history
  • Loading branch information
MarionBWeinzierl committed Aug 18, 2023
1 parent c4ae572 commit 04ad24b
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 170 deletions.
60 changes: 32 additions & 28 deletions examples/jupyter-notebooks/generate-paper-figure-1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Code to generate Figure 2 "
"# Code to generate Figure 1 "
]
},
{
Expand All @@ -17,6 +17,7 @@
"sys.path.insert(1, os.path.join(os.getcwd() , '../../src/gz21_ocean_momentum'))\n",
"from utils import select_experiment, select_run\n",
"from analysis.utils import plot_dataset, GlobalPlotter, plot_training_subdomains\n",
"from data.utils import load_training_datasets\n",
"import mlflow\n",
"from mlflow.tracking import MlflowClient\n",
"import xarray as xr\n",
Expand Down Expand Up @@ -63,7 +64,7 @@
"ml_client = MlflowClient()\n",
"data_fname = ml_client.download_artifacts(run.run_id, 'forcing')\n",
"data = xr.open_zarr(data_fname)\n",
"data = data.rename(dict(xu_ocean='longitude', yu_ocean='latitude'))\n"
"#data = data.rename(dict(xu_ocean='longitude', yu_ocean='latitude'))"
]
},
{
Expand All @@ -74,15 +75,24 @@
"source": [
"from data.pangeo_catalog import get_patch\n",
"\n",
"%matplotlib notebook\n",
"run = select_run(experiment_ids=('497746281881301089'))\n",
"#import IPython \n",
"#from IPython.display import display, Javascript\n",
"#display(Javascript(\"window.IPython = window.IPython || {}\"))\n",
"\n",
"#run = select_run(experiment_ids=('497746281881301089'))\n",
"run_id = run.run_id\n",
"\n",
"from cartopy.crs import PlateCarree\n",
"from data.pangeo_catalog import get_patch, get_whole_data\n",
"from scipy.ndimage import gaussian_filter\n",
"from matplotlib import colors\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from importlib import reload\n",
"reload(plt)\n",
"\n",
"%matplotlib widget\n",
"#%matplotlib inline\n",
"\n",
"CATALOG_URL = 'https://raw.githubusercontent.com/pangeo-data/pangeo-datastore\\\n",
"/master/intake-catalogs/master.yaml'\n",
Expand Down Expand Up @@ -163,17 +173,18 @@
" None.\n",
"\n",
" \"\"\"\n",
"\n",
" fig = plt.figure()\n",
" projection = projection_cls(lon)\n",
" if ax is None:\n",
" ax = plt.axes(projection=projection)\n",
" mesh_x, mesh_y = np.meshgrid(u['longitude'], u['latitude'])\n",
" mesh_x, mesh_y = np.meshgrid(u['xu_ocean'], u['yu_ocean'])\n",
" if u is not None:\n",
" extra = self.mask.isel(longitude=slice(0, 10))\n",
" extra['longitude'] = extra['longitude'] + 360\n",
" mask = xr.concat((self.mask, extra), dim='longitude')\n",
" mask = mask.interp({k: u.coords[k] for k in ('longitude',\n",
" 'latitude')})\n",
" extra = self.mask.isel(xu_ocean=slice(0, 10))\n",
" extra['xu_ocean'] = extra['xu_ocean'] + 360\n",
" mask = xr.concat((self.mask, extra), dim='xu_ocean')\n",
" mask = mask.interp({k: u.coords[k] for k in ('xu_ocean',\n",
" 'yu_ocean')})\n",
" u = u * mask\n",
" im = ax.pcolormesh(mesh_x, mesh_y, u.values,\n",
" transform=PlateCarree(),\n",
Expand All @@ -187,10 +198,10 @@
" # \"Gray-out\" near continental locations\n",
" if self.margin > 0:\n",
" extra = self.borders.isel(longitude=slice(0, 10))\n",
" extra['longitude'] = extra['longitude'] + 360\n",
" borders = xr.concat((self.borders, extra), dim='longitude')\n",
" extra['xu_ocean'] = extra['xu_ocean'] + 360\n",
" borders = xr.concat((self.borders, extra), dim='xu_ocean')\n",
" borders = borders.interp({k: u.coords[k]\n",
" for k in ('longitude', 'latitude')})\n",
" for k in ('xu_ocean', 'yu_ocean')})\n",
" borders_cmap = colors.ListedColormap([borders_color, ])\n",
" ax.pcolormesh(mesh_x, mesh_y, borders, animated=animated,\n",
" transform=PlateCarree(), alpha=borders_alpha,\n",
Expand All @@ -199,10 +210,10 @@
" if self.ice:\n",
" ice = self._get_ice_border()\n",
" ice = xr.where(ice, 1., 0.)\n",
" ice = ice.interp({k: u.coords[k] for k in ('longitude',\n",
" 'latitude')})\n",
" ice = ice.interp({k: u.coords[k] for k in ('xu_ocean',\n",
" 'yu_ocean')})\n",
" ice = xr.where(ice != 0, 1., 0.)\n",
" ice = abs(ice.diff(dim='longitude')) + abs(ice.diff(dim='latitude'))\n",
" ice = abs(ice.diff(dim='xu_ocean')) + abs(ice.diff(dim='yu_ocean'))\n",
" ice = xr.where(ice != 0., 1, np.nan)\n",
" ice_cmap = colors.ListedColormap(['black', ])\n",
" ax.pcolormesh(mesh_x, mesh_y, ice, animated=animated,\n",
Expand Down Expand Up @@ -245,7 +256,7 @@
" mask = mask.coarsen(dict(xt_ocean=factor, yt_ocean=factor))\n",
" mask_ = mask.max()\n",
" mask_ = mask_.where(mask_ > 0.1)\n",
" mask_ = mask_.rename(dict(xt_ocean='longitude', yt_ocean='latitude'))\n",
" mask_ = mask_.rename(dict(xt_ocean='xu_ocean', yt_ocean='yu_ocean'))\n",
" return mask_.compute()\n",
"\n",
" @staticmethod\n",
Expand All @@ -254,8 +265,8 @@
" in the oceans. \"\"\"\n",
" temperature, _ = get_patch(CATALOG_URL, 1, None, 0,\n",
" 'surface_temp')\n",
" temperature = temperature.rename(dict(xt_ocean='longitude',\n",
" yt_ocean='latitude'))\n",
" temperature = temperature.rename(dict(xt_ocean='xu_ocean',\n",
" yt_ocean='yu_ocean'))\n",
" temperature = temperature['surface_temp'].isel(time=0)\n",
" ice = xr.where(temperature <= 0., True, False)\n",
" return ice\n",
Expand Down Expand Up @@ -296,7 +307,7 @@
"plotter = GlobalPlotter(cbar=True, margin=0)\n",
"plotter.x_ticks = np.arange(-150., 151., 50)\n",
"plotter.y_ticks = np.arange(-80., 81., 20)\n",
"plot_training_subdomains(run_id, plotter, bg_variable=data['usurf'].isel(time=0), facecolor='green', edgecolor='black', linewidth=2, fill=False, vmin=-0.5, vmax=0.5, lon=0., cmap=cmap_balance)\n"
"plot_training_subdomains(plotter, bg_variable=data['usurf'].isel(time=0), facecolor='green', edgecolor='black', linewidth=2, fill=False, vmin=-0.5, vmax=0.5, lon=0., cmap=cmap_balance)\n"
]
},
{
Expand All @@ -305,15 +316,8 @@
"metadata": {},
"outputs": [],
"source": [
"plt.savefig('figure2.jpg', dpi=250)"
"plt.savefig('figure1b.jpg', dpi=250)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
4 changes: 2 additions & 2 deletions examples/jupyter-notebooks/generate-paper-figure-6.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Code to generate Figure 1 "
"# Code to generate Figure 6 "
]
},
{
Expand Down Expand Up @@ -198,7 +198,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
"version": "3.11.4"
}
},
"nbformat": 4,
Expand Down
108 changes: 7 additions & 101 deletions examples/jupyter-notebooks/offline_test_SWM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,10 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "stuffed-ratio",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"To load the net from the paper, use the function load_paper_net().\n",
"env: MLFLOW_TRACKING_URI=/home/marion/workspace/gz21_ocean_momentum/examples/jupyter-notebooks/../../mlruns\n"
]
}
],
"outputs": [],
"source": [
"import mlflow\n",
"import xarray as xr\n",
Expand All @@ -29,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "failing-occasion",
"metadata": {},
"outputs": [],
Expand All @@ -48,106 +39,21 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "liable-amendment",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"568040818937566888 : test2\n",
"302083951441703666 : testrun\n",
"0 : Default\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"Select the id of an experiment: 568040818937566888\n"
]
}
],
"outputs": [],
"source": [
"exp_id, _ = select_experiment(default_selection='22')\n",
"runs=mlflow.search_runs(experiment_ids=(exp_id,))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "familiar-lucas",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"run_id 5766a8e8547e4785829b507784493848\n",
"experiment_id 568040818937566888\n",
"status FINISHED\n",
"artifact_uri file:///home/marion/workspace/gz21_ocean_momen...\n",
"start_time 2023-08-14 14:56:47.615000+00:00\n",
"end_time 2023-08-14 14:59:22.214000+00:00\n",
"metrics.test loss -0.90139\n",
"metrics.R2 0.447976\n",
"metrics.train loss -1.170045\n",
"metrics.Inf Norm 10.040836\n",
"params.print_every 20\n",
"params.source.experiment_id 568040818937566888\n",
"params.test_split 0.85\n",
"params.batchsize 4\n",
"params.n_epochs 200\n",
"params.run_id 2308027e434047899cef078149f98edb\n",
"params.exp_id 568040818937566888\n",
"params.weight_decay 0.00\n",
"params.n_epochs_actual 34\n",
"params.transformation_cls_name SoftPlusTransform\n",
"params.model_cls_name FullyCNN\n",
"params.submodel transform3\n",
"params.source.run_id 2308027e434047899cef078149f98edb\n",
"params.loss_cls_name HeteroskedasticGaussianLossV2\n",
"params.train_split 0.8\n",
"params.features_transform_cls_name None\n",
"params.targets_transform_cls_name None\n",
"params.learning_rate 0/5e-4/15/5e-5/30/5e-6\n",
"params.model_module_name models.models1\n",
"params.time_indices 0\n",
"params.long_max None\n",
"params.chunk_size None\n",
"params.lat_min None\n",
"params.global None\n",
"params.lat_max None\n",
"params.ntimes None\n",
"params.long_min None\n",
"params.factor None\n",
"params.CO2 None\n",
"tags.mlflow.source.type PROJECT\n",
"tags.mlflow.runName amazing-squirrel-195\n",
"tags.mlflow.project.backend local\n",
"tags.mlflow.source.git.repoURL [email protected]:m2lines/gz21_ocean_momentum.git\n",
"tags.mlflow.source.git.commit f83a5c8f81cea6ffe84f38f12ae31689918b18e6\n",
"tags.mlflow.gitRepoURL [email protected]:m2lines/gz21_ocean_momentum.git\n",
"tags.mlflow.source.name file:///home/marion/workspace/gz21_ocean_momentum\n",
"tags.mlflow.user marion\n",
"tags.mlflow.project.entryPoint train\n",
"Name: 0, dtype: object\n"
]
},
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'file:///home/marion/workspace/gz21_ocean_momentum/mlruns/568040818937566888/5766a8e8547e4785829b507784493848/artifacts'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m datasets \u001b[38;5;241m=\u001b[39m \u001b[43mload_data_from_run\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[2], line 4\u001b[0m, in \u001b[0;36mload_data_from_run\u001b[0;34m(i_run)\u001b[0m\n\u001b[1;32m 2\u001b[0m run \u001b[38;5;241m=\u001b[39m runs\u001b[38;5;241m.\u001b[39miloc[i_run]\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(run)\n\u001b[0;32m----> 4\u001b[0m filenames \u001b[38;5;241m=\u001b[39m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlistdir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrun\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43martifact_uri\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m datasets \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fn \u001b[38;5;129;01min\u001b[39;00m filenames:\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'file:///home/marion/workspace/gz21_ocean_momentum/mlruns/568040818937566888/5766a8e8547e4785829b507784493848/artifacts'"
]
}
],
"outputs": [],
"source": [
"datasets = load_data_from_run(0)"
]
Expand Down
4 changes: 2 additions & 2 deletions examples/jupyter-notebooks/train_results.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"time = 50\n",
"\n",
"# Prompts the user to select a trained model\n",
"mlflow.set_tracking_uri('/scratch/ag7531/mlruns')\n",
"mlflow.set_tracking_uri(os.path.join(os.getcwd(), '../../mlruns'))\n",
"cols = ['params.model_cls_name', 'params.loss_cls_name']\n",
"exp_id, _ =select_experiment()\n",
"run = select_run(sort_by='metrics.test loss', cols=cols, experiment_ids=[exp_id,])\n",
Expand Down Expand Up @@ -582,7 +582,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
"version": "3.11.4"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 04ad24b

Please sign in to comment.