diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 03c2cc9f6d5a4..aeafb54021e0b 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1830,7 +1830,7 @@ def get_index(): itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")] tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float) - # The moset inner loop variable is used in the index_expr + # The most inner loop variable is used in the index_expr with CppVecKernelChecker( args=None, num_threads=1, tiling_factor=tiling_factor ) as vec_checker: @@ -1843,7 +1843,7 @@ def get_index(): vec_checker.ranges = ranges[:2] submodules = {"get_index": get_index} InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertFalse(vec_checker.simd_vec) + self.assertTrue(vec_checker.simd_vec) # Most inner loop variable irrevalant with CppVecKernelChecker( @@ -2719,6 +2719,31 @@ def forward(self, idx, x): self.assertTrue("cvt_lowp_fp_to_fp32" not in code) self.assertTrue("cvt_fp32_to_lowp_fp" not in code) + def test_concat_inner_vec(self): + def fn(x, y): + return F.relu(torch.cat([x, y], dim=1)) + + x = torch.randn(32, 35) + y = torch.randn(32, 120) + metrics.reset() + self.common(fn, (x, y)) + assert metrics.generated_cpp_vec_kernel_count == 1 + + def test_expr_vec_non_contiguous(self): + def fn(x): + # the pattern from sebotnet33ts_256 + y = torch.nn.functional.pad(x, (0, 31)).reshape(-1, 33, 63) + y = y[:, :32, 31:].reshape(4, 32, 1, 32, 32).expand(-1, -1, 32, -1, -1) + y = y.permute(0, 3, 1, 4, 2).clone(memory_format=torch.contiguous_format) + y = y.view(4, 1024, 1024) + return y.softmax(dim=-1) + + x = torch.randn(128, 2048) + metrics.reset() + self.common(fn, (x,)) + # 4 kernels for max, exp, sum and div + assert metrics.generated_cpp_vec_kernel_count == 4 + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f927f9e5de58d..54aae140dd5ce 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -823,7 +823,7 @@ def clone(self): def generate( self, buffer: IndentedBuffer, - expr: Union[str, CSEVariable, OpsValue], + expr: Union[str, CSEVariable, OpsValue, IndentedBuffer], *, bounds: ValueRanges = ValueRanges.unknown(), write=True, @@ -832,7 +832,7 @@ def generate( if isinstance(expr, OpsValue): expr = expr.value - assert isinstance(expr, (str, CSEVariable)), type(expr) + assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr) assert write or assignment if isinstance(expr, CSEVariable): # If the expressions were always created with all the information, we could @@ -840,7 +840,7 @@ def generate( # with the loose ValueRanges.unknown(), so we need to tighten the bounds expr.bounds = expr.bounds.tighten(bounds) return expr - cache_key = expr + cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr var = self.cache.get(cache_key, None) if not var: var = self.newvar(bounds) if assignment else None @@ -850,11 +850,17 @@ def generate( V.kernel.current_node.codegen_originating_info( buffer, only_once=True ) - if assignment: - line = f"{self.prefix}{var} = {expr}{self.suffix}" + if isinstance(expr, IndentedBuffer): + if assignment: + buffer.writeline(f"{self.prefix}{var} =") + buffer.splice(expr) + buffer.writeline(self.suffix) else: - line = f"{expr}{self.suffix}" - buffer.writeline(line) + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) else: var.bounds = var.bounds.tighten(bounds) @@ -1237,7 +1243,6 @@ class OptimizationContext: dtype: Optional[torch.dtype] = None ops_name: str = "" - is_most_inner_loop_irrevelant: bool = False # Load uint8 value as float32 is_load_uint8_as_float: bool = False diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 0d730fc5ee00b..faa8cbf7c3b66 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -32,7 +32,7 @@ sympy_symbol, ) -from ..virtualized import ops, V +from ..virtualized import ops, OpsValue, V from .common import ( BracesBuffer, CppWrapperKernelArgs, @@ -489,6 +489,14 @@ def update_on_args(self, name, args, kwargs): self._set_dependent_itervars(args[0]) if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): self.is_vec = True + # NOTE [dtype of CppCSEVariable] + # Deciding dtype according to the current optimization context is not + # always accurate since the dtypes are initialized during dtype propagation + # at the beginning of the codegen. It is possible that some ops are invoked + # during the codegen of the current op and take different dtypes from the + # current op. + # TODO(jgong5): A more accurate way of deciding the dtype of the variables is to + # propagate the dtypes here inside `update_on_args`. if ( hasattr(V.interpreter, "current_node") and get_current_node_opt_ctx() is not None @@ -874,12 +882,16 @@ def __new__(cls, *args, **kargs): def wrap(func): # `CppVecKernel` generates both scalar ops and vector ops according to # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` - # (except for "masked") assume the inputs are vectors. We wrap the ops in + # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to # `CppOverrides` when all inputs are scalars. # - # Inputs to ops.masked are handled separately in its own function due to - # the need of recurive handling of masked body. + # Notes on ops handled separately in their own functions: + # `ops.masked`: + # needs recursive handling of masked body. + # `ops.index_expr`: + # needs to further analyze the dependency of the index expression on + # the tiling itervar. def wrapper(*args, **kwargs): has_scalar = any( not arg.is_vec for arg in args if isinstance(arg, CppCSEVariable) @@ -911,8 +923,11 @@ def wrapper(*args, **kwargs): return wrapper - for name, method in vars(cls).items(): - if getattr(method, "__class__", None) == staticmethod and name != "masked": + for name, method in vars(CppVecOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in [ + "masked", + "index_expr", + ]: setattr(self, name, wrap(method.__func__)) return self @@ -1209,6 +1224,7 @@ def to_dtype(x, dtype, src_dtype=None): torch.bfloat16, torch.float16, torch.uint8, + torch.int32, ], f"{__name__} does not support {dtype}" node: torch.fx.Node = V.interpreter.current_node assert node and isinstance(node, torch.fx.Node) @@ -1259,25 +1275,74 @@ def masked(mask, body, other): code.writeline(";") V.kernel.compute.splice(code) + body_code = f"{var}()" + body_code_vec = ( + body_code if result.is_vec else f"at::vec::Vectorized({body_code})" + ) other_code = value_to_cpp(other, "float") other_code_vec = f"at::vec::Vectorized({other_code})" - - if result.is_vec: - type = f"decltype({var}())" + assert isinstance(new_mask, CppCSEVariable), new_mask + if new_mask.is_vec or result.is_vec: + type = f"decltype({body_code_vec})" float_mask = f"to_float_mask({new_mask})" + code = BracesBuffer() + code.writeline("[&]") + with V.kernel.swap_buffers(code), code.indent(): + code.writeline(f"if (all_zero({float_mask}))") + with code.indent(): + code.writeline(f"return {other_code_vec};") + code.writeline("else") + with code.indent(): + code.writeline( + f"return {type}::blendv({other_code_vec}, {body_code_vec}, {float_mask});" + ) + code.writeline("()") csevar = V.kernel.cse.generate( V.kernel.compute, - f"{type}::blendv({other_code_vec}, {var}(), {float_mask})", + code, ) else: csevar = V.kernel.cse.generate( - V.kernel.compute, f"{mask} ? {var}() : {other_code}" + V.kernel.compute, f"{mask} ? {body_code} : {other_code}" ) # `result` is explicitly added to the args for correct propagation # of relevant itervars and vectorization status. csevar.update_on_args("masked", (mask, body, other, result), {}) return csevar + @staticmethod + def index_expr(expr, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype + assert dtype == torch.int32 + assert isinstance(V.kernel, CppVecKernel) + index = V.kernel.rename_indexing(expr) + tiling_var = V.kernel.itervars[V.kernel.tiling_idx] + if not V.kernel.index_depends_on(index, tiling_var): + # if index doesn't depend on tiling_var, it is fine to use a scalar index + return CppOverrides.index_expr(expr, dtype) + if stride_at( + tiling_var, index + ).is_number and not V.kernel.index_indirect_depends_on(index, tiling_var): + stride = stride_at(tiling_var, index) + value = ops.to_dtype(cexpr(index), dtype) + if isinstance(value, OpsValue): + value = value.value + csevar = V.kernel.arange(value, stride) + else: + csevar = V.kernel.load_non_contiguous(None, index, dtype, V.kernel.compute) + csevar.update_on_args("index_expr", (expr, dtype), {}) + return csevar + + +class CppTile2DOverrides(CppVecOverrides): + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppTile2DKernel) + expr = V.kernel.transform_indexing(expr) + return CppVecOverrides.index_expr(expr, dtype) + class CppKernel(Kernel): overrides = CppOverrides # type: ignore[assignment] @@ -1305,7 +1370,13 @@ def masked(self, mask): """Context manager to add an additional mask to loads and stores.""" prior = self._load_mask if prior: - mask = self.cse.generate(self.compute, f"{mask} & {prior}") + mask = ops.and_(mask, prior) + if isinstance(mask, OpsValue): + mask = mask.value + assert isinstance(mask, CppCSEVariable) + # see NOTE [dtype of CppCSEVariable] + # mask's dtype should be bool + mask.dtype = torch.bool self._load_mask = mask try: @@ -1328,6 +1399,22 @@ def index_to_str(self, index: sympy.Expr) -> str: """ return cexpr(self.rename_indexing(index)) + def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + """ + Check if an index has free symbol CppCSEVariable that depends on `itervar`. + """ + return any( + self.cse.varname_map[s.name].depends_on(itervar) + for s in index.free_symbols + if s.name in self.cse.varname_map + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) + ) + + def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + return itervar in index.free_symbols or self.index_indirect_depends_on( + index, itervar + ) + def load(self, name: str, index: sympy.Expr): var = self.args.input(name) index = self.rename_indexing(index) @@ -1589,84 +1676,189 @@ def __init__( self.tiling_idx = tiling_idx metrics.generated_cpp_vec_kernel_count += 1 - def load(self, name: str, index: sympy.Expr): + def _get_vec_load_line( + self, + var: str, + index: sympy.Expr, + dtype: torch.dtype, + load_mask: Optional[CppCSEVariable] = None, + ): + """ + Get a load line str that loads a vector from `var` at `index` of type `dtype`. + If `load_mask` is not None, we do a masked load accordingly. + Notes on the `dtype`: + 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. + It means we load half of the vector lanes for 16-bit data types and quarter of the + vector lanes for 8-bit data types. + 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. + """ opt_ctx: OptimizationContext = get_current_node_opt_ctx() - var = self.args.input(name) - index = self.rename_indexing(index) - dtype = V.graph.get_dtype(name) - tiling_var = self.itervars[self.tiling_idx] - is_broadcast = not index.has(tiling_var) - is_mask = ( - dtype in [torch.bool, torch.uint8] and not opt_ctx.is_load_uint8_as_float - ) - load_mask = f"to_float_mask({self._load_mask})" if self._load_mask else None - non_contiguous = ( - not is_broadcast - and stride_at(tiling_var, index) != 1 - or any( - self.cse.varname_map[s.name].depends_on(tiling_var) - for s in index.free_symbols - if s.name.startswith("tmp") - ) - ) - var_expr = ( - f"{var}[{cexpr_index(index)}]" - if is_broadcast - else f"{var} + {cexpr_index(index)}" - ) - loadbuf = "tmpbuf" if non_contiguous else var_expr - if is_broadcast: - csevar = super().load(name, index) - csevar.dtype = dtype - return csevar - elif dtype in [torch.uint8] and opt_ctx.is_load_uint8_as_float: + assert opt_ctx is not None + load_mask_str = f"to_float_mask({load_mask})" if load_mask else None + loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var + if dtype == torch.uint8 and opt_ctx.is_load_uint8_as_float: line = ( - f"masked_load({loadbuf}, {load_mask})" - if load_mask + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str else f"at::vec::Vectorized::loadu_one_fourth({loadbuf})" ) - elif is_mask: + elif opt_ctx.is_load_as_mask: line = f"flag_to_float_vec({loadbuf})" elif dtype in DTYPE_LOWP_FP: line = ( - f"masked_load({loadbuf}, {load_mask})" - if load_mask + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str else f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>::loadu({loadbuf}, {self.tiling_factor})" ) else: line = ( - f"masked_load({loadbuf}, {load_mask})" - if load_mask - else f"at::vec::Vectorized::loadu({loadbuf})" + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str + else f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>::loadu({loadbuf})" ) + return line - if non_contiguous: - # TODO: support masked_load for non_contiguous path? - tmpbuftype = "float" if is_mask else f"{DTYPE_TO_CPP[dtype]}" - tmpbufsize = f"{self.tiling_factor}" - if dtype in DTYPE_LOWP_FP: - tmpbufsize += " * 2" - tmpbufdeclare = f"__at_align__ {tmpbuftype} tmpbuf[{tmpbufsize}];" - inner = sympy_symbol(f"{tiling_var}_inner") - new_index = self.scale_index_with_offset( - index, itervar_idx=self.tiling_idx, offset=inner + def load_non_contiguous( + self, + var: Optional[str], + index: sympy.Expr, + dtype: torch.dtype, + buffer: Optional[IndentedBuffer] = None, + ) -> CppCSEVariable: + """ + Load a vector in a non-contiguous way. The vector is initialized from an array that is + filled in an inner loop over the tiling factor. + :param var: buffer to load from, i.e. `var[transformed(index)]`. If None, we load the index + as index expression, i.e. `transformed(index)`. + :param index: index into the `var` or the index expression by its own if `var` is None. + The `index` could contain indirect indexing or the tiling itervar. When used in + the inner loop, the index is transformed as follows: + 1. the index is linearized along the tiling dim. + 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. + :param dtype: data type of `var` or `index` if `var` is None. + :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. + :return: a CppCSEVariable that represents the loaded vector. + """ + if buffer is None: + buffer = self.loads + + def get_result_size(dtype: torch.dtype) -> int: + assert dtype.itemsize <= 4 + return self.tiling_factor * (4 // dtype.itemsize) + + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: + assert vec_var.is_vec + code = BracesBuffer() + code.writeline("[&]") + with self.swap_buffers(code), code.indent(): + vec_dtype = vec_var.dtype + assert vec_dtype is not None + if vec_dtype == torch.bool: + vec_dtype = torch.float + result_size = get_result_size(vec_dtype) + code.writeline( + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {result_size}> tmpbuf;" + ) + line = f"{vec_var}.store(tmpbuf.data());" + code.writeline(line) + code.writeline("return tmpbuf;") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + return csevar + + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx is not None + is_mask = opt_ctx.is_load_as_mask + code = BracesBuffer() + code.writeline("[&]") + with self.swap_buffers(code), code.indent(): + result_type = "float" if is_mask else f"{DTYPE_TO_CPP[dtype]}" + result_size = get_result_size(dtype) + result_declare = ( + f"__at_align__ std::array<{result_type}, {result_size}> tmpbuf;" ) - tmpbufdefine = ( - f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++) " + code.writeline(result_declare) + itervar_inner = sympy_symbol(f"{self.itervars[self.tiling_idx]}_inner") + replacements = {} + for indirect_var in ( + self.cse.varname_map[s.name] + for s in index.free_symbols + if s.name.startswith("tmp") + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + array_var = vec_to_array(indirect_var) + replacements[indirect_var] = f"{array_var}[{itervar_inner}]" + load_mask = None + if self._load_mask is not None: + assert isinstance(self._load_mask, CppCSEVariable), self._load_mask + if self._load_mask.is_vec: + load_mask = ( + f"vector_lane_mask_check({self._load_mask}, {itervar_inner})" + ) + else: + load_mask = f"{self._load_mask} != 0" + index = sympy_subs(index, replacements) + index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=itervar_inner + ) + if codecache.is_gcc(): + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") + else: + code.writeline(f"#pragma unroll {self.tiling_factor}") + code.writeline( + f"for (long {itervar_inner} = 0; {itervar_inner} < {self.tiling_factor}; {itervar_inner}++)" ) - rhs = f"{var}[{cexpr_index(new_index)}]" - if is_mask: - rhs = f"flag_to_float_scalar({rhs})" - tmpbufdefine += f"tmpbuf[{inner}] = {rhs};" - line = f"([&]() {{ {tmpbufdeclare} {tmpbufdefine} return {line}; }})()" + with code.indent(), contextlib.ExitStack() as stack: + rhs = ( + f"{var}[{cexpr_index(index)}]" + if var is not None + else f"{cexpr_index(index)}" + ) + if is_mask: + rhs = f"flag_to_float_scalar({rhs})" + if load_mask: + code.writeline(f"if ({load_mask})") + stack.enter_context(code.indent()) + code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) + code.writeline(f"return {load_line};") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar - csevar = self.cse.generate(self.loads, line) - csevar.update_on_args("load", (name, index), {}) + def load(self, name: str, index: sympy.Expr): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.input(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + tiling_var = self.itervars[self.tiling_idx] + if not self.index_depends_on(index, tiling_var): + # load scalar and lazily broadcast it on demand + return super().load(name, index) + non_contiguous = stride_at( + tiling_var, index + ) != 1 or self.index_indirect_depends_on(index, tiling_var) + if non_contiguous: + csevar = self.load_non_contiguous(var, index, dtype) + else: + line = self._get_vec_load_line(var, index, dtype, self._load_mask) + csevar = self.cse.generate(self.loads, line) # type: ignore[assignment] assert isinstance(csevar, CppCSEVariable) + csevar.update_on_args("load", (name, index), {}) csevar.is_vec = True return csevar - def get_vec_store_line(self, value, var, index, dtype): + def _get_vec_store_line( + self, + value: Union[str, CppCSEVariable], + var: str, + index: sympy.Expr, + dtype: torch.dtype, + ): """ Get a store line str that stores `value` into `var` at `index` of `dtype`. :param value: Vectorized type templaterized on `dtype`. @@ -1716,7 +1908,7 @@ def store(self, name, index, value, mode=None): self.stores.writeline( DeferredLine( name, - self.get_vec_store_line(value, var, index, V.graph.get_dtype(name)), + self._get_vec_store_line(value, var, index, V.graph.get_dtype(name)), ) ) @@ -1846,12 +2038,12 @@ def store_reduction(self, name, index, value): store_lines += [ DeferredLine( name, - self.get_vec_store_line(value, var, index, out_dtype), + self._get_vec_store_line(value, var, index, out_dtype), ) ] self.reduction_suffix.writelines(store_lines) - def broadcast(self, scalar_var: CppCSEVariable): + def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: assert ( not scalar_var.is_vec and self.itervars[self.tiling_idx] not in scalar_var.dependent_itervars @@ -1872,6 +2064,22 @@ def broadcast(self, scalar_var: CppCSEVariable): vec_var.is_vec = True return vec_var + def arange( + self, index: Union[sympy.Expr, CppCSEVariable], stride: sympy.Symbol + ) -> CppCSEVariable: + if isinstance(index, sympy.Expr): + index = cexpr(index) + else: + assert isinstance(index, CppCSEVariable) + assert not index.is_vec + csevar = self.cse.generate( + self.compute, f"at::vec::Vectorized::arange({index}, {stride})" + ) + assert isinstance(csevar, CppCSEVariable) + csevar.dtype = torch.int32 + csevar.is_vec = True + return csevar + class CppTile2DKernel(CppVecKernel): """ @@ -1904,6 +2112,8 @@ class CppTile2DKernel(CppVecKernel): ... """ + overrides = CppTile2DOverrides # type: ignore[assignment] + def __init__(self, args, num_threads, tiling_factor, tiling_indices, tiling_dtype): super().__init__( args, num_threads, tiling_factor, tiling_indices[1], tiling_dtype @@ -1915,7 +2125,8 @@ def inner_itervar(self): def need_vec_transpose(self, index): return ( - stride_at(self.itervars[self.outer_idx], index) == 1 + self._load_mask is None # TODO: support transposition with mask + and stride_at(self.itervars[self.outer_idx], index) == 1 and index.has(self.itervars[self.tiling_idx]) and not stride_at(self.itervars[self.tiling_idx], index).has( self.itervars[self.tiling_idx] @@ -1972,26 +2183,14 @@ def load(self, name: str, index: sympy.Expr): # vector load inside the kernel inner loop loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}" dtype = V.graph.get_dtype(name) - if dtype in DTYPE_LOWP_FP: - line = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>::loadu({loadbuf}, {self.tiling_factor})" - elif ( - V.graph.get_dtype(name) in [torch.uint8] - and opt_ctx.is_load_uint8_as_float - ): - line = f"at::vec::Vectorized::loadu_one_fourth({loadbuf})" - else: - line = f"at::vec::Vectorized::loadu({loadbuf})" + line = self._get_vec_load_line(loadbuf, 0, dtype) csevar = self.cse.generate(self.loads, line) csevar.update_on_args("load", (name, index), {}) assert isinstance(csevar, CppCSEVariable) csevar.is_vec = True return csevar else: - new_index = self.scale_index_with_offset( - index, - itervar_idx=self.outer_idx, - offset=inner, - ) + new_index = self.transform_indexing(index) return super().load(name, new_index) def store(self, name, index, value, mode=None): @@ -2016,11 +2215,7 @@ def store(self, name, index, value, mode=None): line = f"{value}.store({storebuf});" self.stores.writeline(DeferredLine(name, line)) else: - new_index = self.scale_index_with_offset( - index, - itervar_idx=self.outer_idx, - offset=inner, - ) + new_index = self.transform_indexing(index) super().store(name, new_index, value, mode) def codegen_inner_loops(self, code): @@ -2039,6 +2234,13 @@ def set_ranges(self, group, reduction_group): ) return vars + def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: + return self.scale_index_with_offset( + index, + itervar_idx=self.outer_idx, + offset=self.inner_itervar(), + ) + class CppVecKernelChecker(CppVecKernel): def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1): @@ -2425,13 +2627,6 @@ def can_use_int32(): opt_ctx.dtype = dtype self.disable_vec(f"index_expr: {expr}, dtype {dtype}") - tiling_var = self.itervars[self.tiling_idx] - tiling_var_irrelevant = not expr.has(tiling_var) - if not tiling_var_irrelevant: - self.disable_vec( - f"index_expr (tiling var relevant): {expr}, dtype {dtype}" - ) - opt_ctx.is_most_inner_loop_irrevelant = tiling_var_irrelevant tmp_var = self.cse.newvar() return tmp_var diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 23f72218a0cc1..98e3f58245d4d 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -407,4 +407,36 @@ inline at::vec::Vectorized to_float_mask(int src) { *(uint32_t*)&mask = src ? 0xFFFFFFFF : 0; return at::vec::Vectorized(mask); } + +inline bool all_zero(at::vec::Vectorized src) { +# if defined(CPU_CAPABILITY_AVX512) + auto src_int = _mm512_castps_si512(src); + __mmask16 mask = _mm512_test_epi32_mask(src_int, src_int); + return mask == 0; +# elif defined(CPU_CAPABILITY_AVX2) + return _mm256_testz_ps(src, src); +# else + __at_align__ int mask[at::vec::Vectorized::size()]; + src.store(mask); + for (int i = 0; i < at::vec::Vectorized::size(); i++) { + if (mask[i] != 0) { + return false; + } + } + return true; +# endif +} + +inline bool vector_lane_mask_check(at::vec::Vectorized src, int lane) { +# if defined(CPU_CAPABILITY_AVX512) + return _mm512_movepi32_mask(_mm512_castps_si512(src)) & (1 << lane); +# elif defined(CPU_CAPABILITY_AVX2) + return _mm256_movemask_ps(src) & (1 << lane); +# else + __at_align__ int mask[at::vec::Vectorized::size()]; + src.store(mask); + return mask[lane] != 0; +# endif +} + #endif