From ed30fa74ab5ba5a523c2c247776ce55f2174e36f Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 3 Nov 2024 15:22:30 -0800 Subject: [PATCH] [inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139523 Approved by: https://github.com/ezyang ghstack dependencies: #139364, #139365, #139370, #139452 --- torch/_inductor/codegen/common.py | 2 +- torch/_inductor/codegen/cpp.py | 6 ++-- .../_inductor/codegen/cpp_template_kernel.py | 2 +- torch/_inductor/codegen/halide.py | 18 +++++------ torch/_inductor/codegen/simd.py | 16 +++++----- .../_inductor/codegen/simd_kernel_features.py | 2 +- torch/_inductor/codegen/triton.py | 22 ++++++-------- torch/_inductor/dependencies.py | 4 +-- torch/_inductor/index_propagation.py | 4 +-- torch/_inductor/ir.py | 30 +++++++++---------- torch/_inductor/lowering.py | 12 ++++---- torch/_inductor/ops_handler.py | 2 +- torch/_inductor/scheduler.py | 2 +- torch/_inductor/select_algorithm.py | 8 ++--- torch/_inductor/sizevars.py | 16 +++++----- torch/_inductor/utils.py | 2 +- torch/fx/experimental/sym_node.py | 8 ++--- torch/fx/experimental/symbolic_shapes.py | 4 +-- torch/utils/_sympy/functions.py | 2 +- 19 files changed, 77 insertions(+), 85 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 0fb35f8824f645..b2beaa2a0e4ee6 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -170,7 +170,7 @@ class TensorArg: name: str buffer: str dtype: torch.dtype - offset: sympy.Expr = sympy.Integer(0) # c++ only + offset: sympy.Expr = sympy.S.Zero # c++ only alias_of: Optional[str] = None # halide only diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ceaa9c8cdb1cf2..41fcaa86b303c8 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -246,7 +246,7 @@ def stride_at(index: sympy.Expr, var: sympy.Symbol): # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation. # in this case, there is no dependencies between index and var. - return sympy.Integer(0) + return sympy.S.Zero replacement = {var: var + 1} new_index = sympy_subs(index, replacement) # type: ignore[arg-type] return sympy.simplify(new_index - index) @@ -4711,8 +4711,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): class LoopLevel: var: Optional[sympy.Expr] = None size: Optional[sympy.Expr] = None - offset: sympy.Expr = sympy.Integer(0) - steps: sympy.Expr = sympy.Integer(1) + offset: sympy.Expr = sympy.S.Zero + steps: sympy.Expr = sympy.S.One parallel: int = 0 simd_omp: bool = False simd_vec: bool = False diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 453e4b37375e98..768aadc563e095 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -216,7 +216,7 @@ def store_pointwise_nodes( for i, sz in enumerate(var_sizes[0]) } if not offsets: - offsets = [sympy.Integer(0)] * len(var_sizes[0]) + offsets = [sympy.S.Zero] * len(var_sizes[0]) if not reindexers: reindexers = [None] * len(nodes) assert len(offsets) == len(var_sizes[0]) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index da23c5e81f7b15..584c5a5393a631 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -788,7 +788,7 @@ def visit_floor_div(base, divisor): if not nodes: nodes.append(tree.lookup(1, tree.numel)) handled_count = 0 - divisor = sympy.Integer(1) + divisor = sympy.S.One added_sym_size = [] # decide on a minimal set of symbols and put them in self.halide_vars while handled_count < len(nodes) and not eq(tree.numel, divisor): @@ -846,7 +846,7 @@ def visit_floor_div(base, divisor): idx += 1 divisor *= size length = 1 - expr = sympy.Integer(0) + expr = sympy.S.Zero while not eq(node.length, length): sym, size = added_sym_size[idx] idx += 1 @@ -855,8 +855,8 @@ def visit_floor_div(base, divisor): self.index_replacements[node.symbol()] = expr except IndexError: assert had_fallback - full_index = sympy.Integer(0) - stride = sympy.Integer(1) + full_index = sympy.S.Zero + stride = sympy.S.One for sym, size in added_sym_size: full_index += stride * sym stride *= size @@ -937,8 +937,8 @@ def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool): ), sym # group the expression by variables used - offset = sympy.Integer(0) - split_expr = {s: sympy.Integer(0) for s in symbols} + offset = sympy.S.Zero + split_expr = {s: sympy.S.Zero for s in symbols} split_failed: List[Tuple[List[sympy.Symbol], sympy.Expr]] = [] index = sympy.expand(self.rename_indexing(index)) for part in index.args if isinstance(index, sympy.Add) else [index]: @@ -972,7 +972,7 @@ def expr_to_dimension(expr, syms): length = sympy.simplify( sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1 ) - stride = sympy.Integer(1) + stride = sympy.S.One if isinstance(expr, sympy.Mul): for term in expr.args: if isinstance(term, sympy.Integer): @@ -994,11 +994,11 @@ def expr_to_dimension(expr, syms): if not dims: # scalar load/store if self.has_indirect_indexing: # workaround https://github.com/halide/Halide/issues/8338 - dims.append(DimensionInfo(sympy.Integer(0), 1, 1)) + dims.append(DimensionInfo(sympy.S.Zero, 1, 1)) elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1): # Halide assumes dimension 0 is stride == 1, so add a dummy dimension dims.insert( - 0, DimensionInfo(sympy.Integer(0), 1 if is_store else dims[0].stride, 1) + 0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1) ) if dims and not is_store: diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index c255908e5dd77f..5239954f6fbf13 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -101,8 +101,8 @@ def __init__( prefix: str, *, kernel: SIMDKernel, - divisor=sympy.Integer(1), - length=sympy.Integer(1), + divisor=sympy.S.One, + length=sympy.S.One, root: IterationRangesRoot, ) -> None: super().__init__() @@ -205,7 +205,7 @@ def lookup(self, divisor, length): return self.nodes[expr] def construct_entries(self, lengths: List[sympy.Expr]): - divisor = sympy.Integer(1) + divisor = sympy.S.One itervars = [] for length in reversed(lengths): itervars.append(self.lookup(divisor, length)) @@ -224,7 +224,7 @@ def vars_and_sizes(self, index: sympy.Expr): x.divisor, fallback=config.unbacked_symint_fallback ) ) - divisor = sympy.Integer(1) + divisor = sympy.S.One index_vars = [] sizes = [] @@ -481,7 +481,7 @@ def combine_modular_indexing_pairs(self, index): new_index, { tree_node.root.index_sym(): tree_node.root.lookup( - sympy.Integer(1), tree_node.root.numel + sympy.S.One, tree_node.root.numel ).symbol() }, ) @@ -572,7 +572,7 @@ def getter(flat_vars): return_getters = [] for size in length_group: if sv.statically_known_equals(size, 1): # type: ignore[arg-type] - return_getters.append(lambda _: sympy.Integer(0)) + return_getters.append(lambda _: sympy.S.Zero) continue while current_group < len(remaining) and sv.statically_known_equals( @@ -635,7 +635,7 @@ def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): """ groups = [rt.numel for rt in self.range_trees] if not self.inside_reduction: - groups[-1] = sympy.Integer(1) + groups[-1] = sympy.S.One if len(lengths) == len(self.range_trees) and all( V.graph.sizevars.simplify(sympy_product(x) - g) == 0 @@ -1564,7 +1564,7 @@ def candidate_tilings(node): return tilings @classmethod - def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): + def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One): """ Heuristics to decide how to tile kernels. Currently, we tile based on stride-1 dimensions. diff --git a/torch/_inductor/codegen/simd_kernel_features.py b/torch/_inductor/codegen/simd_kernel_features.py index e6c7d2fa9290fc..4b278cfb70851a 100644 --- a/torch/_inductor/codegen/simd_kernel_features.py +++ b/torch/_inductor/codegen/simd_kernel_features.py @@ -70,7 +70,7 @@ def __init__( self, node_schedule: List[NodeScheduleEntry], numel: sympy.Expr, - reduction_numel: sympy.Expr = sympy.Integer(1), + reduction_numel: sympy.Expr = sympy.S.One, ): self.node_schedule = node_schedule self.numel = V.graph.sizevars.simplify(numel) # numel excludes reduction_numel diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 117aa491e60b05..92684733bc85ed 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -241,7 +241,7 @@ def codegen_broadcast_and_reshape( # Reshape to add singletons. pre_broadcast_shape = [ - sympy.Integer(1) if is_broadcasting else dim + sympy.S.One if is_broadcasting else dim for dim, is_broadcasting in zip( self.broadcast_shape, self.broadcasting_dims ) @@ -342,7 +342,7 @@ def remove_dims(it): and V.kernel.numels[-1] != 1 ): # Need to expand rank by 1 to match rank when self.inside_reduction=True - final_shape.append(sympy.Integer(1)) + final_shape.append(sympy.S.One) return BlockPtrOptions( params=params, @@ -375,9 +375,7 @@ def format(self, name: str, roffset=True) -> str: f = V.kernel.index_to_str offsets = [*self.offsets] if not roffset: - offsets = [ - self.replace_roffset(offset, sympy.Integer(0)) for offset in offsets - ] + offsets = [self.replace_roffset(offset, sympy.S.Zero) for offset in offsets] args = [ ( f"{name} + ({f(self.constant_offset)})" @@ -408,9 +406,7 @@ def boundary_check(self) -> List[int]: idx for idx in range(len(self.shape)) if ( - not sizevars.statically_known_equals( - self.strides[idx], sympy.Integer(0) - ) + not sizevars.statically_known_equals(self.strides[idx], sympy.S.Zero) and not sizevars.statically_known_multiple_of( self.shape[idx], self.block_shape[idx] ) @@ -437,7 +433,7 @@ def advance_roffset(self): advance = [ ( self.replace_roffset(offset, rblock) - - self.replace_roffset(offset, sympy.Integer(0)) + - self.replace_roffset(offset, sympy.S.Zero) ) for offset in self.offsets ] @@ -1655,7 +1651,7 @@ def get_slice_numels(dims: List[Any]) -> List[Any]: Compute the cumulative size of each dimension's slice. This proceeds from the last dim up to the second. """ - numels = [sympy.Integer(1)] + numels = [sympy.S.One] for dim in dims[:0:-1]: numel = dim * numels[0] numels.insert(0, numel) @@ -1680,10 +1676,10 @@ def get_slice_numels(dims: List[Any]) -> List[Any]: # Provide default values for unmatched dims and strides. for dim in dims[1:]: if dim not in match: - match[dim] = sympy.Integer(1) + match[dim] = sympy.S.One for stride in strides[1:]: if stride not in match: - match[stride] = sympy.Integer(0) + match[stride] = sympy.S.Zero sizevars = V.graph.sizevars @@ -1786,7 +1782,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]: # For example xindex * 5 + rindex * 3 is partitioned to # (xindex * 5, rindex * 3). symbol = tree.symbol() - subexpr = sympy.Integer(0) + sum( + subexpr = sympy.S.Zero + sum( expr for expr in index_terms if symbol in expr.free_symbols ) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 94591381623750..75f2fdb62ebf4c 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -204,7 +204,7 @@ def get_numel(self) -> sympy.Expr: numel = V.graph.get_numel(self.name) else: vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols) - numel = sympy.Integer(1) + numel = sympy.S.One for var, size in zip(self.var_names, self.size): if var in vars: numel = numel * size @@ -328,7 +328,7 @@ def index(self): raise NotImplementedError("WeakDep does not have an index") def get_numel(self) -> sympy.Expr: - return sympy.Integer(1) + return sympy.S.One def rename(self, renames: Dict[str, str]) -> "WeakDep": if self.name in renames: diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index f4384d51b7d4ad..46793c1dd87a31 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -135,7 +135,7 @@ def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: if not is_integer_dtype(result_type): return NotImplemented - result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) + result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr) return TypedExpr(result_expr, result_type) @staticmethod @@ -152,7 +152,7 @@ def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: x_expr.is_nonnegative is not None and x_expr.is_nonnegative == y_expr.is_positive ): - result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) + result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr) return TypedExpr(result_expr, result_type) return NotImplemented diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8b547560e24767..83bad00258a3f8 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -587,7 +587,7 @@ def create(cls, *args, **kwargs): @staticmethod def _index(ranges, prefix=SymT.INDEX): return [ - sympy.Integer(0) if s == 1 else sympy_index_symbol_with_prefix(prefix, n) + sympy.S.Zero if s == 1 else sympy_index_symbol_with_prefix(prefix, n) for n, s in enumerate(ranges) ] @@ -1199,7 +1199,7 @@ def fn(index): else: def fn(index): - reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + reduction_index = [sympy.S.Zero for _ in reduction_ranges] return inner_fn(index, reduction_index) return Pointwise.create( @@ -1619,7 +1619,7 @@ def inner_fn(idx): def copy(loader): def inner_fn(idx): - reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + reduction_index = [sympy.S.Zero for _ in reduction_ranges] return loader(idx, reduction_index) return Pointwise.create( @@ -2345,14 +2345,14 @@ def create(cls, x, new_size): storage, old_layout = as_storage_and_layout(x) skip = len(new_size) - len(old_layout.size) assert skip >= 0 - new_stride = [sympy.Integer(0)] * skip + new_stride = [sympy.S.Zero] * skip for stride, size in zip(old_layout.stride, old_layout.size): new_stride.append( stride if not V.graph.sizevars.shape_env.evaluate_expr( sympy.Eq(size, 1), size_oblivious=True ) - else sympy.Integer(0) + else sympy.S.Zero ) new_layout = FixedLayout( old_layout.device, @@ -2379,7 +2379,7 @@ def reindex(index): for i in range(len(actual)): if actual[i] == 1: # zero out broadcast dimension - index[i] = sympy.Integer(0) + index[i] = sympy.S.Zero return index return reindex @@ -2477,7 +2477,7 @@ def squeezer(size: Tuple[sympy.Expr, ...]): def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]: assert len(index) == len(not_one), f"{index} {not_one}" - new_index = [sympy.Integer(0)] * length + new_index = [sympy.S.Zero] * length for idx, s in zip(not_one, index): new_index[idx] = s return tuple(new_index) @@ -2579,7 +2579,7 @@ def resolve_negative_size(old_size, new_size): new_size = list(new_size) for i in range(len(new_size)): if new_size[i] == -1: - new_size[i] = sympy.Integer(1) + new_size[i] = sympy.S.One new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) break @@ -2618,7 +2618,7 @@ def _dynamic_reshape_indexer(old_size, new_size): size_old = stack_old.pop() var, size_new = stack_new.pop() if size_old == 1: - view_expr.append(sympy.Integer(0)) + view_expr.append(sympy.S.Zero) stack_new.append((var, size_new)) # re-add elif size_new == 1: stack_old.append(size_old) # re-add @@ -2633,7 +2633,7 @@ def _dynamic_reshape_indexer(old_size, new_size): view_expr.append(var) V.graph.sizevars.guard_equals(size_new, size_old) elif size_hint(size_new) > size_hint(size_old): - divisor = sympy.Integer(1) + divisor = sympy.S.One modulus = size_old view_expr.append(ModularIndexing(var, divisor, modulus)) divisor = divisor * modulus @@ -2649,7 +2649,7 @@ def _dynamic_reshape_indexer(old_size, new_size): while stack_old: size_old = stack_old.pop() V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type] - view_expr.append(sympy.Integer(0)) + view_expr.append(sympy.S.Zero) while stack_new: var, size_new = stack_new.pop() @@ -3190,7 +3190,7 @@ class FlexibleLayout(Layout): def contiguous_strides(sizes): if len(sizes) == 0: return [] - reversed_strides = [sympy.Integer(1)] + reversed_strides = [sympy.S.One] for size in reversed(sizes[1:]): reversed_strides.append(size * reversed_strides[-1]) return list(reversed(reversed_strides)) @@ -3204,7 +3204,7 @@ def fill_ordered(sizes, order): [1, 3, 2, 0] """ assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order) - next_stride = sympy.Integer(1) + next_stride = sympy.S.One strides = [None] * len(order) for i in order: @@ -3764,9 +3764,7 @@ def get_fill_order(self): for r in reads ) reads = [ - sympy_subs( - r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0} - ) + sympy_subs(r.index, {v: sympy.S.Zero for v in reduction_vars if v != 0}) for r in reads if isinstance(r, dependencies.MemoryDep) ] diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 784f11131539e8..88706728345784 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -437,9 +437,7 @@ def broadcast_symbolic_shapes(a, b): are symbolic sympy formulas. """ output = [] - for x, y in itertools.zip_longest( - reversed(a), reversed(b), fillvalue=sympy.Integer(1) - ): + for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One): if V.graph.sizevars.shape_env.evaluate_expr( sympy.Eq(y, 1), size_oblivious=True ): @@ -1037,7 +1035,7 @@ def expand_as(x, y): def repeat(x, repeats): old_size = list(x.get_size()) if len(repeats) > len(old_size): - old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size + old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size x = view(x, list(old_size)) assert len(repeats) == len(x.get_size()) @@ -1062,7 +1060,7 @@ def inner_fn(index): for i in range(len(repeats)): if repeats[i] != 1: if old_size[i] == 1: - index[i] = sympy.Integer(0) + index[i] = sympy.S.Zero else: index[i] = ModularIndexing(index[i], 1, old_size[i]) return x_loader(index) @@ -1730,7 +1728,7 @@ def reindexer(idx): def unsqueeze(x, dim): dim = _validate_dim(x, dim, 1) new_shape = list(x.get_size()) - new_shape.insert(dim, sympy.Integer(1)) + new_shape.insert(dim, sympy.S.One) return view(x, new_shape) @@ -5264,7 +5262,7 @@ def loader(index, reduction_index): if keepdims: new_size = list(size) for i in reduced_idx: - new_size[i] = sympy.Integer(1) + new_size[i] = sympy.S.One else: new_size = kept_sizes diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 713dbd1558b3b5..b31f64872d8d30 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -774,7 +774,7 @@ def sort(dtypes, values, stable, descending) -> Tuple[None, ...]: @staticmethod def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: - return sympy.Integer(0) + return sympy.S.Zero # Use mypy to check protocol implemented correctly diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 50fc7f4a9bf8f2..f1584d431e0309 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1030,7 +1030,7 @@ def pointwise_read_writes(self) -> dependencies.ReadWrites: """ sizes, reduction_sizes = self._sizes return dependencies.extract_read_writes( - self._body, sizes, hidden_args=[[sympy.Integer(0)] * len(reduction_sizes)] + self._body, sizes, hidden_args=[[sympy.S.Zero] * len(reduction_sizes)] ) def can_inplace(self, read_dep: dependencies.Dep) -> bool: diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 1c43ab62c738e8..16e3073d08cf55 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -199,7 +199,7 @@ def __init__( numel = sympy_product(output_node.get_size()) super().__init__( numel, - sympy.Integer(1), + sympy.S.One, features=SIMDKernelFeatures([], numel), ) self.input_nodes = input_nodes @@ -519,9 +519,9 @@ def store_output( ) contiguous_index = self.rename_indexing(contiguous_index) self.body.writeline("xindex = " + texpr(contiguous_index)) - self.range_trees[0].lookup( - sympy.Integer(1), sympy_product(lengths) - ).set_name("xindex") + self.range_trees[0].lookup(sympy.S.One, sympy_product(lengths)).set_name( + "xindex" + ) self.template_mask = mask self.template_out = val self.template_indices = indices diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 8775036cf10593..8dbdb9b00722df 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -298,7 +298,7 @@ def reindex(index): new_index = [] for size in sizes: if size is None: - new_index.append(sympy.Integer(0)) + new_index.append(sympy.S.Zero) else: new_index.append(it.pop()) assert not it @@ -617,26 +617,26 @@ def _stride_vars( index = self.simplify(index) # remove any offset index = index - sympy_subs( - index, {v: sympy.Integer(0) for v in support_vars if v != 0} + index, {v: sympy.S.Zero for v in support_vars if v != 0} ) for i in range(len(vars)): # drop all the other dims index_dim = sympy_subs( index, { - support_vars[j]: sympy.Integer(0) + support_vars[j]: sympy.S.Zero for j in range(len(support_vars)) if vars[i] != support_vars[j] and support_vars[j] != 0 }, ) v = vars[i] if v == 0: - strides.append(sympy.Integer(0)) + strides.append(sympy.S.Zero) else: # TODO(jansel): should we use sympy.diff here? strides.append( - sympy_subs(index_dim, {v: sympy.Integer(1)}) - - sympy_subs(index_dim, {v: sympy.Integer(0)}) + sympy_subs(index_dim, {v: sympy.S.One}) + - sympy_subs(index_dim, {v: sympy.S.Zero}) ) return strides @@ -661,7 +661,7 @@ def atomically_apply_size_hint( def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: """Extract offset part of an indexing expression""" index = self.simplify(index) - return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0}) + return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0}) def stride_hints( self, @@ -826,7 +826,7 @@ def expand_floor_div( # Construct the new expression and remember the denominator denominator = factorlist[floor_div_index] - new_index = sympy.Integer(0) + new_index = sympy.S.Zero for var, factor, idx in zip(varlist, factorlist, itertools.count()): if idx == floor_div_index: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5bf0550e8a95e8..bf426ea5901018 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -234,7 +234,7 @@ def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: def sympy_product(it): - return functools.reduce(operator.mul, it, sympy.Integer(1)) + return functools.reduce(operator.mul, it, sympy.S.One) def sympy_dot(seq1, seq2): diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 176a72dd5d1e7d..c041b1131308ec 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -961,14 +961,14 @@ def sympy_is_contiguous_generic(sizes, strides, dim_order): return sympy.false is_contiguous = sympy.true - z = sympy.Integer(1) + z = sympy.S.One # Contiguous if the strides make sense (or the dim is size 1) for d in dim_order: - is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) + is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z) z *= sizes[d] # OR if any size is zero for d in range(dim): - is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) + is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero) return is_contiguous @@ -994,7 +994,7 @@ def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): if dim != len(dim_order): return sympy.false - m = sympy.Integer(0) + m = sympy.S.Zero r = sympy.true # special case for trivial C dimension. default to NCHW diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index eda3ef4d476435..5a22ba40a66fa7 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -3054,7 +3054,7 @@ def _init( # they get assigned the same symbolic variable self.val_to_var: Dict[int, sympy.Symbol] = {} if specialize_zero_one: - self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)} + self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One} self.unbacked_symfloat_counter = itertools.count() self.unbacked_symint_counter = itertools.count() # Similar to guards, but these MUST evaluate to true and can @@ -6599,7 +6599,7 @@ def _suggest_torch_checks( f"torch._check({printer.doprint(sympy.Not(cond))})", ] for i, fix in enumerate(suggested_fixes): - msg += f"\n {i+1}. {fix}" + msg += f"\n {i + 1}. {fix}" src_mapped = ", ".join( f"`{s}` with {' or '.join(src_map[s])}" for s in sorted(s.name for s in cond.free_symbols) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 480be602efa6d2..920b097b632605 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -286,7 +286,7 @@ def eval( cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer ) -> Optional[sympy.Basic]: if base == 0 or modulus == 1: - return sympy.Integer(0) + return sympy.S.Zero if ( isinstance(base, sympy.Integer)