Skip to content

Commit

Permalink
zero out registers when accessing by pred + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
SzymonOzog committed Mar 16, 2024
1 parent a575e50 commit a5281b0
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions tinygrad/renderer/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
out_lab, lab = lab, ssa(None, "alu_cast", lang.types[dtype])
for i, op in enumerate(operands):
operands[i] = ssa(None, "alu_cast", lang.types[dtype])
kk(*lang.render_cast(operands[i], op, dtype, dtypes.half))
kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
# pass in the other dtype here
kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype]))
Expand Down Expand Up @@ -183,6 +183,8 @@ def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
assert vin[1].dtype is not None
if dtype.count > 1:
r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
Expand Down Expand Up @@ -255,11 +257,8 @@ class PTXLanguage(AssemblyLanguage):
const_requires_mov = [dtypes.half, dtypes.bool]

def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
# if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}"
# else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
val = render_val(x, dtype)
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
# if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"]
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val

def render_local(self, dest, name, size, dtype) -> List[str]:
Expand All @@ -286,11 +285,7 @@ def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]

def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
if atype == dtypes.bool:
st = f".b{self.types[dtype][1:]}"
return[f".reg {st} {d}_bin;",
f"selp{st} {d}_bin, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};",
f"mov{st} {d}, {d}_bin;"]
if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
Expand Down

0 comments on commit a5281b0

Please sign in to comment.