Skip to content

Commit

Permalink
Merge pull request #7 from cy-xu/feature/add_models_to_repo
Browse files Browse the repository at this point in the history
banzai integration, logging, model wrap
  • Loading branch information
cy-xu authored Jan 4, 2022
2 parents ca28740 + 86c0d97 commit a4c5bf0
Show file tree
Hide file tree
Showing 16 changed files with 448 additions and 59 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
*.fits
*.fz
*.zip
*.pth
*.lprof
*.out
*.pdf
Expand All @@ -12,9 +11,7 @@
*.csv

# ignored dirs
trained_models
cosmic_conn/data
cosmic_conn/trained_models
demo_data
slurm
checkpoints*
Expand Down
9 changes: 9 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
History
=======

0.2.8 (2022-01-04)
- added `-c` option to CLI to specify crop size for stamp detection
- stop using memory_check(), which is not robust on server nodes
- moved messages to logger, stdout turned on only for CLI users
- removed trained models' DataParallel wrapper
- new threshold-based plots for BANZAI integration

------------------

0.2.7 (2021-12-03)
- Trained models added to git repository

Expand Down
2 changes: 1 addition & 1 deletion cosmic_conn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

__author__ = """Chengyuan Xu, Curtis McCully, Boning Dong, D. Andrew Howell, and Pradeep Sen"""
__email__ = "[email protected]"
__version__ = "0.2.7"
__version__ = "0.2.8"

from cosmic_conn.inference_cr import init_model, detect_image, detect_FITS
8 changes: 8 additions & 0 deletions cosmic_conn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def console_arguments():
default="SCI",
help="read data from this hdul extension, SCI by default."
)
parser.add_argument(
"-c",
"--crop",
type=int,
default=1024,
help="slice the image to stamps of this size, 1024 by default."
"Set to 0 for full image detection, large memory required."
)

opt = parser.parse_args()
opt = vars(opt)
Expand Down
50 changes: 27 additions & 23 deletions cosmic_conn/dl_framework/cosmic_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import math
import random
import logging

import torch
import torch.nn as nn
Expand Down Expand Up @@ -46,11 +47,12 @@ def initialize(self, opt):

if torch.cuda.is_available():
self.device = torch.device("cuda")
print("...GPU found, yeah!")
logging.info("...GPU found, yeah!")

else:
self.device = torch.device("cpu")
print("...GPU or CUDA not detected, using CPU (slower). ")
print("...training on CPU is not recommended.")
logging.info("...GPU or CUDA not detected, using CPU (slower). ")
logging.info("...training on CPU is not recommended.")

self.model_dir = os.path.join(opt.expr_dir, "models")
self.valid_dir = os.path.join(opt.expr_dir, "validation")
Expand All @@ -69,8 +71,7 @@ def build_models(self, opt):
norm_setting = [opt.n_group, opt.gn_channel, opt.no_affine]

# the network is defined here
self.network = nn.DataParallel(
UNet_module(
self.network = UNet_module(
n_channels=1,
n_classes=1,
hidden=opt.hidden,
Expand All @@ -81,7 +82,6 @@ def build_models(self, opt):
up_type=opt.up_type,
deeper=opt.deeper,
)
)

self.network.to(self.device)

Expand Down Expand Up @@ -153,7 +153,7 @@ def build_models(self, opt):
self.network.load_state_dict(checkpoint["state_dict_mask"])

# use available memory to determine detection method
self.full_image_detection = memory_check(self.device)
# self.full_image_detection = memory_check(self.device)

elif opt.mode == "train" and opt.continue_train:
# initialize with previously saved model, but new optimizer
Expand Down Expand Up @@ -212,31 +212,37 @@ def detect_full_image(self, image):
mask[: pdt.shape[0], : pdt.shape[1]] = pdt
return mask

def detect_image_stamps(self, image, verbose=True):
def detect_image_stamps(self, image, crop=1024):
# if not enough memory, detect in smaller stamps
mask = None
stamp_sizes = [1024, 512, 256]

# by defualt we use smaller stamp size as memory safeguard
stamp_sizes = [1024, 512, 256] if crop==1024 else [crop]

for stamp in stamp_sizes:
try:
if verbose:
print(
f"Slicing image to {stamp}x{stamp} stamps...",
)
msg = f"Slicing image to {stamp}x{stamp} stamps..."
logging.info(msg)

torch.cuda.empty_cache()

mask = clean_large(
image, self.network, patch=stamp, overlap=0
)
except:
print(f"...{stamp}x{stamp} stamp won't fit into memory.")
logging.warning(f"...{stamp}x{stamp} stamp won't fit into memory.")

if mask is not None:
break

if mask is None:
msg = "...detection failed. Memory too small?"
logging.error(msg)
raise ValueError(msg)

assert mask is not None, f"...detection failed.\n Memory too small?"
return mask

# @torch.jit.export
def detect_cr(self, image, ret_numpy=True):
# replace NaN with 0.0 if exist
image = remove_nan(image)
Expand All @@ -247,14 +253,12 @@ def detect_cr(self, image, ret_numpy=True):

# no gradient saved during inference
with torch.no_grad():
if self.full_image_detection:
try:
# on CPU, instance might be killed without raising error
mask = self.detect_full_image(image)
except:
mask = self.detect_image_stamps(image)
if self.opt.crop == 0:
# full image detection as user requested
mask = self.detect_full_image(image)
else:
mask = self.detect_image_stamps(image)
# slice image into smaller stamps
mask = self.detect_image_stamps(image, self.opt.crop)

if ret_numpy:
return tensor2np(mask)
Expand Down Expand Up @@ -416,7 +420,7 @@ def validate(self, valid_loader, epoch, source):
pdt_masks = torch.zeros_like(frames)

for j in range(b):
pdt = self.detect_image_stamps(frames[j].squeeze(), False)
pdt = self.detect_image_stamps(frames[j].squeeze())
pdt_masks[j] = pdt.unsqueeze(dim=0)

# pdt_masks = self.detect_image_stamps(frames)
Expand Down
1 change: 1 addition & 0 deletions cosmic_conn/dl_framework/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self):
self.__dict__["app"] = False
self.__dict__["input"] = "./"
self.__dict__["ext"] = "SCI"
self.__dict__["verbose"] = False


def __setitem__(self, key, value):
Expand Down
15 changes: 10 additions & 5 deletions cosmic_conn/dl_framework/utils_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import os
import math
import logging
import datetime
import random
import psutil
Expand Down Expand Up @@ -40,9 +41,11 @@ def memory_check(device):
GPU_THRESHOLD = 8 * (1024**3) # 8 GB free memory

if str(device) == 'cpu':
# on CPU, available memory
available_memory = psutil.virtual_memory()[1]
full_image_detection = available_memory > CPU_THRESHOLD
# available memory on CPU is not reliable, defult to 1024 stamps
# available_memory = psutil.virtual_memory()[1]
# full_image_detection = available_memory > CPU_THRESHOLD
full_image_detection = False

else:
# GPU available memory
t = torch.cuda.get_device_properties(device).total_memory
Expand All @@ -52,8 +55,10 @@ def memory_check(device):
full_image_detection = t > GPU_THRESHOLD

if not full_image_detection:
print(f"...available memory not sufficient for whole image detection.")
print(f"...image will be sliced into stamps.")
msg = f"...available memory not sufficient for whole image detection."
msg2 = f"...image will be sliced into stamps."
logging.warning(msg)
logging.warning(msg2)

return full_image_detection

Expand Down
Loading

0 comments on commit a4c5bf0

Please sign in to comment.