Skip to content

Commit

Permalink
pass sense mod flags to the sense calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
esseivaju committed Dec 10, 2024
1 parent 4fedccf commit 2df2f88
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 9 deletions.
4 changes: 4 additions & 0 deletions src/orange/OrangeData.hh
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ struct OrangeStateData

// Scratch space with dimensions {track}{max_faces}
Items<Sense> temp_sense;
Items<SenseModFlags> temp_sense_mod;

// Scratch space with dimensions {track}{max_intersections}
Items<FaceId> temp_face;
Expand Down Expand Up @@ -500,6 +501,7 @@ struct OrangeStateData
&& vol.size() == max_depth * this->size()
&& universe.size() == max_depth * this->size()
&& !temp_sense.empty()
&& !temp_sense_mod.empty()
&& !temp_face.empty()
&& temp_distance.size() == temp_face.size()
&& temp_isect.size() == temp_face.size();
Expand Down Expand Up @@ -533,6 +535,7 @@ struct OrangeStateData
universe = other.universe;

temp_sense = other.temp_sense;
temp_sense_mod = other.temp_sense_mod;

temp_face = other.temp_face;
temp_distance = other.temp_distance;
Expand Down Expand Up @@ -576,6 +579,7 @@ inline void resize(OrangeStateData<Ownership::value, M>* data,

size_type face_states = params.scalars.max_faces * num_tracks;
resize(&data->temp_sense, face_states);
resize(&data->temp_sense_mod, face_states);

size_type isect_states = params.scalars.max_intersections * num_tracks;
resize(&data->temp_face, isect_states);
Expand Down
16 changes: 16 additions & 0 deletions src/orange/OrangeTrackView.hh
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class OrangeTrackView

// Create local sense reference
inline CELER_FUNCTION Span<Sense> make_temp_sense() const;
inline CELER_FUNCTION Span<SenseModFlags> make_temp_sense_mod() const;

// Create local distance
inline CELER_FUNCTION detail::TempNextFace make_temp_next() const;
Expand Down Expand Up @@ -276,6 +277,7 @@ OrangeTrackView::operator=(Initializer_t const& init)
local.volume = {};
local.surface = {};
local.temp_sense = this->make_temp_sense();
local.temp_sense_mod = this->make_temp_sense_mod();

// Helpers for applying parent-to-daughter transformations
TransformVisitor apply_transform{params_};
Expand Down Expand Up @@ -689,6 +691,7 @@ CELER_FUNCTION void OrangeTrackView::cross_boundary()
local.volume = lsa.vol();
local.surface = {this->surf(), this->sense()};
local.temp_sense = this->make_temp_sense();
local.temp_sense_mod = this->make_temp_sense_mod();
}

TrackerVisitor visit_tracker{params_};
Expand Down Expand Up @@ -1047,6 +1050,18 @@ CELER_FUNCTION Span<Sense> OrangeTrackView::make_temp_sense() const
offset, max_faces);
}

//---------------------------------------------------------------------------//
/*!
* Get a reference to the current volume, or to world volume if outside.
*/
CELER_FUNCTION Span<SenseModFlags> OrangeTrackView::make_temp_sense_mod() const
{
auto const max_faces = params_.scalars.max_faces;
auto offset = track_slot_.get() * max_faces;
return states_.temp_sense_mod[AllItems<SenseModFlags, MemSpace::native>{}]
.subspan(offset, max_faces);
}

//---------------------------------------------------------------------------//
/*!
* Set up intersection scratch space.
Expand Down Expand Up @@ -1089,6 +1104,7 @@ OrangeTrackView::make_local_state(LevelId level) const
local.surface = {};
}
local.temp_sense = this->make_temp_sense();
local.temp_sense_mod = this->make_temp_sense_mod();
local.temp_next = this->make_temp_next();
return local;
}
Expand Down
31 changes: 31 additions & 0 deletions src/orange/OrangeTypes.hh
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ enum class Sense : bool
outside, //!< Expression is greater than zero
};

//---------------------------------------------------------------------------//
/*!
* Transformations to apply to senses when using lazy sense evaluation.
*/
enum class SenseMod
{
normal = 0,
flipped = 1 << 0,
};
using SenseModFlags = std::underlying_type_t<SenseMod>;

//---------------------------------------------------------------------------//
/*!
* Enumeration for mapping surface classes to integers.
Expand Down Expand Up @@ -378,6 +389,26 @@ CELER_CONSTEXPR_FUNCTION Sense to_sense(bool s)
return static_cast<SignedSense>(-static_cast<IntT>(orig));
}

//---------------------------------------------------------------------------//
/*!
* Check if a sense modifier is set.
*/
[[nodiscard]] CELER_CONSTEXPR_FUNCTION bool
is_sense_mod_set(SenseMod mod, SenseModFlags flags)
{
return (flags & static_cast<SenseModFlags>(mod)) != 0;
}

//---------------------------------------------------------------------------//
/*!
* Set a sense modifier.
*/
[[nodiscard]] CELER_CONSTEXPR_FUNCTION SenseModFlags
set_sense_mod(SenseMod mod, SenseModFlags flags)
{
return flags | static_cast<SenseModFlags>(mod);
}

