Skip to content

Commit

Permalink
eta time
Browse files Browse the repository at this point in the history
  • Loading branch information
RRobert92 committed Dec 28, 2024
1 parent 550a977 commit cb518e5
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions tardis_em/utils/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(
self.device = get_device(device_s)
self.debug = debug
self.semantic_header, self.instance_header, self.log_prediction = [], [], []
self.eta_predict = "NA"

"""Initial Setup"""
if debug:
Expand Down Expand Up @@ -786,15 +787,16 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader):
pred_title = f"CNN prediction with four 90 degree rotations with {self.convolution_nn}"
else:
pred_title = f"CNN prediction with {self.convolution_nn}"
eta_time = "NA"

for j in range(len(dataloader)):
if j % iter_time == 0 and self.tardis_logo:
# Tardis progress bar update
self.tardis_progress(
title=self.title,
text_1=f"Found {len(self.predict_list)} images to predict!",
text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]",
text_2=f"Device: {self.device}",
text_3=f"Image {id_i + 1}/{len(self.predict_list)}: {id_name}",
text_3=f"Image {id_i + 1}/{len(self.predict_list)} [{eta_time} ETA]: {id_name}",
text_4=f"Org. Pixel size: {self.px} A | Norm. Pixel size: {self.normalize_px}",
text_5=pred_title,
text_7="Current Task: CNN prediction...",
Expand All @@ -812,6 +814,8 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader):

# Scale progress bar refresh to 10s
end = time.time()
eta_time = str(round(((end - start) * (len(dataloader) - j)) / 60, 1)) + "min"

iter_time = 10 // (end - start)
if iter_time <= 1:
iter_time = 1
Expand Down Expand Up @@ -965,7 +969,7 @@ def predict_DIST(self, id_i: int, id_name: str):
if id_dist % iter_time == 0 and self.tardis_logo:
self.tardis_progress(
title=self.title,
text_1=f"Found {len(self.predict_list)} images to predict!",
text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]",
text_2=f"Device: {self.device}",
text_3=f"Image {id_i + 1}/{len(self.predict_list)}: {id_name}",
text_4=f"Org. Pixel size: {self.px} A | Norm. Pixel size: {self.normalize_px}",
Expand All @@ -981,7 +985,7 @@ def predict_DIST(self, id_i: int, id_name: str):
if self.tardis_logo:
self.tardis_progress(
title=self.title,
text_1=f"Found {len(self.predict_list)} images to predict!",
text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]",
text_2=f"Device: {self.device}",
text_3=f"Image {id_i + 1}/{len(self.predict_list)}: {id_name}",
text_4=f"Org. Pixel size: {self.px} A | Norm. Pixel size: {self.normalize_px}",
Expand Down Expand Up @@ -1064,7 +1068,7 @@ def postprocess_DIST(self, id_i, i):

self.tardis_progress(
title=self.title,
text_1=f"Found {len(self.predict_list)} images to predict!",
text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]",
text_2=f"Device: {self.device}",
text_3=f"Image {id_i + 1}/{len(self.predict_list)}: {i}",
text_4=f"Org. Pixel size: {self.px} A | Norm. Pixel size: {self.normalize_px}",
Expand Down Expand Up @@ -1155,7 +1159,7 @@ def get_file_list(self):
else:
self.tardis_progress(
title=self.title,
text_1=f"Found {len(self.predict_list)} images to predict!",
text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]",
text_2=f"Device: {self.device}",
text_7="Current Task: Setting-up environment...",
)
Expand Down Expand Up @@ -1190,7 +1194,7 @@ def log_tardis(self, id_i: int, i: Union[str, np.ndarray], log_id: float):

# Common text for all configurations
common_text = {
"text_1": f"Found {len(self.predict_list)} images to predict!",
"text_1": f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]",
"text_2": f"Device: {self.device}",
"text_3": f"Image {id_i + 1}/{len(self.predict_list)}: {i}",
}
Expand Down Expand Up @@ -1645,6 +1649,8 @@ def __call__(self, save_progres=False):

semantic_output, instance_output, instance_filter_output = [], [], []
for id_, i in enumerate(self.predict_list):
start_predict = time.time()

"""CNN Pre-Processing"""
if isinstance(i, str):
# Find a file format
Expand Down Expand Up @@ -1817,6 +1823,8 @@ def __call__(self, save_progres=False):

"""Clean-up temp dir"""
clean_up(dir_s=self.dir)
end_predict = time.time()
self.eta_predict = str(round((end_predict - start_predict) / 60, 2)) + " min"

"""Optional return"""
if self.output_format.startswith("return"):
Expand Down

0 comments on commit cb518e5

Please sign in to comment.