Skip to content

Commit

Permalink
Revert "[dtensor][MTPG] make sharding prop lru cache not shared among…
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchmergebot committed Aug 25, 2024
1 parent 268092d commit e5563f7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 18 deletions.
1 change: 1 addition & 0 deletions test/distributed/_tensor/test_dtensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def wrapped(fn):
xfail("nn.functional.pdist"),
xfail("nn.functional.pixel_shuffle"),
xfail("nn.functional.pixel_unshuffle"),
xfail("nn.functional.poisson_nll_loss"),
xfail("nn.functional.prelu"),
xfail("nn.functional.relu6"),
xfail("nn.functional.rrelu"),
Expand Down
20 changes: 2 additions & 18 deletions torch/distributed/_tensor/_sharding_prop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import threading
from functools import lru_cache
from itertools import chain
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -38,17 +37,6 @@ def _length(obj) -> int:
return len(obj)


class LocalLRUCache(threading.local):
def __init__(self, user_function: Callable) -> None:
self.cache = lru_cache(None)(user_function)

def __call__(self, *args, **kwargs) -> object:
return self.cache(*args, **kwargs)

def cache_info(self):
return self.cache.cache_info()


class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
Expand All @@ -58,9 +46,7 @@ def __init__(self) -> None:
] = {}
# op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
self.propagate_op_sharding = LocalLRUCache(
self.propagate_op_sharding_non_cached
)
self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign]
# op map to save indices of shape (and stride) args which may need to be modified in sharding prop
self.op_to_shape_and_stride_idx: Dict[
OpOverload, Union[int, Tuple[int, int]]
Expand Down Expand Up @@ -197,9 +183,7 @@ def propagate(self, op_info: OpInfo) -> None:
if op_info.schema.has_symints:
output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
else:
output_sharding = cast(
OutputSharding, self.propagate_op_sharding(op_info.schema)
)
output_sharding = self.propagate_op_sharding(op_info.schema)
op_info.output_sharding = output_sharding

def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding:
Expand Down

0 comments on commit e5563f7

Please sign in to comment.