Skip to content

Commit

Permalink
remove duplicated UOp in Tensor init types [pr] (tinygrad#8177)
Browse files Browse the repository at this point in the history
and a small comment
  • Loading branch information
chenyuxyz authored Dec 12, 2024
1 parent d240bdd commit 97aaa50
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
if DEBUG >= 5:
print(ast)
if DEBUG >= 5: print(ast)
k = Kernel(ast, opts=renderer).required_optimizations()
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class Tensor(SimpleMathTrait):
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False

def __init__(self, data:Union[None, ConstType, UOp, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self, data:Union[None, ConstType, UOp, bytes, List, Tuple, UOp, Mul
elif isinstance(data, (list, tuple)):
if dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
else: data = _frompy(data, dtype)
elif str(type(data)) == "<class 'numpy.ndarray'>":
Expand Down

0 comments on commit 97aaa50

Please sign in to comment.