From 1dbc3534761fd906631f4c8d2c26b3f54aaea342 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 1 Nov 2024 16:25:48 -0700 Subject: [PATCH] print all outputs in sdpa fwd --- csrc/ir/internal_nodes.h | 12 ++++++++++++ csrc/ir/nodes.cpp | 5 ++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index d33614c3e0c..4d5c41e3ec2 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2261,6 +2261,18 @@ class SdpaFwdOp : public Expr { return output(0)->as(); } + TensorView* logsumexp() const { + return output(1)->as(); + } + + TensorView* philox_seed() const { + return output(2)->as(); + } + + TensorView* philox_offset() const { + return output(3)->as(); + } + TensorView* query() const { return input(0)->as(); } diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3c7be6b6262..ca39d51684a 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4426,7 +4426,10 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(SdpaFwdOp) std::string SdpaFwdOp::toString(int indent_size) const { std::stringstream ss; - indent(ss, indent_size) << attn_out()->toString() << "\n"; + indent(ss, indent_size) << attn_out()->toString() << ",\n"; + indent(ss, indent_size) << logsumexp()->toString() << ",\n"; + indent(ss, indent_size) << philox_seed()->toString() << ",\n"; + indent(ss, indent_size) << philox_offset()->toString() << "\n"; indent(ss, indent_size + 1) << " = sdpa(" << query()->toString() << ",\n"; indent(ss, indent_size + 1) << " " << key()->toString() << ",\n"; indent(ss, indent_size + 1) << " " << value()->toString() << ",\n";