Skip to content

Commit

Permalink
fix elementwise with const
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Jan 24, 2024
1 parent af49245 commit 16ae9bc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions tinygrad/features/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def is_unrealized_contiguous_const(self): return False

# passthroughs
def schedule(self, seen=None): return create_schedule(self.real_lbs, seen)
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.real_lbs], self.axis)
def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.real_lbs], self.axis)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.real_lbs], self.axis)
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.real_lbs], self.axis, self.real)
def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.real_lbs], self.axis, self.real)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.real_lbs], self.axis, self.real)

# elementwise is simple
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLazyBuffer, arg:Optional[Any]=None) -> MultiLazyBuffer:
Expand All @@ -68,12 +68,17 @@ def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:MultiLa
print(f"{msrcs=}")
if not all(self.real):
out_real = [False] * len(self.real)
out_lbs = self.lbs
out_lbs = [[] for _ in self.real]
print(f"{msrcs=}")
for src in msrcs:
print(f"{src.real=}")
for i, (lb, r) in enumerate(zip(src.lbs, src.real)):
if r:
out_real[i] = True
out_lbs[i] = lb
out_lbs[i].append(lb)
for i, lbs in enumerate(out_lbs):
out_lbs[i] = self.lbs[i] if not lbs else lbs[0] if len(lbs)==1 else lbs[0].e(op, *lbs[1:])
print(f"{out_lbs=}, {out_real=}")
return MultiLazyBuffer(out_lbs, self.axis, out_real)
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:
if 0 in self.shape: return self, self.full_like(y)
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
else: y_dtype = dtypes.from_py(y)
y = Tensor(y, self.device, y_dtype, requires_grad=False)
y = Tensor(self.lazydata.const(y), self.device, requires_grad=False)

if match_dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
Expand Down

0 comments on commit 16ae9bc

Please sign in to comment.