Skip to content

Commit

Permalink
Update predictor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
RRobert92 committed Dec 28, 2024
1 parent cb518e5 commit d471519
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tardis_em/utils/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,7 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader):

# Pick image['s]
input_, name = dataloader.__getitem__(j)
start, end = 0, 0

if j == 0:
start = time.time()
Expand All @@ -814,7 +815,6 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader):

# Scale progress bar refresh to 10s
end = time.time()
eta_time = str(round(((end - start) * (len(dataloader) - j)) / 60, 1)) + "min"

iter_time = 10 // (end - start)
if iter_time <= 1:
Expand All @@ -823,6 +823,7 @@ def predict_cnn(self, id_i: int, id_name: str, dataloader):
# Predict
input_ = self.cnn.predict(input_[None, :], rotate=self.rotate)

eta_time = str(round(((end - start) * (len(dataloader) - j)) / 60, 1)) + "min"
tif.imwrite(join(self.output, f"{name}.tif"), input_)

def predict_cnn_napari(self, input_t: torch.Tensor, name: str):
Expand Down

0 comments on commit d471519

Please sign in to comment.