Skip to content

Commit

Permalink
move to_python_const from onnx_ops to onnx (tinygrad#8158)
Browse files Browse the repository at this point in the history
* move to_python_const out

* move more over

* try deleting alternative gather implementation

* Revert "try deleting alternative gather implementation"

This reverts commit d46b30b.

* add types to onnx ops

* better debug msg

* improve some com.microsoft too

---------

Co-authored-by: chenyu <[email protected]>
  • Loading branch information
geohotstan and chenyuxyz authored Dec 17, 2024
1 parent 21b085b commit 32c995a
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 143 deletions.
21 changes: 17 additions & 4 deletions extra/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ def get_run_onnx(onnx_model: ModelProto):
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
}

# these values are expected to be python consts
required_input_python_consts: Dict[str, tuple[int, ...]] = {
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
"CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
}

# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
# parses and validates inputs based on their shape and dtype specified by model
def prepare_input(user_input:Any, model_input:ValueInfoProto):
Expand Down Expand Up @@ -121,11 +130,15 @@ def run_onnx(inputs={}, debug=0):
model_tensors[name] = prepare_input(inputs[name], value_info)

for num,n in enumerate(onnx_model.graph.node):
inp = [model_tensors.get(x) for x in n.input]
inp_tensors = [model_tensors.get(x) for x in n.input]
required_consts = required_input_python_consts.get(n.op_type, ())
inp = [to_python_const(t) if i in required_consts else t for i,t in enumerate(inp_tensors)]
opt = model_attributes[num]

if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
if debug >= 3: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {t}" for i,(x,t) in enumerate(zip(n.input, inp))))
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp_tensors]} opt {opt}")
if debug >= 3:
print("\tinputs:")
print("\n".join(f"\t\t{x} - {t}" + (" (to_python_const)" if i in required_consts else "") for i,(x,t) in enumerate(zip(n.input, inp))))

if n.op_type in tensor_methods:
ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
Expand All @@ -134,7 +147,7 @@ def run_onnx(inputs={}, debug=0):
elif n.op_type == "Split":
axis, n_outputs = opt.get('axis', 0), opt.get('num_outputs') or len(n.output)
sz = inp[0].shape[axis]
sizes = to_python_const(inp[1]) if len(inp) == 2 else [sz // n_outputs + (1 if i < sz % n_outputs else 0) for i in range(n_outputs)]
sizes = inp[1] if len(inp) == 2 else [sz // n_outputs + (1 if i < sz % n_outputs else 0) for i in range(n_outputs)]
ret = inp[0].split(sizes, axis)
elif n.op_type == "Gradient":
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
Expand Down
Loading

0 comments on commit 32c995a

Please sign in to comment.