From cb518e5347b9c6d04e585686230b2a250ef4439d Mon Sep 17 00:00:00 2001 From: Robert Kiewisz <56911280+RRobert92@users.noreply.github.com> Date: Sat, 28 Dec 2024 18:12:07 +0100 Subject: [PATCH] eta time --- tardis_em/utils/predictor.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tardis_em/utils/predictor.py b/tardis_em/utils/predictor.py index 44563275..d1a33615 100644 --- a/tardis_em/utils/predictor.py +++ b/tardis_em/utils/predictor.py @@ -188,6 +188,7 @@ def __init__( self.device = get_device(device_s) self.debug = debug self.semantic_header, self.instance_header, self.log_prediction = [], [], [] + self.eta_predict = "NA" """Initial Setup""" if debug: @@ -786,15 +787,16 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader): pred_title = f"CNN prediction with four 90 degree rotations with {self.convolution_nn}" else: pred_title = f"CNN prediction with {self.convolution_nn}" + eta_time = "NA" for j in range(len(dataloader)): if j % iter_time == 0 and self.tardis_logo: # Tardis progress bar update self.tardis_progress( title=self.title, - text_1=f"Found {len(self.predict_list)} images to predict!", + text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]", text_2=f"Device: {self.device}", - text_3=f"Image {id_i + 1}/{len(self.predict_list)}: {id_name}", + text_3=f"Image {id_i + 1}/{len(self.predict_list)} [{eta_time} ETA]: {id_name}", text_4=f"Org. Pixel size: {self.px} A | Norm. Pixel size: {self.normalize_px}", text_5=pred_title, text_7="Current Task: CNN prediction...", @@ -812,6 +814,8 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader): # Scale progress bar refresh to 10s end = time.time() + eta_time = str(round(((end - start) * (len(dataloader) - j)) / 60, 1)) + "min" + iter_time = 10 // (end - start) if iter_time <= 1: iter_time = 1 @@ -965,7 +969,7 @@ def predict_DIST(self, id_i: int, id_name: str): if id_dist % iter_time == 0 and self.tardis_logo: self.tardis_progress( title=self.title, - text_1=f"Found {len(self.predict_list)} images to predict!", + text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]", text_2=f"Device: {self.device}", text_3=f"Image {id_i + 1}/{len(self.predict_list)}: {id_name}", text_4=f"Org. Pixel size: {self.px} A | Norm. Pixel size: {self.normalize_px}", @@ -981,7 +985,7 @@ def predict_DIST(self, id_i: int, id_name: str): if self.tardis_logo: self.tardis_progress( title=self.title, - text_1=f"Found {len(self.predict_list)} images to predict!", + text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]", text_2=f"Device: {self.device}", text_3=f"Image {id_i + 1}/{len(self.predict_list)}: {id_name}", text_4=f"Org. Pixel size: {self.px} A | Norm. Pixel size: {self.normalize_px}", @@ -1064,7 +1068,7 @@ def postprocess_DIST(self, id_i, i): self.tardis_progress( title=self.title, - text_1=f"Found {len(self.predict_list)} images to predict!", + text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]", text_2=f"Device: {self.device}", text_3=f"Image {id_i + 1}/{len(self.predict_list)}: {i}", text_4=f"Org. Pixel size: {self.px} A | Norm. Pixel size: {self.normalize_px}", @@ -1155,7 +1159,7 @@ def get_file_list(self): else: self.tardis_progress( title=self.title, - text_1=f"Found {len(self.predict_list)} images to predict!", + text_1=f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]", text_2=f"Device: {self.device}", text_7="Current Task: Setting-up environment...", ) @@ -1190,7 +1194,7 @@ def log_tardis(self, id_i: int, i: Union[str, np.ndarray], log_id: float): # Common text for all configurations common_text = { - "text_1": f"Found {len(self.predict_list)} images to predict!", + "text_1": f"Found {len(self.predict_list)} images to predict! [{self.eta_predict} ETA]", "text_2": f"Device: {self.device}", "text_3": f"Image {id_i + 1}/{len(self.predict_list)}: {i}", } @@ -1645,6 +1649,8 @@ def __call__(self, save_progres=False): semantic_output, instance_output, instance_filter_output = [], [], [] for id_, i in enumerate(self.predict_list): + start_predict = time.time() + """CNN Pre-Processing""" if isinstance(i, str): # Find a file format @@ -1817,6 +1823,8 @@ def __call__(self, save_progres=False): """Clean-up temp dir""" clean_up(dir_s=self.dir) + end_predict = time.time() + self.eta_predict = str(round((end_predict - start_predict) / 60, 2)) + " min" """Optional return""" if self.output_format.startswith("return"):