From 0639f0a6367c0c469081e06a2e224c00511c74c5 Mon Sep 17 00:00:00 2001 From: Merlin Nimier-David Date: Tue, 26 Nov 2024 16:20:22 +0100 Subject: [PATCH] Call: allow isolation per (variant, domain, scope) The actual values of `scope` are set to 0 for now. --- include/drjit/autodiff.h | 5 +- include/drjit/call.h | 51 ++++++++++++++++----- include/drjit/extra.h | 8 ++-- include/drjit/jit.h | 12 +++-- src/extra/call.cpp | 98 ++++++++++++++++++++-------------------- src/python/switch.cpp | 34 ++++++++++---- tests/call_ext.cpp | 2 +- 7 files changed, 129 insertions(+), 81 deletions(-) diff --git a/include/drjit/autodiff.h b/include/drjit/autodiff.h index 428233d50..84915a5ee 100644 --- a/include/drjit/autodiff.h +++ b/include/drjit/autodiff.h @@ -761,7 +761,8 @@ struct DRJIT_TRIVIAL_ABI DiffArray if constexpr (!IsClass) return out; else - return (Value) jit_registry_ptr(Backend, CallSupport::Domain, out); + return (Value) jit_registry_ptr( + CallSupport::Variant, CallSupport::Domain, /* TODO: scope */ 0, out); } bool schedule_() const { return jit_var_schedule((uint32_t) m_index); } @@ -929,7 +930,7 @@ void backward_to(const T &value, uint32_t flags = (uint32_t) ADFlag::Default) { template void forward_from(T &value, uint32_t flags = (uint32_t) ADFlag::Default) { detail::check_grad_enabled("forward_from", value); - + if constexpr (is_complex_v) set_grad(value, T(1.f, 1.f)); else if constexpr (is_quaternion_v) diff --git a/include/drjit/call.h b/include/drjit/call.h index 87a24baff..bf0891f5a 100644 --- a/include/drjit/call.h +++ b/include/drjit/call.h @@ -18,6 +18,22 @@ NAMESPACE_BEGIN(drjit) +NAMESPACE_BEGIN(detail) + +template +using has_variant_override = decltype(T::variant_()); + +template +constexpr const char *get_variant(const char *fallback) { + if constexpr (is_detected_v) { + return CallSupport::variant_(); + } else { + return fallback; + } +} + +NAMESPACE_END(detail) + #define DRJIT_CALL_BEGIN(Name) \ namespace drjit { \ template \ @@ -25,7 +41,7 @@ NAMESPACE_BEGIN(drjit) using Base_ = void; \ using Class_ = Name; \ using Mask_ = mask_t; \ - static constexpr const char *Domain = #Name; \ + using CallSupport_ = call_support; \ call_support(const Self &self) : self(self) { } \ const call_support *operator->() const { \ return this; \ @@ -38,7 +54,7 @@ NAMESPACE_BEGIN(drjit) using Base_ = void; \ using Class_ = Name; \ using Mask_ = mask_t; \ - static constexpr const char *Domain = #Name; \ + using CallSupport_ = call_support, Self>; \ call_support(const Self &self) : self(self) { } \ const call_support *operator->() const { \ return this; \ @@ -51,6 +67,7 @@ NAMESPACE_BEGIN(drjit) : call_support, Self> { \ using Base_ = call_support, Self>; \ using Base_::self; \ + using Base_::Variant; \ using Base_::Domain; \ using Class_ = Name; \ using Mask_ = mask_t; \ @@ -64,6 +81,12 @@ NAMESPACE_BEGIN(drjit) } #define DRJIT_CALL_END(Name) \ + public: \ + static constexpr const char *Domain = #Name; \ + /* Define `Variant` at the end so that the optional `variant_()`*/ \ + /* method provided by the user can be detected (if given). */ \ + static constexpr const char *Variant = \ + detail::get_variant(""); \ protected: \ const Self &self; \ }; \ @@ -104,7 +127,8 @@ private: \ }; \ \ return detail::call( \ - self, Domain, #Name "()", false, callback, args...); \ + self, Variant, Domain, /* TODO: scope */ 0, #Name "()", false, \ + callback, args...); \ } #define DRJIT_CALL_GETTER(Name) \ @@ -125,8 +149,9 @@ public: \ state->collect_rv(rv_i); \ }; \ \ - return detail::call(self, Domain, #Name "()", \ - true, callback, mask); \ + return detail::call( \ + self, Variant, Domain, /* TODO: scope */ 0, #Name "()", true, \ + callback, mask); \ } template using vectorize_rv_t = @@ -183,8 +208,9 @@ template struct CallState { }; template -Ret call(const Self &self, const char *domain, const char *name, - bool is_getter, ad_call_func callback, const Args &...args) { +Ret call(const Self &self, const char *variant, const char *domain, + uint32_t scope, const char *name, bool is_getter, + ad_call_func callback, const Args &...args) { using Mask = mask_t; using CallStateT = CallState; CallStateT *state = new CallStateT(args...); @@ -193,9 +219,9 @@ Ret call(const Self &self, const char *domain, const char *name, index64_vector args_i, rv_i; collect_indices(state->args, args_i); - bool done = ad_call(Self::Backend, domain, -1, 0, name, is_getter, - self.index(), mask.index(), args_i, rv_i, state, - callback, &CallStateT::cleanup, true); + bool done = ad_call(Self::Backend, variant, domain, scope, -1, 0, name, + is_getter, self.index(), mask.index(), args_i, rv_i, + state, callback, &CallStateT::cleanup, true); if constexpr (!std::is_same_v) { Ret2 result(std::move(state->rv)); @@ -242,8 +268,9 @@ auto dispatch_impl(std::index_sequence, const Self &self, const Func &fun }; return detail::call( - self, Self::CallSupport::Domain, "drjit::dispatch()", false, callback, - func, args...); + self, Self::CallSupport::Variant, Self::CallSupport::Domain, + /* TODO: scope */ 0, "drjit::dispatch()", false, callback, func, + args...); } NAMESPACE_END(detail) diff --git a/include/drjit/extra.h b/include/drjit/extra.h index 422d1206e..2b344f575 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -331,11 +331,11 @@ typedef void (*ad_call_cleanup)(void*); * already been destroyed. */ extern DRJIT_EXTRA_EXPORT bool -ad_call(JitBackend backend, const char *domain, int symbolic, size_t callable_count, - const char *name, bool is_getter, uint32_t index, uint32_t mask, +ad_call(JitBackend backend, const char *variant, const char *domain, + uint32_t scope, int symbolic, size_t callable_count, const char *name, + bool is_getter, uint32_t index, uint32_t mask, const drjit::vector &args, drjit::vector &rv, - void *payload, ad_call_func callback, ad_call_cleanup cleanup, - bool ad); + void *payload, ad_call_func callback, ad_call_cleanup cleanup, bool ad); // Callbacks used by \ref ad_loop() below. See the interface for details typedef void (*ad_loop_read)(void *payload, drjit::vector &); diff --git a/include/drjit/jit.h b/include/drjit/jit.h index d89335fd1..620ce6e17 100644 --- a/include/drjit/jit.h +++ b/include/drjit/jit.h @@ -436,7 +436,8 @@ struct DRJIT_TRIVIAL_ABI JitArray jit_memcpy(Backend, temp, data(), size * sizeof(uint32_t)); for (uint32_t i = 0; i < size; i++) ((void **) ptr)[i] = - jit_registry_ptr(Backend, CallSupport::Domain, temp[i]); + jit_registry_ptr(CallSupport::Variant, CallSupport::Domain, + /* TODO: scope */ 0, temp[i]); delete[] temp; } } @@ -555,8 +556,9 @@ struct DRJIT_TRIVIAL_ABI JitArray drjit_fail("Unsupported operand type"); } else { uint32_t bucket_count = 0; - CallBucket *buckets = jit_var_call_reduce( - Backend, CallSupport::Domain, m_index, &bucket_count); + CallBucket *buckets = jit_var_call_reduce( + Backend, CallSupport::Variant, CallSupport::Domain, + /* TODO: scope */ 0, m_index, &bucket_count); return { buckets, bucket_count }; } } @@ -622,7 +624,9 @@ struct DRJIT_TRIVIAL_ABI JitArray if constexpr (!IsClass) return out; else - return (Value) jit_registry_ptr(Backend, CallSupport::Domain, out); + return (Value) jit_registry_ptr(CallSupport::Variant, + CallSupport::Domain, + /* TODO: scope */ 0, out); } template && std::is_same_v> = 0> diff --git a/src/extra/call.cpp b/src/extra/call.cpp index d0d311efb..9fc4b7fe8 100644 --- a/src/extra/call.cpp +++ b/src/extra/call.cpp @@ -64,14 +64,13 @@ static void ad_call_check_rv(JitBackend backend, size_t size, const vector &rv2); // Strategy 1: this is a getter. turn the call into a gather operation -static void ad_call_getter(JitBackend backend, const char *domain, - const char *name, size_t size, uint32_t index, - uint32_t mask_, size_t callable_count, - const vector args, +static void ad_call_getter(JitBackend backend, const char *variant, + const char *domain, uint32_t scope, const char *name, + size_t size, uint32_t index, uint32_t mask_, + size_t callable_count, const vector args, vector &rv, vector &rv_ad, ad_call_func func, void *payload, - dr::vector &implicit_in, - bool ad) { + dr::vector &implicit_in, bool ad) { index64_vector args2; // unused vector rv2; @@ -99,7 +98,7 @@ static void ad_call_getter(JitBackend backend, const char *domain, void *ptr; if (domain) { - ptr = jit_registry_ptr(backend, domain, (uint32_t) i + 1); + ptr = jit_registry_ptr(variant, domain, scope, (uint32_t) i + 1); if (!ptr) continue; } else { @@ -241,13 +240,13 @@ static void ad_call_getter(JitBackend backend, const char *domain, } // Strategy 2: perform indirection symbolically by tracing all callables -static void ad_call_symbolic(JitBackend backend, const char *domain, +static void ad_call_symbolic(JitBackend backend, const char *variant, + const char *domain, uint32_t scope, const char *name, size_t size, uint32_t index, uint32_t mask_, size_t callable_count, - const vector args, - vector &rv, vector &rv_ad, - ad_call_func func, void *payload, - dr::vector &implicit_in, + const vector args, vector &rv, + vector &rv_ad, ad_call_func func, + void *payload, dr::vector &implicit_in, bool ad) { (void) domain; (void) size; @@ -307,7 +306,7 @@ static void ad_call_symbolic(JitBackend backend, const char *domain, void *ptr; if (domain) { - ptr = jit_registry_ptr(backend, domain, (uint32_t) i + 1); + ptr = jit_registry_ptr(variant, domain, scope, (uint32_t) i + 1); if (!ptr) continue; } else { @@ -373,12 +372,12 @@ static void ad_call_symbolic(JitBackend backend, const char *domain, } // Strategy 3: group the arguments and evaluate a kernel per callable -static void ad_call_reduce(JitBackend backend, const char *domain, - const char *name, size_t size, uint32_t index_, - uint32_t mask_, size_t callable_count, - const vector args_, - vector &rv, - ad_call_func func, void *payload) { +static void ad_call_reduce(JitBackend backend, const char *variant, + const char *domain, uint32_t scope, const char *name, + size_t size, uint32_t index_, uint32_t mask_, + size_t callable_count, const vector args_, + vector &rv, ad_call_func func, + void *payload) { (void) name; // unused const char *domain_or_empty = domain ? domain : "", *separator = domain ? "::" : ""; @@ -414,8 +413,8 @@ static void ad_call_reduce(JitBackend backend, const char *domain, } uint32_t n_inst = (uint32_t) callable_count; - CallBucket *buckets = - jit_var_call_reduce(backend, domain, index.index(), &n_inst); + CallBucket *buckets = jit_var_call_reduce(backend, variant, domain, scope, + index.index(), &n_inst); index64_vector args2(args.size(), 0); args2.clear(); @@ -456,7 +455,7 @@ static void ad_call_reduce(JitBackend backend, const char *domain, void *ptr; if (domain) { - ptr = jit_registry_ptr(backend, domain, buckets[i].id); + ptr = jit_registry_ptr(variant, domain, scope, buckets[i].id); if (!ptr) jit_raise( "ad_call_reduce(\"%s%s%s\"): instance %u does not exist (or no longer exists).", @@ -559,13 +558,14 @@ static void ad_call_check_rv(JitBackend backend, size_t size, /// CustomOp that hooks a recorded virtual function call into the AD graph struct CallOp : public dr::detail::CustomOpBase { public: - CallOp(JitBackend backend, std::string &&name, const char *domain, - uint32_t index, uint32_t mask, size_t callable_count, - const vector &args, size_t rv_size, void *payload, - ad_call_func func, ad_call_cleanup cleanup) - : m_name(std::move(name)), m_domain(domain), m_index(index), m_mask(mask), - m_callable_count(callable_count), m_payload(payload), - m_func(func), m_cleanup(cleanup) { + CallOp(JitBackend backend, std::string &&name, const char *variant, + const char *domain, uint32_t scope, uint32_t index, uint32_t mask, + size_t callable_count, const vector &args, size_t rv_size, + void *payload, ad_call_func func, ad_call_cleanup cleanup) + : m_name(std::move(name)), m_variant(variant), m_domain(domain), + m_scope(scope), m_index(index), m_mask(mask), + m_callable_count(callable_count), m_payload(payload), m_func(func), + m_cleanup(cleanup) { m_backend = backend; jit_var_inc_ref(m_index); @@ -611,8 +611,8 @@ struct CallOp : public dr::detail::CustomOpBase { args.push_back_steal(ad_grad(combine(m_input_indices[i]))); ad_call( - m_backend, m_domain, 1, m_callable_count, name.c_str(), false, - m_index, m_mask, args, rv, this, + m_backend, m_variant, m_domain, m_scope, 1, m_callable_count, + name.c_str(), false, m_index, m_mask, args, rv, this, [](void *ptr, void *self, const vector &args, vector &rv) { ((CallOp *) ptr)->forward_cb(self, args, rv); @@ -644,8 +644,8 @@ struct CallOp : public dr::detail::CustomOpBase { args.push_back_steal(ad_grad(combine(m_output_indices[i]))); ad_call( - m_backend, m_domain, 1, m_callable_count, name.c_str(), false, - m_index, m_mask, args, rv, this, + m_backend, m_variant, m_domain, m_scope, 1, m_callable_count, + name.c_str(), false, m_index, m_mask, args, rv, this, [](void *ptr, void *self, const vector &args, vector &rv) { ((CallOp *) ptr)->backward_cb(self, args, rv); @@ -766,7 +766,9 @@ struct CallOp : public dr::detail::CustomOpBase { private: std::string m_name, m_name_op; + const char *m_variant; const char *m_domain; + uint32_t m_scope; uint32_t m_index, m_mask; size_t m_callable_count; index32_vector m_args; @@ -782,11 +784,11 @@ struct CallOp : public dr::detail::CustomOpBase { }; // Generic checks, then forward either to ad_call_symbolic or ad_call_reduce -bool ad_call(JitBackend backend, const char *domain, int symbolic, - size_t callable_count, const char *name, bool is_getter, - uint32_t index, uint32_t mask, const vector &args, - vector &rv, void *payload, ad_call_func func, - ad_call_cleanup cleanup, bool ad) { +bool ad_call(JitBackend backend, const char *variant, const char *domain, + uint32_t scope, int symbolic, size_t callable_count, + const char *name, bool is_getter, uint32_t index, uint32_t mask, + const vector &args, vector &rv, void *payload, + ad_call_func func, ad_call_cleanup cleanup, bool ad) { try { const char *domain_or_empty = domain ? domain : "", *separator = domain ? "::" : ""; @@ -815,7 +817,7 @@ bool ad_call(JitBackend backend, const char *domain, int symbolic, jit_raise("ad_call(): 'symbolic' must be -1, 0, or 1!"); if (domain) - callable_count = jit_registry_id_bound(backend, domain); + callable_count = jit_registry_id_bound(variant, domain, scope); size_t size = jit_var_size(index); if (mask) { @@ -865,13 +867,13 @@ bool ad_call(JitBackend backend, const char *domain, int symbolic, dr::detail::ad_index32_vector implicit_in; if (is_getter) { - ad_call_getter(backend, domain, name, size, index, mask, - callable_count, args, rv, rv_ad, func, payload, + ad_call_getter(backend, variant, domain, scope, name, size, index, + mask, callable_count, args, rv, rv_ad, func, payload, implicit_in, ad); } else if (symbolic) { - ad_call_symbolic(backend, domain, name, size, index, mask, - callable_count, args, rv, rv_ad, func, payload, - implicit_in, ad); + ad_call_symbolic(backend, variant, domain, scope, name, size, index, + mask, callable_count, args, rv, rv_ad, func, + payload, implicit_in, ad); } else { if (jit_flag(JitFlag::SymbolicScope)) jit_raise( @@ -881,8 +883,8 @@ bool ad_call(JitBackend backend, const char *domain, int symbolic, "documentation of drjit.JitFlag.SymbolicCalls and drjit.switch() for general\n" "information on symbolic and evaluated calls, as well as their limitations."); - ad_call_reduce(backend, domain, name, size, index, mask, - callable_count, args, rv, func, payload); + ad_call_reduce(backend, variant, domain, scope, name, size, index, + mask, callable_count, args, rv, func, payload); ad = false; // derivative already tracked, no CustomOp needed } @@ -898,8 +900,8 @@ bool ad_call(JitBackend backend, const char *domain, int symbolic, } nanobind::ref op = new CallOp( - backend, std::move(combined), domain, index, mask, callable_count, - args, rv.size(), payload, func, cleanup); + backend, std::move(combined), variant, domain, scope, index, + mask, callable_count, args, rv.size(), payload, func, cleanup); for (size_t i = 0; i < args.size(); ++i) op->add_input(i, args[i]); diff --git a/src/python/switch.cpp b/src/python/switch.cpp index 47e7b2642..792946a73 100644 --- a/src/python/switch.cpp +++ b/src/python/switch.cpp @@ -166,8 +166,9 @@ nb::object switch_impl(nb::handle index_, nb::sequence targets, } bool done = ad_call( - (JitBackend) s.backend, nullptr, symbolic, nb::len(targets), - label.c_str(), false, (uint32_t) s.index(inst_ptr(index)), + (JitBackend) s.backend, /* variant */ nullptr, /* domain */ nullptr, + /* scope */ 0, symbolic, nb::len(targets), label.c_str(), false, + (uint32_t) s.index(inst_ptr(index)), mask.is_valid() ? ((uint32_t) s.index(inst_ptr(mask))) : 0u, args_i, rv_i, state, func, cleanup, true); @@ -195,7 +196,9 @@ nb::object dispatch_impl(nb::handle_t inst, nb::object target_o; nb::object rv_o; JitBackend backend; + nb::str variant_name; nb::str domain_name; + uint32_t scope; ~State() { if (!nb::is_alive()) @@ -204,6 +207,7 @@ nb::object dispatch_impl(nb::handle_t inst, args_o.reset(); target_o.reset(); rv_o.reset(); + variant_name.reset(); domain_name.reset(); } }; @@ -212,6 +216,11 @@ nb::object dispatch_impl(nb::handle_t inst, if (!s.is_class || s.ndim != 1) nb::raise("drjit.dispatch(): 'inst' parameter must be an instance array."); + nb::object variant_name = nb::getattr(inst.type(), "Variant", nb::handle()); + if (!variant_name.is_valid() || !nb::isinstance(variant_name)) + nb::raise("drjit.dispatch(): The instance array type ('%s') lacks the " + "'Variant' name attribute.", nb::type_name(inst.type()).c_str()); + nb::object domain_name = nb::getattr(inst.type(), "Domain", nb::handle()); if (!domain_name.is_valid() || !nb::isinstance(domain_name)) nb::raise("drjit.dispatch(): The instance array type ('%s') lacks the " @@ -231,8 +240,10 @@ nb::object dispatch_impl(nb::handle_t inst, state.args_o = nb::borrow(update_indices(state.args_o, args_i)); - if (!self) - self = jit_registry_peek(state.backend, state.domain_name.c_str()); + if (!self) { + self = jit_registry_peek(state.variant_name.c_str(), + state.domain_name.c_str(), state.scope); + } nb::object self_o = nb::steal(nb::detail::nb_type_put( state.type, self, nb::rv_policy::reference, nullptr)); @@ -253,7 +264,9 @@ nb::object dispatch_impl(nb::handle_t inst, target, nb::object(), (JitBackend) s.backend, - nb::borrow(domain_name) + nb::borrow(variant_name), + nb::borrow(domain_name), + /* TODO: scope */ 0, }; ad_call_cleanup cleanup = [](void *ptr) { @@ -272,11 +285,12 @@ nb::object dispatch_impl(nb::handle_t inst, mask = mask_tp(mask); } - bool done = ad_call( - (JitBackend) s.backend, state->domain_name.c_str(), symbolic, 0, - label.c_str(), false, (uint32_t) s.index(inst_ptr(inst)), - mask.is_valid() ? ((uint32_t) s.index(inst_ptr(mask))) : 0u, args_i, - rv_i, state, target_cb, cleanup, true); + bool done = + ad_call((JitBackend) s.backend, state->variant_name.c_str(), + state->domain_name.c_str(), state->scope, symbolic, 0, + label.c_str(), false, (uint32_t) s.index(inst_ptr(inst)), + mask.is_valid() ? ((uint32_t) s.index(inst_ptr(mask))) : 0u, + args_i, rv_i, state, target_cb, cleanup, true); nb::object result = ::update_indices(state->rv_o, rv_i); diff --git a/tests/call_ext.cpp b/tests/call_ext.cpp index c629c7f31..43a94fb72 100644 --- a/tests/call_ext.cpp +++ b/tests/call_ext.cpp @@ -50,7 +50,7 @@ template struct Base : nb::intrusive_base { Base() { if constexpr (dr::is_jit_v) - jit_registry_put(dr::backend_v, "Base", this); + jit_registry_put("", "Base", 0, this); } virtual ~Base() { jit_registry_remove(this); }