diff --git a/projects/interactions/private/InteractionCollection.cxx b/projects/interactions/private/InteractionCollection.cxx index 9a3464a7..af6f0e41 100644 --- a/projects/interactions/private/InteractionCollection.cxx +++ b/projects/interactions/private/InteractionCollection.cxx @@ -116,6 +116,14 @@ bool InteractionCollection::MatchesPrimary(dataclasses::InteractionRecord const return primary_type == record.signature.primary_type; } +siren::dataclasses::ParticleType InteractionCollection::GetPrimaryType() const { + return primary_type; +} + +void InteractionCollection::SetPrimaryType(siren::dataclasses::ParticleType primary_type) { + this->primary_type = primary_type; +} + std::map InteractionCollection::TotalCrossSectionByTarget(siren::dataclasses::InteractionRecord const & record) const { std::map result; for(siren::dataclasses::ParticleType target : target_types) { diff --git a/projects/interactions/private/pybindings/InteractionCollection.h b/projects/interactions/private/pybindings/InteractionCollection.h index 5d8314ac..1df42439 100644 --- a/projects/interactions/private/pybindings/InteractionCollection.h +++ b/projects/interactions/private/pybindings/InteractionCollection.h @@ -35,5 +35,7 @@ void register_InteractionCollection(pybind11::module_ & m) { .def("TotalDecayWidth",&InteractionCollection::TotalDecayWidth) .def("TotalDecayLength",&InteractionCollection::TotalDecayLength) .def("MatchesPrimary",&InteractionCollection::MatchesPrimary) + .def("GetPrimaryType",&InteractionCollection::GetPrimaryType) + .def("SetPrimaryType",&InteractionCollection::SetPrimaryType) ; } diff --git a/projects/interactions/public/SIREN/interactions/InteractionCollection.h b/projects/interactions/public/SIREN/interactions/InteractionCollection.h index 243b4114..ecab5465 100644 --- a/projects/interactions/public/SIREN/interactions/InteractionCollection.h +++ b/projects/interactions/public/SIREN/interactions/InteractionCollection.h @@ -63,6 +63,8 @@ class InteractionCollection { }; double TotalDecayWidth(siren::dataclasses::InteractionRecord const & record) const; double TotalDecayLength(siren::dataclasses::InteractionRecord const & record) const; + siren::dataclasses::ParticleType GetPrimaryType() const; + void SetPrimaryType(siren::dataclasses::ParticleType primary_type); virtual bool MatchesPrimary(dataclasses::InteractionRecord const & record) const; std::map TotalCrossSectionByTarget(siren::dataclasses::InteractionRecord const & record) const; std::map TotalCrossSectionByTargetAllFinalStates(siren::dataclasses::InteractionRecord const & record) const;