diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index a084a398d7ed..c94e6a810874 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -152,7 +152,7 @@ def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): out_lab, lab = lab, ssa(None, "alu_cast", lang.types[dtype]) for i, op in enumerate(operands): operands[i] = ssa(None, "alu_cast", lang.types[dtype]) - kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) + kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ: # pass in the other dtype here kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype])) @@ -183,6 +183,8 @@ def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): assert vin[1].dtype is not None if dtype.count > 1: r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)] + if(len(vin)>3): + for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};") kk((f"@{r[vin[2]]}"if len(vin) > 3 else "") + f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];") else: @@ -255,11 +257,8 @@ class PTXLanguage(AssemblyLanguage): const_requires_mov = [dtypes.half, dtypes.bool] def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: - # if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}" - # else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") val = render_val(x, dtype) if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"] - # if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"] return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val def render_local(self, dest, name, size, dtype) -> List[str]: @@ -286,11 +285,7 @@ def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str] def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"] - if atype == dtypes.bool: - st = f".b{self.types[dtype][1:]}" - return[f".reg {st} {d}_bin;", - f"selp{st} {d}_bin, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};", - f"mov{st} {d}, {d}_bin;"] + if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"] if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"] rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')