-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minor cleanup in executor.cpp #2750
Changes from 1 commit
d74a8bf
11df23a
78ca610
75543ce
549d150
3f1036b
f23b40b
453e44f
9943152
97ef121
5826645
9bc2af8
8029b1b
8ecc5e7
1db8e06
8f3ff98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,7 +65,7 @@ class FusionExecutor : public NonCopyable { | |
//! Notes: 1. This API should ignore aliased outputs instead of | ||
//! pushing scalar int 0 as a place-holder. | ||
//! 2. This API does not allocate output in memory, but only returns the | ||
//! inferred output sizes. | ||
//! inferred output sizes. Used in kernel_cache.cpp. | ||
KernelArgumentHolder inferOutputSizes( | ||
Fusion* fusion, | ||
const KernelArgumentHolder& args, | ||
|
@@ -118,10 +118,14 @@ class FusionExecutor : public NonCopyable { | |
|
||
//! Computes fusion outputs through expression evaluator. | ||
std::vector<at::Tensor> evaluateFusionOutputs( | ||
KernelArgumentHolder& args, | ||
std::vector<at::Tensor> outputs, | ||
ExpressionEvaluator& expr_eval); | ||
|
||
// TODO: args shouldn't come in a reference here because we will append the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jjsjann123 see the todo here. We don't ever use args as they're updated with the outputs. We always pass it back as an array of tensors. So it is different behavior with |
||
// outputs to be able to send it to the kernel. For now none of the users are | ||
// reconsuming the args, so it is okay. It isn't done now because changing it | ||
// from a reference makes a call as runFusion({}) ambiguous, and that is used | ||
// in some places in the codebase. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be really interesting to see how we can resolve this cleanly, but I don't know any way to do that without changing the call site to not use braced-init-list. This gives me headache: https://en.cppreference.com/w/cpp/language/overload_resolution#Implicit_conversion_sequence_in_list-initialization I don't know anything that can help overload resolution? wondering if @zasdfgbnm knows any dark magic? |
||
NVF_API std::vector<at::Tensor> runFusion( | ||
KernelArgumentHolder& args, | ||
const LaunchParams& launch_constraints = LaunchParams(), | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -705,7 +705,18 @@ ExpressionEvaluator bindInputs( | |||
// NOTE: we bind all inputs here, including at::Tensors. This means that | ||||
// expr_eval will create a PolymorphicValue containing *args[i], which means | ||||
// that at::Tensor's lifetime will be at least as long as that of expr_eval. | ||||
expr_eval.bind(inputs[i], *args[i], true); | ||||
try { | ||||
expr_eval.bind(inputs[i], *args[i], true); | ||||
} catch (const nvfError& e) { | ||||
std::stringstream ss; | ||||
ss << "When trying to run the provided host program," | ||||
<< " there was an error with the provided input " << i | ||||
<< ". Provided input was:\n "; | ||||
ss << PolymorphicValue_functions::toString(*args[i]); | ||||
ss << "\n which does not match the expected input:\n "; | ||||
ss << inputs[i]->toString() << "\n"; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this instance I didn't think it was particularly helpful, as it's the caller that we would typically want to point at, not the inside of expression evaluator which is the first to find it. I think we'd actually want the error to be thrown higher where WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I trust your judgement -- I haven't ran into enough errors to have a strong opinion. AFAICT, Fuser/csrc/evaluator_common.cpp Line 369 in 346e51c
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, that's why I lifted the error in the common place we bind a bunch of inputs to expression evaluator provided from someplace not generated by nvFuser (developers in tests and Thunder in integration). |
||||
NVF_THROW(ss.str()); | ||||
} | ||||
} | ||||
|
||||
return expr_eval; | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wanted a throw to be able to get some better error messages in bindInputs.