Skip to content

Commit

Permalink
fixed train time bug (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
simplymathematics authored Nov 30, 2023
1 parent ecf3bf5 commit aef248e
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions deckard/base/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __call__(self, data: list, model: object, library=None):
start = process_time_ns()
start_timestamp = time()
model.fit(data[0], data[2], **trainer)
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
except np.AxisError: # pragma: no cover
from art.utils import to_categorical
Expand All @@ -139,7 +139,7 @@ def __call__(self, data: list, model: object, library=None):
start = process_time_ns()
start_timestamp = time()
model.fit(data[0], data[2], **trainer)
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
except ValueError as e: # pragma: no cover
if "Shape of labels" in str(e):
Expand All @@ -150,7 +150,7 @@ def __call__(self, data: list, model: object, library=None):
start = process_time_ns()
start_timestamp = time()
model.fit(data[0], data[2], **trainer)
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
else:
raise e
Expand All @@ -162,7 +162,7 @@ def __call__(self, data: list, model: object, library=None):
start = process_time_ns()
start_timestamp = time()
model.fit(data[0], data[2], **trainer)
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
except Exception as e:
raise e
Expand All @@ -174,7 +174,7 @@ def __call__(self, data: list, model: object, library=None):
start = process_time_ns()
start_timestamp = time()
model.fit(data[0], data[2], **trainer)
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
elif "should be the same" in str(e).lower():
import torch
Expand All @@ -194,7 +194,7 @@ def __call__(self, data: list, model: object, library=None):
start = process_time_ns()
start_timestamp = time()
model.fit(data[0], data[2], **trainer)
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
else:
raise e
Expand Down Expand Up @@ -561,7 +561,7 @@ def predict(self, data=None, model=None, predictions_file=None):
start = process_time_ns()
start_timestamp = time()
predictions = model.predict(data[1])
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
except NotFittedError as e: # pragma: no cover
logger.warning(e)
Expand All @@ -579,7 +579,7 @@ def predict(self, data=None, model=None, predictions_file=None):
except Exception as e: # pragma: no cover
logger.error(e)
raise e
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
if predictions_file is not None:
self.data.save(predictions, predictions_file)
Expand Down Expand Up @@ -627,13 +627,13 @@ def predict_proba(self, data=None, model=None, probabilities_file=None):
start = process_time_ns()
start_timestamp = time()
predictions = model.predict_proba(data[1])
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
else:
start = process_time_ns()
start_timestamp = time()
predictions = model.predict(data[1])
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
if probabilities_file is not None:
self.data.save(predictions, probabilities_file)
Expand Down Expand Up @@ -680,19 +680,19 @@ def predict_log_loss(self, data, model, losses_file=None):
start = process_time_ns()
start_timestamp = time()
predictions = model.predict_log_proba(data[1])
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
elif hasattr(model, "predict_proba"):
start = process_time_ns()
start_timestamp = time()
predictions = model.predict_proba(data[1])
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
elif hasattr(model, "predict"):
start = process_time_ns()
start_timestamp = time()
predictions = model.predict(data[1])
end = process_time_ns() - start
end = process_time_ns()
end_timestamp = time()
else: # pragma: no cover
raise ValueError(
Expand Down

0 comments on commit aef248e

Please sign in to comment.