Skip to content

Commit

Permalink
hacked up
Browse files Browse the repository at this point in the history
stack-info: PR: #88, branch: drisspg/stack/1
  • Loading branch information
drisspg committed Dec 20, 2024
1 parent 4ff3caf commit 1bfe7b7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .chain_matmul import call_replacement_chain_matmul
from .cholesky import call_replacement_cholesky
from .qr import call_replacement_qr
from .size_average import call_replacement_loss

from .range import call_replacement_range

Expand Down Expand Up @@ -54,6 +55,7 @@ def _call_replacement(
"torch.qr": call_replacement_qr,
"torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast,
"torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast,
"torch.nn.functional.soft_margin_loss": call_replacement_loss
}
replacement = None

Expand Down Expand Up @@ -103,7 +105,8 @@ def visit_Call(self, node) -> None:
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name is None:
return

self.deprecated_config["torch.nn.functional.soft_margin_loss"] = {}
self.deprecated_config["torch.nn.functional.soft_margin_loss"]["remove_pr"] = None
if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERRORS[1].error_code
Expand All @@ -112,7 +115,6 @@ def visit_Call(self, node) -> None:
error_code = self.ERRORS[0].error_code
message = self.ERRORS[0].message(old_name=qualified_name)
replacement = self._call_replacement(node, qualified_name)

reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
message = f"{message}: {reference}"
Expand Down
60 changes: 60 additions & 0 deletions torchfix/visitors/deprecated_symbols/size_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""size_average and reduce are deprecated, please use reduction='mean' instead."""

import libcst as cst
from ...common import TorchVisitor, get_module_name
from torch.nn._reduction import legacy_get_string

def call_replacement_loss(node: cst.Call) -> cst.CSTNode:
"""
Replace loss function that contains size_average / reduce with a new loss function
that uses reduction='mean' instead. Uses the logic from torch.nn._reduction to
determine the correct reduction value.
Args:
node: The CST Call node representing the loss function call
Returns:
A new CST node with updated reduction parameter
"""
# Extract existing arguments
input_arg = TorchVisitor.get_specific_arg(node, "input", 0)
target_arg = TorchVisitor.get_specific_arg(node, "target", 1)

size_average_arg = TorchVisitor.get_specific_arg(node, "size_average", 2)
reduce_arg = TorchVisitor.get_specific_arg(node, "reduce", 3)

# Ensure input and target args maintain their commas
input_arg = cst.ensure_type(input_arg, cst.Arg).with_changes(
comma=cst.MaybeSentinel.DEFAULT
)

target_arg = cst.ensure_type(target_arg, cst.Arg).with_changes(
comma=cst.MaybeSentinel.DEFAULT
)

# Extract size_average and reduce values
size_average_value = None
reduce_value = None

if size_average_arg:
size_average_value = getattr(size_average_arg.value, "value", True)
if reduce_arg:
reduce_value = getattr(reduce_arg.value, "value", True)

if size_average_value is None and reduce_value is None:
# We want to return the original call as is
return node
# Use legacy_get_string to determine the correct reduction value
reduction = legacy_get_string(size_average_value, reduce_value, emit_warning=False)

# Create new reduction argument
reduction_arg = cst.Arg(
value=cst.SimpleString(f"'{reduction}'"),
keyword=cst.Name("reduction"),
comma=cst.MaybeSentinel.DEFAULT,
)

# Build new arguments list
new_args = [input_arg, target_arg, reduction_arg]
replacement = node.with_changes(args=new_args)
return replacement

0 comments on commit 1bfe7b7

Please sign in to comment.