Skip to content

Commit

Permalink
Revert 'Refactor torch training test to stop using metrics from check…
Browse files Browse the repository at this point in the history
…point'
  • Loading branch information
daniil-lyakhov committed Sep 18, 2023
1 parent 390b010 commit 9618b49
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/torch/test_compression_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def finalize(self, dataset_dir, tmp_path_factory, weekly_models_path) -> "Compre
return self

def get_metric(self):
return self.expected_accuracy_
return self.sample_handler.get_metric_value_from_checkpoint(
self.checkpoint_save_dir, self.checkpoint_name, self.config_path
)

def _get_weight_path(self, weekly_models_path):
if self.weights_filename_ is None:
Expand Down Expand Up @@ -247,7 +249,9 @@ def subnet_expected_accuracy(self, subnet_expected_accuracy: float):
return self

def get_subnet_metric(self):
return self.subnet_expected_accuracy_
return self.sample_handler.get_metric_value_from_checkpoint(
self.checkpoint_save_dir, self.subnet_checkpoint_name
)

def _get_weight_path(self, weekly_models_path):
return os.path.join(
Expand Down

0 comments on commit 9618b49

Please sign in to comment.