Skip to content

Commit

Permalink
added freeze_encoder param
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Dec 9, 2024
1 parent eafca1e commit d572465
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self,
std: List[float],
data_type_max: float,
loss: Callable,
freeze_encoder: bool = False,
weights: str = None,
class_labels: List[str] = None,
class_colors: List[str] = None,
Expand All @@ -38,7 +39,7 @@ def __init__(self,
self.data_type_max = data_type_max
self.class_colors = class_colors
self.num_classes = num_classes
self.model = SegFormer(encoder, in_channels, weights, self.num_classes)
self.model = SegFormer(encoder, in_channels, weights, freeze_encoder, self.num_classes)
if weights_from_checkpoint_path:
print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}")
checkpoint = torch.load(weights_from_checkpoint_path)
Expand Down

0 comments on commit d572465

Please sign in to comment.