Skip to content

Commit

Permalink
roll back wandb_reporter and a few typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yc7z committed Nov 13, 2024
1 parent 022e76c commit d8e83ff
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
6 changes: 3 additions & 3 deletions fl4health/model_bases/masked_layers/masked_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def from_pretrained(cls, conv_module: nn.ConvTranspose1d) -> MaskedConvTranspose
# type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
kernel_size_ = _single(conv_module.kernel_size)
stride_ = _single(conv_module.stride)
padding_ = conv_module.padding if isinstance(conv_module.padding, str) else _single(conv_module.padding)
padding_ = _single(conv_module.padding)
dilation_ = _single(conv_module.dilation)
output_padding_ = _single(conv_module.output_padding)
masked_conv_module = cls(
Expand Down Expand Up @@ -616,7 +616,7 @@ def from_pretrained(cls, conv_module: nn.ConvTranspose2d) -> MaskedConvTranspose
# type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
kernel_size_ = _pair(conv_module.kernel_size)
stride_ = _pair(conv_module.stride)
padding_ = conv_module.padding if isinstance(conv_module.padding, str) else _pair(conv_module.padding)
padding_ = _pair(conv_module.padding)
dilation_ = _pair(conv_module.dilation)
output_padding_ = _pair(conv_module.output_padding)
masked_conv_module = cls(
Expand Down Expand Up @@ -756,7 +756,7 @@ def from_pretrained(cls, conv_module: nn.ConvTranspose3d) -> MaskedConvTranspose
# type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
kernel_size_ = _triple(conv_module.kernel_size)
stride_ = _triple(conv_module.stride)
padding_ = conv_module.padding if isinstance(conv_module.padding, str) else _triple(conv_module.padding)
padding_ = _triple(conv_module.padding)
dilation_ = _triple(conv_module.dilation)
output_padding_ = _triple(conv_module.output_padding)
masked_conv_module = cls(
Expand Down
9 changes: 1 addition & 8 deletions fl4health/reporting/wandb_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def start_run(self, wandb_init_kwargs: dict[str, Any]) -> None:
if not self.initialized:
self.initialize()

wandb_run = wandb.init(
self.run = wandb.init(
project=self.project,
entity=self.entity,
config=self.config,
Expand All @@ -110,13 +110,6 @@ def start_run(self, wandb_init_kwargs: dict[str, Any]) -> None:
id=self.id,
**wandb_init_kwargs # Other less commonly used kwargs
)

if isinstance(wandb_run, wandb.wandb_run.Run):
self.run = wandb_run
else:
# Handle the case when wandb_run is None or RunDisabled
raise TypeError("wandb.init did not return the expected type `wandb.wandb_run.Run`")

self.run_id = self.run._run_id # If run_id was None, we need to reset run id
self.run_started = True

Expand Down

0 comments on commit d8e83ff

Please sign in to comment.