//---------------------------------------------------------------------------//
/*!
* Change whether a boundary crossing is reentrant or exiting.
Expand Down
10 changes: 5 additions & 5 deletions src/orange/univ/SimpleUnitTracker.hh
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ SimpleUnitTracker::initialize(LocalState const& state) const -> Initialization
CELER_EXPECT(params_);
CELER_EXPECT(!state.surface && !state.volume);

detail::LazySenseCalculator calc_senses(this->make_surface_visitor(),
state.pos);
detail::LazySenseCalculator calc_senses(
this->make_surface_visitor(), state.pos, state.temp_sense_mod);

// Use the BIH to locate a position that's inside, and save whether it's on
// a surface in the found volume
Expand Down Expand Up @@ -206,8 +206,8 @@ SimpleUnitTracker::cross_boundary(LocalState const& state) const -> Initializati
detail::OnLocalSurface on_surface;
auto is_inside
= [this, &state, &on_surface](LocalVolumeId const& id) -> bool {
detail::LazySenseCalculator calc_senses(this->make_surface_visitor(),
state.pos);
detail::LazySenseCalculator calc_senses(
this->make_surface_visitor(), state.pos, state.temp_sense_mod);
if (id == state.volume)
{
// Cannot cross surface into the same volume
Expand Down Expand Up @@ -622,7 +622,7 @@ CELER_FUNCTION auto SimpleUnitTracker::background_intersect(
CELER_ASSERT(vid != state.volume);
VolumeView vol = this->make_local_volume(vid);
auto calc_senses = detail::LazySenseCalculator{
this->make_surface_visitor(), pos};
this->make_surface_visitor(), pos, state.temp_sense_mod};
auto bind_calc_sense
= [&](FaceId face_id) { return calc_senses(vol, face_id); };

Expand Down
21 changes: 17 additions & 4 deletions src/orange/univ/detail/LazySenseCalculator.hh
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class LazySenseCalculator
{
public:
// Construct from persistent, current, and temporary data
inline CELER_FUNCTION
LazySenseCalculator(LocalSurfaceVisitor const& visit, Real3 const& pos);
inline CELER_FUNCTION LazySenseCalculator(LocalSurfaceVisitor const& visit,
Real3 const& pos,
Span<SenseModFlags> sense_mod);

// Calculate senses for a single face of the given volume, possibly on a
// face
Expand All @@ -52,6 +53,9 @@ class LazySenseCalculator

//! Local position
Real3 pos_;

//! Temporary senses
Span<SenseModFlags> sense_storage_;
};

//---------------------------------------------------------------------------//
Expand All @@ -62,9 +66,14 @@ class LazySenseCalculator
*/
CELER_FUNCTION
LazySenseCalculator::LazySenseCalculator(LocalSurfaceVisitor const& visit,
Real3 const& pos)
: visit_{visit}, pos_(pos)
Real3 const& pos,
Span<SenseModFlags> sense_mod)
: visit_{visit}, pos_(pos), sense_storage_{sense_mod}
{
for (auto& sense : sense_storage_)
{
sense = static_cast<SenseModFlags>(SenseMod::normal);
}
}

//---------------------------------------------------------------------------//
Expand Down Expand Up @@ -102,6 +111,10 @@ CELER_FUNCTION auto LazySenseCalculator::operator()(VolumeView const& vol,
// Sense is known a priori
sense = face.sense();
}
if (is_sense_mod_set(SenseMod::flipped, sense_storage_[face_id.get()]))
{
sense = flip_sense(sense);
}

return sense;
}
Expand Down
1 change: 1 addition & 0 deletions src/orange/univ/detail/Types.hh
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ struct LocalState
LocalVolumeId volume;
OnLocalSurface surface;
Span<Sense> temp_sense;
Span<SenseModFlags> temp_sense_mod;
TempNextFace temp_next;
};

Expand Down
1 change: 1 addition & 0 deletions test/orange/univ/SimpleUnitTracker.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ LocalState SimpleUnitTrackerTest::make_state(Real3 pos, Real3 dir)
auto const& hsref = this->host_state();
auto face_storage = hsref.temp_face[AllItems<FaceId>{}];
state.temp_sense = hsref.temp_sense[AllItems<Sense>{}];
state.temp_sense_mod = hsref.temp_sense_mod[AllItems<SenseModFlags>{}];
state.temp_next.face = face_storage.data();
state.temp_next.distance
= hsref.temp_distance[AllItems<real_type>{}].data();
Expand Down
3 changes: 3 additions & 0 deletions test/orange/univ/SimpleUnitTracker.test.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "corecel/Macros.hh"
#include "corecel/Types.hh"
#include "orange/OrangeData.hh"
#include "orange/OrangeTypes.hh"
#include "orange/detail/LevelStateAccessor.hh"
#include "orange/univ/SimpleUnitTracker.hh"
#include "orange/univ/detail/Types.hh"
Expand Down Expand Up @@ -60,6 +61,8 @@ inline CELER_FUNCTION LocalState build_local_state(ParamsRef<M> params,

size_type const max_faces = params.scalars.max_faces;
lstate.temp_sense = states.temp_sense[build_range<Sense>(max_faces, tid)];
lstate.temp_sense_mod
= states.temp_sense_mod[build_range<SenseModFlags>(max_faces, tid)];

size_type const max_isect = params.scalars.max_intersections;
lstate.temp_next.face
Expand Down

0 comments on commit 2df2f88

Please sign in to comment.