Skip to content

Commit

Permalink
Add plot logging for each experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
SylviaWhittle committed May 31, 2024
1 parent 64eccdc commit de8c0dc
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .dvc/config
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[core]
remote = mygdrive
[hydra]
enabled = False
['remote "myremote"']
url = /Users/sylvi/Documents/dvc_remotes/remote-catsnet/
['remote "mygdrive"']
Expand Down
27 changes: 11 additions & 16 deletions dvc.lock
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,20 @@ stages:
train:
train_data_dir: data/train/
model_save_dir: models/
batch_size: 6
epochs: 5
batch_size: 2
epochs: 50
norm_upper_bound: 7
norm_lower_bound: -1
validation_split: 0.2
outs:
- path: models/catsnet_model.keras
hash: md5
md5: 9c5dd66ef65490711024563e8d6c604b
md5: 343a08cec07288294b247e83355a616f
size: 23454939
- path: results/train/
hash: md5
md5: 60281b02178ac14a7f481c3780f8af2b.dir
size: 688
md5: 828dbc21d8b42f881955251ca233b4b0.dir
size: 4736
nfiles: 5
evaluate:
cmd: python src/evaluate.py
Expand All @@ -78,12 +78,12 @@ stages:
nfiles: 32
- path: models/catsnet_model.keras
hash: md5
md5: 9c5dd66ef65490711024563e8d6c604b
md5: 343a08cec07288294b247e83355a616f
size: 23454939
- path: src/evaluate.py
hash: md5
md5: 25a98179bf20b7d153d5b31025f8a39f
size: 5865
md5: 925f98789798e9920bf396e614d4b29b
size: 5764
params:
params.yaml:
base:
Expand All @@ -95,11 +95,6 @@ stages:
outs:
- path: results/evaluate/
hash: md5
md5: 96b1905ca851ac5f36b0c0679fd19630.dir
size: 42
nfiles: 1
- path: results/evaluate_plots/
hash: md5
md5: 5ed50b6928005885d5f6fa0ae159c421.dir
size: 1714068
nfiles: 8
md5: 42e330087e2dd0b2297e3df41be8d2c2.dir
size: 1700653
nfiles: 9
2 changes: 1 addition & 1 deletion dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ stages:
- evaluate
outs:
- results/evaluate/
- results/evaluate_plots/
metrics:
- results/train/metrics.json
- results/evaluate/metrics.json
plots:
- results/train/plots/metrics:
x: step
- results/evaluate/plots/images
artifacts:
catsnet_model:
path: models/catsnet_model.keras
Expand Down
4 changes: 2 additions & 2 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ data_split:
train:
train_data_dir: data/train/
model_save_dir: models/
batch_size: 6
epochs: 5
batch_size: 2
epochs: 50
norm_upper_bound: 7
norm_lower_bound: -1
validation_split: 0.2
Expand Down
7 changes: 3 additions & 4 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def evaluate(
dice_score = dice(mask_predicted, mask)
dice_multi += dice_score / len(image_indexes)

# Plot the image, mask and predicted mask and save it to the results/evaluate/image_plots directory
plot_save_dir = Path("results/evaluate_plots")
plot_save_dir.mkdir(parents=True, exist_ok=True)
# Plot the image, mask and predicted mask and log it
num_channels = mask_predicted.shape[-1]
fig, ax = plt.subplots(num_channels, 3, figsize=(15, 5))
if num_channels == 1:
Expand All @@ -121,7 +119,8 @@ def evaluate(
ax[i, 1].set_title(f"Ground Truth Mask Channel {i}")
ax[i, 2].imshow(mask_predicted[:, :, i], cmap="binary")
ax[i, 2].set_title(f"Predicted Mask Channel {i}")
plt.savefig(f"{plot_save_dir}/test_image_{index}.png")
# plt.savefig(f"{plot_save_dir}/test_image_{index}.png")
live.log_image(f"test_image_plot_{index}.png", fig)

live.summary["dice_multi"] = dice_multi

Expand Down

0 comments on commit de8c0dc

Please sign in to comment.