From d471519046032f45d1670b1f972e39a73630f8a2 Mon Sep 17 00:00:00 2001 From: Robert Kiewisz <56911280+RRobert92@users.noreply.github.com> Date: Sat, 28 Dec 2024 18:14:13 +0100 Subject: [PATCH] Update predictor.py --- tardis_em/utils/predictor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tardis_em/utils/predictor.py b/tardis_em/utils/predictor.py index d1a3361..dcece6e 100644 --- a/tardis_em/utils/predictor.py +++ b/tardis_em/utils/predictor.py @@ -805,6 +805,7 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader): # Pick image['s] input_, name = dataloader.__getitem__(j) + start, end = 0, 0 if j == 0: start = time.time() @@ -814,7 +815,6 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader): # Scale progress bar refresh to 10s end = time.time() - eta_time = str(round(((end - start) * (len(dataloader) - j)) / 60, 1)) + "min" iter_time = 10 // (end - start) if iter_time <= 1: @@ -823,6 +823,7 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader): # Predict input_ = self.cnn.predict(input_[None, :], rotate=self.rotate) + eta_time = str(round(((end - start) * (len(dataloader) - j)) / 60, 1)) + "min" tif.imwrite(join(self.output, f"{name}.tif"), input_) def predict_cnn_napari(self, input_t: torch.Tensor, name: str):