diff --git a/src/orange/OrangeData.hh b/src/orange/OrangeData.hh index ffd2444547..f310a79fdd 100644 --- a/src/orange/OrangeData.hh +++ b/src/orange/OrangeData.hh @@ -473,6 +473,7 @@ struct OrangeStateData // Scratch space with dimensions {track}{max_faces} Items temp_sense; + Items temp_sense_mod; // Scratch space with dimensions {track}{max_intersections} Items temp_face; @@ -500,6 +501,7 @@ struct OrangeStateData && vol.size() == max_depth * this->size() && universe.size() == max_depth * this->size() && !temp_sense.empty() + && !temp_sense_mod.empty() && !temp_face.empty() && temp_distance.size() == temp_face.size() && temp_isect.size() == temp_face.size(); @@ -533,6 +535,7 @@ struct OrangeStateData universe = other.universe; temp_sense = other.temp_sense; + temp_sense_mod = other.temp_sense_mod; temp_face = other.temp_face; temp_distance = other.temp_distance; @@ -576,6 +579,7 @@ inline void resize(OrangeStateData* data, size_type face_states = params.scalars.max_faces * num_tracks; resize(&data->temp_sense, face_states); + resize(&data->temp_sense_mod, face_states); size_type isect_states = params.scalars.max_intersections * num_tracks; resize(&data->temp_face, isect_states); diff --git a/src/orange/OrangeTrackView.hh b/src/orange/OrangeTrackView.hh index 74bae70e17..2c714961dd 100644 --- a/src/orange/OrangeTrackView.hh +++ b/src/orange/OrangeTrackView.hh @@ -198,6 +198,7 @@ class OrangeTrackView // Create local sense reference inline CELER_FUNCTION Span make_temp_sense() const; + inline CELER_FUNCTION Span make_temp_sense_mod() const; // Create local distance inline CELER_FUNCTION detail::TempNextFace make_temp_next() const; @@ -276,6 +277,7 @@ OrangeTrackView::operator=(Initializer_t const& init) local.volume = {}; local.surface = {}; local.temp_sense = this->make_temp_sense(); + local.temp_sense_mod = this->make_temp_sense_mod(); // Helpers for applying parent-to-daughter transformations TransformVisitor apply_transform{params_}; @@ -689,6 +691,7 @@ CELER_FUNCTION void OrangeTrackView::cross_boundary() local.volume = lsa.vol(); local.surface = {this->surf(), this->sense()}; local.temp_sense = this->make_temp_sense(); + local.temp_sense_mod = this->make_temp_sense_mod(); } TrackerVisitor visit_tracker{params_}; @@ -1047,6 +1050,18 @@ CELER_FUNCTION Span OrangeTrackView::make_temp_sense() const offset, max_faces); } +//---------------------------------------------------------------------------// +/*! + * Get a reference to the current volume, or to world volume if outside. + */ +CELER_FUNCTION Span OrangeTrackView::make_temp_sense_mod() const +{ + auto const max_faces = params_.scalars.max_faces; + auto offset = track_slot_.get() * max_faces; + return states_.temp_sense_mod[AllItems{}] + .subspan(offset, max_faces); +} + //---------------------------------------------------------------------------// /*! * Set up intersection scratch space. @@ -1089,6 +1104,7 @@ OrangeTrackView::make_local_state(LevelId level) const local.surface = {}; } local.temp_sense = this->make_temp_sense(); + local.temp_sense_mod = this->make_temp_sense_mod(); local.temp_next = this->make_temp_next(); return local; } diff --git a/src/orange/OrangeTypes.hh b/src/orange/OrangeTypes.hh index c6625d5ec1..1079174136 100644 --- a/src/orange/OrangeTypes.hh +++ b/src/orange/OrangeTypes.hh @@ -101,6 +101,18 @@ enum class Sense : bool outside, //!< Expression is greater than zero }; +//---------------------------------------------------------------------------// +/*! + * Transformations to apply to senses when using lazy sense evaluation. + */ +enum class SenseMod : unsigned char +{ + normal = 0, + flipped = 1 << 0, + cached = 1 << 1, +}; +using SenseModFlags = std::underlying_type_t; + //---------------------------------------------------------------------------// /*! * Enumeration for mapping surface classes to integers. @@ -378,6 +390,46 @@ CELER_CONSTEXPR_FUNCTION Sense to_sense(bool s) return static_cast(-static_cast(orig)); } +//---------------------------------------------------------------------------// +/*! + * Check if a sense modifier is set. + */ +[[nodiscard]] CELER_CONSTEXPR_FUNCTION bool +is_sense_mod_set(SenseMod mod, SenseModFlags flags) +{ + return (flags & static_cast(mod)) != 0; +} + +//---------------------------------------------------------------------------// +/*! + * Set a sense modifier. + */ +[[nodiscard]] CELER_CONSTEXPR_FUNCTION SenseModFlags +set_sense_mod(SenseMod mod, SenseModFlags flags) +{ + return flags | static_cast(mod); +} + +//---------------------------------------------------------------------------// +/*! + * Unset a sense modifier. + */ +[[nodiscard]] CELER_CONSTEXPR_FUNCTION SenseModFlags +unset_sense_mod(SenseMod mod, SenseModFlags flags) +{ + return flags & ~static_cast(mod); +} + +//---------------------------------------------------------------------------// +/*! + * Flip a sense modifier. + */ +[[nodiscard]] CELER_CONSTEXPR_FUNCTION SenseModFlags +flip_sense_mod(SenseMod mod, SenseModFlags flags) +{ + return flags ^ static_cast(mod); +} + //---------------------------------------------------------------------------// /*! * Change whether a boundary crossing is reentrant or exiting. diff --git a/src/orange/univ/SimpleUnitTracker.hh b/src/orange/univ/SimpleUnitTracker.hh index b1ae943666..4c91157941 100644 --- a/src/orange/univ/SimpleUnitTracker.hh +++ b/src/orange/univ/SimpleUnitTracker.hh @@ -10,10 +10,12 @@ #include "corecel/Assert.hh" #include "corecel/math/Algorithms.hh" #include "orange/OrangeData.hh" +#include "orange/OrangeTypes.hh" #include "orange/detail/BIHEnclosingVolFinder.hh" #include "orange/surf/LocalSurfaceVisitor.hh" #include "detail/InfixEvaluator.hh" +#include "detail/LazySenseCalculator.hh" #include "detail/LogicEvaluator.hh" #include "detail/SenseCalculator.hh" #include "detail/SurfaceFunctors.hh" @@ -165,22 +167,24 @@ SimpleUnitTracker::initialize(LocalState const& state) const -> Initialization CELER_EXPECT(params_); CELER_EXPECT(!state.surface && !state.volume); - detail::SenseCalculator calc_senses( - this->make_surface_visitor(), state.pos, state.temp_sense); + detail::LazySenseCalculator calc_senses(this->make_surface_visitor(), + state.pos, + state.temp_sense, + state.temp_sense_mod); // Use the BIH to locate a position that's inside, and save whether it's on // a surface in the found volume - bool on_surface{false}; - auto is_inside - = [this, &calc_senses, &on_surface](LocalVolumeId id) -> bool { + auto is_inside = [this, &calc_senses](LocalVolumeId id) -> bool { VolumeView vol = this->make_local_volume(id); - auto logic_state = calc_senses(vol); - on_surface = static_cast(logic_state.face); - return detail::LogicEvaluator(vol.logic())(logic_state.senses); + auto bind_calc_sense + = [&](FaceId face_id) { return calc_senses(vol, face_id); }; + auto inside = detail::LogicEvaluator(vol.logic())(bind_calc_sense); + calc_senses.invalidate_cache(); + return inside; }; LocalVolumeId id = this->find_volume_where(state.pos, is_inside); - if (on_surface) + if (static_cast(calc_senses.on_face())) { // Prohibit initialization on a surface id = {}; @@ -202,12 +206,14 @@ CELER_FUNCTION auto SimpleUnitTracker::cross_boundary(LocalState const& state) const -> Initialization { CELER_EXPECT(state.surface && state.volume); - detail::SenseCalculator calc_senses( - this->make_surface_visitor(), state.pos, state.temp_sense); detail::OnLocalSurface on_surface; - auto is_inside = [this, &state, &calc_senses, &on_surface]( - LocalVolumeId const& id) -> bool { + auto is_inside + = [this, &state, &on_surface](LocalVolumeId const& id) -> bool { + detail::LazySenseCalculator calc_senses(this->make_surface_visitor(), + state.pos, + state.temp_sense, + state.temp_sense_mod); if (id == state.volume) { // Cannot cross surface into the same volume @@ -215,13 +221,14 @@ SimpleUnitTracker::cross_boundary(LocalState const& state) const -> Initializati } VolumeView vol = this->make_local_volume(id); - auto logic_state - = calc_senses(vol, detail::find_face(vol, state.surface)); - - if (detail::LogicEvaluator(vol.logic())(logic_state.senses)) + auto on_face = detail::find_face(vol, state.surface); + auto bind_calc_sense = [&](FaceId face_id) { + return calc_senses(vol, face_id, on_face); + }; + if (detail::LogicEvaluator(vol.logic())(bind_calc_sense)) { // Inside: find and save the local surface ID, and end the search - on_surface = get_surface(vol, logic_state.face); + on_surface = get_surface(vol, calc_senses.on_face()); return true; } return false; @@ -528,13 +535,16 @@ SimpleUnitTracker::complex_intersect(LocalState const& state, CELER_ASSERT(num_isect > 0); // Calculate local senses, taking current face into account - auto logic_state = detail::SenseCalculator( - this->make_surface_visitor(), state.pos, state.temp_sense)( - vol, detail::find_face(vol, state.surface)); - + auto calc_senses = detail::LazySenseCalculator(this->make_surface_visitor(), + state.pos, + state.temp_sense, + state.temp_sense_mod); + auto bind_calc_sense = [&](FaceId face_id) { + return calc_senses(vol, face_id, detail::find_face(vol, state.surface)); + }; // Current senses should put us inside the volume detail::LogicEvaluator is_inside(vol.logic()); - CELER_ASSERT(is_inside(logic_state.senses)); + CELER_ASSERT(is_inside(bind_calc_sense)); // Loop over distances and surface indices to cross by iterating over // temp_next.isect[:num_isect]. @@ -547,16 +557,17 @@ SimpleUnitTracker::complex_intersect(LocalState const& state, // Face being crossed in this ordered intersection FaceId face = state.temp_next.face[isect]; // Flip the sense of the face being crossed - Sense new_sense = flip_sense(logic_state.senses[face.get()]); - logic_state.senses[face.unchecked_get()] = new_sense; - if (!is_inside(logic_state.senses)) + calc_senses.flip_sense(face); + if (!is_inside(bind_calc_sense)) { // Flipping this sense puts us outside the current volume: in // other words, only after crossing all the internal surfaces along // this direction do we hit a surface that actually puts us // outside. Intersection result; - result.surface = {vol.get_surface(face), flip_sense(new_sense)}; + + result.surface + = {vol.get_surface(face), flip_sense(bind_calc_sense(face))}; result.distance = state.temp_next.distance[isect]; CELER_ENSURE(result.distance > 0 && !std::isinf(result.distance)); return result; @@ -620,10 +631,15 @@ CELER_FUNCTION auto SimpleUnitTracker::background_intersect( { CELER_ASSERT(vid != state.volume); VolumeView vol = this->make_local_volume(vid); - auto logic_state = detail::SenseCalculator{ - this->make_surface_visitor(), pos, state.temp_sense}(vol); - - if (detail::LogicEvaluator{vol.logic()}(logic_state.senses)) + auto calc_senses + = detail::LazySenseCalculator{this->make_surface_visitor(), + pos, + state.temp_sense, + state.temp_sense_mod}; + auto bind_calc_sense + = [&](FaceId face_id) { return calc_senses(vol, face_id); }; + + if (detail::LogicEvaluator{vol.logic()}(bind_calc_sense)) { // We are in this new volume by crossing the tested surface. // Get the sense corresponding to this "crossed" surface. @@ -633,8 +649,7 @@ CELER_FUNCTION auto SimpleUnitTracker::background_intersect( Intersection result; result.distance = state.temp_next.distance[isect]; result.surface = detail::OnLocalSurface{ - surface, - flip_sense(logic_state.senses[face.unchecked_get()])}; + surface, flip_sense(bind_calc_sense(face))}; return result; } } diff --git a/src/orange/univ/detail/LazySenseCalculator.hh b/src/orange/univ/detail/LazySenseCalculator.hh new file mode 100644 index 0000000000..ca3cbd7c79 --- /dev/null +++ b/src/orange/univ/detail/LazySenseCalculator.hh @@ -0,0 +1,163 @@ +//----------------------------------*-C++-*----------------------------------// +// Copyright 2021-2024 UT-Battelle, LLC, and other Celeritas developers. +// See the top-level COPYRIGHT file for details. +// SPDX-License-Identifier: (Apache-2.0 OR MIT) +//---------------------------------------------------------------------------// +//! \file orange/univ/detail/LazySenseCalculator.hh +//---------------------------------------------------------------------------// +#pragma once + +#include "corecel/Assert.hh" +#include "corecel/cont/Range.hh" +#include "corecel/cont/Span.hh" +#include "orange/OrangeTypes.hh" +#include "orange/surf/LocalSurfaceVisitor.hh" +#include "orange/univ/detail/Types.hh" + +#include "SurfaceFunctors.hh" +#include "../VolumeView.hh" + +namespace celeritas +{ +namespace detail +{ +//---------------------------------------------------------------------------// +/*! + * Calculate senses with a fixed particle position. + * + * This is an implementation detail used in initialization *and* complex + * intersection. + */ +class LazySenseCalculator +{ + public: + // Construct from persistent, current, and temporary data + inline CELER_FUNCTION LazySenseCalculator(LocalSurfaceVisitor const& visit, + Real3 const& pos, + Span sense_cache, + Span sense_flags); + + // Calculate senses for a single face of the given volume, possibly on a + // face + inline CELER_FUNCTION Sense operator()(VolumeView const& vol, + FaceId face_id, + OnFace face = {}); + + //! Clear the cached sense values + CELER_FUNCTION void invalidate_cache() + { + for (auto& flags : sense_flags_) + { + flags = unset_sense_mod(SenseMod::cached, flags); + } + } + + //! The first face encountered that we are "on" + CELER_FUNCTION OnFace& on_face() { return face_; } + + //! Flip the sense of a face + CELER_FUNCTION void flip_sense(FaceId face_id) + { + sense_flags_[face_id.get()] + = flip_sense_mod(SenseMod::flipped, sense_flags_[face_id.get()]); + + // If the sense is cached, flip it, otherwise it will be flipped when + // we calculate it + if (is_sense_mod_set(SenseMod::cached, sense_flags_[face_id.get()])) + { + sense_cache_[face_id.get()] + = celeritas::flip_sense(sense_cache_[face_id.get()]); + } + } + + private: + //! The first face encountered that we are "on" + OnFace face_; + + //! Apply a function to a local surface + LocalSurfaceVisitor visit_; + + //! Local position + Real3 pos_; + + //! Temporary senses + Span sense_cache_; + Span sense_flags_; +}; + +//---------------------------------------------------------------------------// +// INLINE DEFINITIONS +//---------------------------------------------------------------------------// +/*! + * Construct from persistent, current, and temporary data. + */ +CELER_FUNCTION +LazySenseCalculator::LazySenseCalculator(LocalSurfaceVisitor const& visit, + Real3 const& pos, + Span sense_cache, + Span sense_flags) + : visit_{visit} + , pos_{pos} + , sense_cache_{sense_cache} + , sense_flags_{sense_flags} +{ + for (auto& sense : sense_flags_) + { + sense = static_cast(SenseMod::normal); + } +} + +//---------------------------------------------------------------------------// +/*! + * Calculate senses for the given volume. + * + * If the point is exactly on one of the volume's surfaces, the \c face value + * of the return will be set. + */ +CELER_FUNCTION auto LazySenseCalculator::operator()(VolumeView const& vol, + FaceId face_id, + OnFace face) -> Sense +{ + CELER_EXPECT(!face || face.id() < vol.num_faces()); + + if (!face_ && face) + { + face_ = face; + } + + if (is_sense_mod_set(SenseMod::cached, sense_flags_[face_id.get()])) + { + return sense_cache_[face_id.get()]; + } + + Sense sense; + if (face_id != face.id()) + { + // Calculate sense + SignedSense ss = visit_(CalcSense{pos_}, vol.get_surface(face_id)); + sense = to_sense(ss); + if (ss == SignedSense::on && !face_) + { + // This is the first face that we're exactly on: save it + face_ = {face_id, sense}; + } + } + else + { + // Sense is known a priori + sense = face.sense(); + } + if (is_sense_mod_set(SenseMod::flipped, sense_flags_[face_id.get()])) + { + sense = celeritas::flip_sense(sense); + } + + sense_flags_[face_id.get()] + = set_sense_mod(SenseMod::cached, sense_flags_[face_id.get()]); + sense_cache_[face_id.get()] = sense; + return sense; +} + +//---------------------------------------------------------------------------// +} // namespace detail +} // namespace celeritas diff --git a/src/orange/univ/detail/LogicEvaluator.hh b/src/orange/univ/detail/LogicEvaluator.hh index 1790f9005a..d9cf32a54b 100644 --- a/src/orange/univ/detail/LogicEvaluator.hh +++ b/src/orange/univ/detail/LogicEvaluator.hh @@ -7,6 +7,8 @@ //---------------------------------------------------------------------------// #pragma once +#include + #include "corecel/Assert.hh" #include "corecel/cont/Span.hh" #include "corecel/data/LdgIterator.hh" @@ -38,6 +40,10 @@ class LogicEvaluator // Evaluate a logical expression, substituting bools from the vector inline CELER_FUNCTION bool operator()(SpanConstSense values) const; + // Evaluate a logical expression, with on-the-fly sense evaluation + template, bool> = true> + inline CELER_FUNCTION bool operator()(F&& eval_sense) const; + private: //// DATA //// @@ -59,6 +65,17 @@ CELER_FUNCTION LogicEvaluator::LogicEvaluator(SpanConstLogic logic) * Evaluate a logical expression, substituting bools from the sense view. */ CELER_FUNCTION bool LogicEvaluator::operator()(SpanConstSense values) const +{ + auto calc_sense = [&](FaceId face_id) { return values[face_id.get()]; }; + return (*this)(calc_sense); +} + +//---------------------------------------------------------------------------// +/*! + * Evaluate a logical expression, with on-the-fly sense evaluation. + */ +template, bool>> +CELER_FUNCTION bool LogicEvaluator::operator()(F&& eval_sense) const { LogicStack stack; @@ -67,8 +84,7 @@ CELER_FUNCTION bool LogicEvaluator::operator()(SpanConstSense values) const if (!logic::is_operator_token(lgc)) { // Push a boolean from the senses onto the stack - CELER_EXPECT(lgc < values.size()); - stack.push(static_cast(values[lgc])); + stack.push(static_cast(eval_sense(FaceId{lgc}))); continue; } diff --git a/src/orange/univ/detail/Types.hh b/src/orange/univ/detail/Types.hh index 35a275511b..d28b280cd5 100644 --- a/src/orange/univ/detail/Types.hh +++ b/src/orange/univ/detail/Types.hh @@ -165,6 +165,7 @@ struct LocalState LocalVolumeId volume; OnLocalSurface surface; Span temp_sense; + Span temp_sense_mod; TempNextFace temp_next; }; diff --git a/test/orange/univ/SimpleUnitTracker.test.cc b/test/orange/univ/SimpleUnitTracker.test.cc index 8af6ab12c2..4e90261ed9 100644 --- a/test/orange/univ/SimpleUnitTracker.test.cc +++ b/test/orange/univ/SimpleUnitTracker.test.cc @@ -146,6 +146,7 @@ LocalState SimpleUnitTrackerTest::make_state(Real3 pos, Real3 dir) auto const& hsref = this->host_state(); auto face_storage = hsref.temp_face[AllItems{}]; state.temp_sense = hsref.temp_sense[AllItems{}]; + state.temp_sense_mod = hsref.temp_sense_mod[AllItems{}]; state.temp_next.face = face_storage.data(); state.temp_next.distance = hsref.temp_distance[AllItems{}].data(); diff --git a/test/orange/univ/SimpleUnitTracker.test.hh b/test/orange/univ/SimpleUnitTracker.test.hh index a79e28effc..cac6f3ff33 100644 --- a/test/orange/univ/SimpleUnitTracker.test.hh +++ b/test/orange/univ/SimpleUnitTracker.test.hh @@ -9,6 +9,7 @@ #include "corecel/Macros.hh" #include "corecel/Types.hh" #include "orange/OrangeData.hh" +#include "orange/OrangeTypes.hh" #include "orange/detail/LevelStateAccessor.hh" #include "orange/univ/SimpleUnitTracker.hh" #include "orange/univ/detail/Types.hh" @@ -60,6 +61,8 @@ inline CELER_FUNCTION LocalState build_local_state(ParamsRef params, size_type const max_faces = params.scalars.max_faces; lstate.temp_sense = states.temp_sense[build_range(max_faces, tid)]; + lstate.temp_sense_mod + = states.temp_sense_mod[build_range(max_faces, tid)]; size_type const max_isect = params.scalars.max_intersections; lstate.temp_next.face diff --git a/test/orange/univ/TrackerVisitor.test.cc b/test/orange/univ/TrackerVisitor.test.cc index 42ad1f751f..d0a81d1a6b 100644 --- a/test/orange/univ/TrackerVisitor.test.cc +++ b/test/orange/univ/TrackerVisitor.test.cc @@ -9,6 +9,7 @@ #include "orange/univ/TrackerVisitor.hh" #include "orange/OrangeGeoTestBase.hh" +#include "orange/OrangeTypes.hh" #include "orange/univ/detail/Types.hh" #include "celeritas_test.hh" @@ -48,6 +49,7 @@ detail::LocalState TrackerVisitorTest::make_state(Real3 pos, Real3 dir) auto const& hsref = this->host_state(); auto face_storage = hsref.temp_face[AllItems{}]; state.temp_sense = hsref.temp_sense[AllItems{}]; + state.temp_sense_mod = hsref.temp_sense_mod[AllItems{}]; state.temp_next.face = face_storage.data(); state.temp_next.distance = hsref.temp_distance[AllItems{}].data();