From 3f99c060dd84dbd52b11fe191be9116e753a9c49 Mon Sep 17 00:00:00 2001 From: Charlie Meyers Date: Thu, 30 Nov 2023 21:00:12 +0000 Subject: [PATCH] fixed train time bug --- deckard/base/model/model.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/deckard/base/model/model.py b/deckard/base/model/model.py index 5e96557a..c4f4d643 100644 --- a/deckard/base/model/model.py +++ b/deckard/base/model/model.py @@ -130,7 +130,7 @@ def __call__(self, data: list, model: object, library=None): start = process_time_ns() start_timestamp = time() model.fit(data[0], data[2], **trainer) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() except np.AxisError: # pragma: no cover from art.utils import to_categorical @@ -139,7 +139,7 @@ def __call__(self, data: list, model: object, library=None): start = process_time_ns() start_timestamp = time() model.fit(data[0], data[2], **trainer) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() except ValueError as e: # pragma: no cover if "Shape of labels" in str(e): @@ -150,7 +150,7 @@ def __call__(self, data: list, model: object, library=None): start = process_time_ns() start_timestamp = time() model.fit(data[0], data[2], **trainer) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() else: raise e @@ -162,7 +162,7 @@ def __call__(self, data: list, model: object, library=None): start = process_time_ns() start_timestamp = time() model.fit(data[0], data[2], **trainer) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() except Exception as e: raise e @@ -174,7 +174,7 @@ def __call__(self, data: list, model: object, library=None): start = process_time_ns() start_timestamp = time() model.fit(data[0], data[2], **trainer) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() elif "should be the same" in str(e).lower(): import torch @@ -194,7 +194,7 @@ def __call__(self, data: list, model: object, library=None): start = process_time_ns() start_timestamp = time() model.fit(data[0], data[2], **trainer) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() else: raise e @@ -561,7 +561,7 @@ def predict(self, data=None, model=None, predictions_file=None): start = process_time_ns() start_timestamp = time() predictions = model.predict(data[1]) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() except NotFittedError as e: # pragma: no cover logger.warning(e) @@ -579,7 +579,7 @@ def predict(self, data=None, model=None, predictions_file=None): except Exception as e: # pragma: no cover logger.error(e) raise e - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() if predictions_file is not None: self.data.save(predictions, predictions_file) @@ -627,13 +627,13 @@ def predict_proba(self, data=None, model=None, probabilities_file=None): start = process_time_ns() start_timestamp = time() predictions = model.predict_proba(data[1]) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() else: start = process_time_ns() start_timestamp = time() predictions = model.predict(data[1]) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() if probabilities_file is not None: self.data.save(predictions, probabilities_file) @@ -680,19 +680,19 @@ def predict_log_loss(self, data, model, losses_file=None): start = process_time_ns() start_timestamp = time() predictions = model.predict_log_proba(data[1]) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() elif hasattr(model, "predict_proba"): start = process_time_ns() start_timestamp = time() predictions = model.predict_proba(data[1]) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() elif hasattr(model, "predict"): start = process_time_ns() start_timestamp = time() predictions = model.predict(data[1]) - end = process_time_ns() - start + end = process_time_ns() end_timestamp = time() else: # pragma: no cover raise ValueError(