Skip to content

Commit

Permalink
Call: allow isolation per (variant, domain, scope)
Browse files Browse the repository at this point in the history
The actual values of `scope` are set to 0 for now.
  • Loading branch information
merlinND committed Nov 27, 2024
1 parent b1eebb7 commit 0639f0a
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 81 deletions.
5 changes: 3 additions & 2 deletions include/drjit/autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,8 @@ struct DRJIT_TRIVIAL_ABI DiffArray
if constexpr (!IsClass)
return out;
else
return (Value) jit_registry_ptr(Backend, CallSupport::Domain, out);
return (Value) jit_registry_ptr(
CallSupport::Variant, CallSupport::Domain, /* TODO: scope */ 0, out);
}

bool schedule_() const { return jit_var_schedule((uint32_t) m_index); }
Expand Down Expand Up @@ -929,7 +930,7 @@ void backward_to(const T &value, uint32_t flags = (uint32_t) ADFlag::Default) {
template <typename T>
void forward_from(T &value, uint32_t flags = (uint32_t) ADFlag::Default) {
detail::check_grad_enabled("forward_from", value);

if constexpr (is_complex_v<T>)
set_grad(value, T(1.f, 1.f));
else if constexpr (is_quaternion_v<T>)
Expand Down
51 changes: 39 additions & 12 deletions include/drjit/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,30 @@

NAMESPACE_BEGIN(drjit)

NAMESPACE_BEGIN(detail)

template <typename T>
using has_variant_override = decltype(T::variant_());

template <typename CallSupport>
constexpr const char *get_variant(const char *fallback) {
if constexpr (is_detected_v<has_variant_override, CallSupport>) {
return CallSupport::variant_();
} else {
return fallback;
}
}

NAMESPACE_END(detail)

#define DRJIT_CALL_BEGIN(Name) \
namespace drjit { \
template <typename Self> \
struct call_support<Name, Self> { \
using Base_ = void; \
using Class_ = Name; \
using Mask_ = mask_t<Self>; \
static constexpr const char *Domain = #Name; \
using CallSupport_ = call_support<Name, Self>; \
call_support(const Self &self) : self(self) { } \
const call_support *operator->() const { \
return this; \
Expand All @@ -38,7 +54,7 @@ NAMESPACE_BEGIN(drjit)
using Base_ = void; \
using Class_ = Name<Ts...>; \
using Mask_ = mask_t<Self>; \
static constexpr const char *Domain = #Name; \
using CallSupport_ = call_support<Name<Ts...>, Self>; \
call_support(const Self &self) : self(self) { } \
const call_support *operator->() const { \
return this; \
Expand All @@ -51,6 +67,7 @@ NAMESPACE_BEGIN(drjit)
: call_support<Parent<Ts...>, Self> { \
using Base_ = call_support<Parent<Ts...>, Self>; \
using Base_::self; \
using Base_::Variant; \
using Base_::Domain; \
using Class_ = Name<Ts...>; \
using Mask_ = mask_t<Self>; \
Expand All @@ -64,6 +81,12 @@ NAMESPACE_BEGIN(drjit)
}

#define DRJIT_CALL_END(Name) \
public: \
static constexpr const char *Domain = #Name; \
/* Define `Variant` at the end so that the optional `variant_()`*/ \
/* method provided by the user can be detected (if given). */ \
static constexpr const char *Variant = \
detail::get_variant<CallSupport_>(""); \
protected: \
const Self &self; \
}; \
Expand Down Expand Up @@ -104,7 +127,8 @@ private: \
}; \
\
return detail::call<Self, Ret, Ret2, Args...>( \
self, Domain, #Name "()", false, callback, args...); \
self, Variant, Domain, /* TODO: scope */ 0, #Name "()", false, \
callback, args...); \
}

#define DRJIT_CALL_GETTER(Name) \
Expand All @@ -125,8 +149,9 @@ public: \
state->collect_rv(rv_i); \
}; \
\
return detail::call<Self, Ret, Ret, Mask_>(self, Domain, #Name "()", \
true, callback, mask); \
return detail::call<Self, Ret, Ret, Mask_>( \
self, Variant, Domain, /* TODO: scope */ 0, #Name "()", true, \
callback, mask); \
}
template <typename Guide, typename T>
using vectorize_rv_t =
Expand Down Expand Up @@ -183,8 +208,9 @@ template <typename Ret, typename... Args> struct CallState {
};

template <typename Self, typename Ret, typename Ret2, typename... Args>
Ret call(const Self &self, const char *domain, const char *name,
bool is_getter, ad_call_func callback, const Args &...args) {
Ret call(const Self &self, const char *variant, const char *domain,
uint32_t scope, const char *name, bool is_getter,
ad_call_func callback, const Args &...args) {
using Mask = mask_t<Self>;
using CallStateT = CallState<Ret2, Args...>;
CallStateT *state = new CallStateT(args...);
Expand All @@ -193,9 +219,9 @@ Ret call(const Self &self, const char *domain, const char *name,

index64_vector args_i, rv_i;
collect_indices<true>(state->args, args_i);
bool done = ad_call(Self::Backend, domain, -1, 0, name, is_getter,
self.index(), mask.index(), args_i, rv_i, state,
callback, &CallStateT::cleanup, true);
bool done = ad_call(Self::Backend, variant, domain, scope, -1, 0, name,
is_getter, self.index(), mask.index(), args_i, rv_i,
state, callback, &CallStateT::cleanup, true);

if constexpr (!std::is_same_v<Ret, void>) {
Ret2 result(std::move(state->rv));
Expand Down Expand Up @@ -242,8 +268,9 @@ auto dispatch_impl(std::index_sequence<Is...>, const Self &self, const Func &fun
};

return detail::call<Self, Ret, Ret2, Func, Args...>(
self, Self::CallSupport::Domain, "drjit::dispatch()", false, callback,
func, args...);
self, Self::CallSupport::Variant, Self::CallSupport::Domain,
/* TODO: scope */ 0, "drjit::dispatch()", false, callback, func,
args...);
}

NAMESPACE_END(detail)
Expand Down
8 changes: 4 additions & 4 deletions include/drjit/extra.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ typedef void (*ad_call_cleanup)(void*);
* already been destroyed.
*/
extern DRJIT_EXTRA_EXPORT bool
ad_call(JitBackend backend, const char *domain, int symbolic, size_t callable_count,
const char *name, bool is_getter, uint32_t index, uint32_t mask,
ad_call(JitBackend backend, const char *variant, const char *domain,
uint32_t scope, int symbolic, size_t callable_count, const char *name,
bool is_getter, uint32_t index, uint32_t mask,
const drjit::vector<uint64_t> &args, drjit::vector<uint64_t> &rv,
void *payload, ad_call_func callback, ad_call_cleanup cleanup,
bool ad);
void *payload, ad_call_func callback, ad_call_cleanup cleanup, bool ad);

// Callbacks used by \ref ad_loop() below. See the interface for details
typedef void (*ad_loop_read)(void *payload, drjit::vector<uint64_t> &);
Expand Down
12 changes: 8 additions & 4 deletions include/drjit/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ struct DRJIT_TRIVIAL_ABI JitArray
jit_memcpy(Backend, temp, data(), size * sizeof(uint32_t));
for (uint32_t i = 0; i < size; i++)
((void **) ptr)[i] =
jit_registry_ptr(Backend, CallSupport::Domain, temp[i]);
jit_registry_ptr(CallSupport::Variant, CallSupport::Domain,
/* TODO: scope */ 0, temp[i]);
delete[] temp;
}
}
Expand Down Expand Up @@ -555,8 +556,9 @@ struct DRJIT_TRIVIAL_ABI JitArray
drjit_fail("Unsupported operand type");
} else {
uint32_t bucket_count = 0;
CallBucket *buckets = jit_var_call_reduce(
Backend, CallSupport::Domain, m_index, &bucket_count);
CallBucket *buckets = jit_var_call_reduce(
Backend, CallSupport::Variant, CallSupport::Domain,
/* TODO: scope */ 0, m_index, &bucket_count);
return { buckets, bucket_count };
}
}
Expand Down Expand Up @@ -622,7 +624,9 @@ struct DRJIT_TRIVIAL_ABI JitArray
if constexpr (!IsClass)
return out;
else
return (Value) jit_registry_ptr(Backend, CallSupport::Domain, out);
return (Value) jit_registry_ptr(CallSupport::Variant,
CallSupport::Domain,
/* TODO: scope */ 0, out);
}

template <typename T, enable_if_t<!std::is_void_v<T> && std::is_same_v<T, Value>> = 0>
Expand Down
Loading

0 comments on commit 0639f0a

Please sign in to comment.