Skip to content

Commit

Permalink
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 2, 2024
1 parent 90d645c commit d862a1f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 17 deletions.
13 changes: 8 additions & 5 deletions cyclops/data/features/medical_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ def decode_example(
use_auth_token = token_per_repo_id.get(repo_id)
except ValueError:
use_auth_token = None
with xopen(
path,
"rb",
use_auth_token=use_auth_token,
) as file_obj, BytesIO(file_obj.read()) as buffer:
with (
xopen(
path,
"rb",
use_auth_token=use_auth_token,
) as file_obj,
BytesIO(file_obj.read()) as buffer,
):
image, metadata = self._read_file_from_bytes(buffer)
metadata["filename_or_obj"] = path

Expand Down
8 changes: 5 additions & 3 deletions cyclops/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ def get_columns_as_array(
if isinstance(columns, str):
columns = [columns]

with dataset.formatted_as("arrow", columns=columns, output_all_columns=True) if (
isinstance(dataset, Dataset) and dataset.format != "arrow"
) else nullcontext():
with (
dataset.formatted_as("arrow", columns=columns, output_all_columns=True)
if (isinstance(dataset, Dataset) and dataset.format != "arrow")
else nullcontext()
):
out_arr = squeeze_all(
xp.stack(
[xp.asarray(dataset[col].to_pylist()) for col in columns], axis=-1
Expand Down
21 changes: 12 additions & 9 deletions cyclops/models/wrappers/pt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,14 +968,17 @@ def fit(
splits_mapping["validation"] = val_split

format_kwargs = {} if transforms is None else {"transform": transforms}
with X[train_split].formatted_as(
"custom" if transforms is not None else "torch",
columns=feature_columns + target_columns,
**format_kwargs,
), X[val_split].formatted_as(
"custom" if transforms is not None else "torch",
columns=feature_columns + target_columns,
**format_kwargs,
with (
X[train_split].formatted_as(
"custom" if transforms is not None else "torch",
columns=feature_columns + target_columns,
**format_kwargs,
),
X[val_split].formatted_as(
"custom" if transforms is not None else "torch",
columns=feature_columns + target_columns,
**format_kwargs,
),
):
self.partial_fit(
X,
Expand Down Expand Up @@ -1309,7 +1312,7 @@ def save_model(self, filepath: str, overwrite: bool = True, **kwargs):
if include_lr_scheduler:
state_dict["lr_scheduler"] = self.lr_scheduler_.state_dict() # type: ignore[attr-defined]

epoch = kwargs.get("epoch", None)
epoch = kwargs.get("epoch")
if epoch is not None:
filename, extension = os.path.basename(filepath).split(".")
filepath = join(
Expand Down

0 comments on commit d862a1f

Please sign in to comment.