Skip to content

Commit

Permalink
Pickle serialization for dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
austinschneider committed Oct 22, 2024
1 parent 57589d2 commit 5bcb2ef
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions projects/dataclasses/private/pybindings/dataclasses.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "../../public/SIREN/dataclasses/InteractionSignature.h"
#include "../../public/SIREN/dataclasses/InteractionRecord.h"
#include "../../public/SIREN/dataclasses/InteractionTree.h"
#include "../../../serialization/public/SIREN/serialization/ByteString.h"

#include "SIREN/dataclasses/serializable.h"

#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
Expand Down Expand Up @@ -52,6 +55,10 @@ PYBIND11_MODULE(dataclasses, m) {
.def_readwrite("length",&Particle::length)
.def_readwrite("helicity",&Particle::helicity)
.def("generate_id",&Particle::GenerateID)
.def(pybind11::pickle(
&(siren::serialization::pickle_save<Particle>),
&(siren::serialization::pickle_load<Particle>)
))
;

py::enum_<ParticleType>(m, "ParticleType", py::arithmetic())
Expand All @@ -67,6 +74,10 @@ PYBIND11_MODULE(dataclasses, m) {
.def_readwrite("primary_type",&InteractionSignature::primary_type)
.def_readwrite("target_type",&InteractionSignature::target_type)
.def_readwrite("secondary_types",&InteractionSignature::secondary_types)
.def(pybind11::pickle(
&(siren::serialization::pickle_save<InteractionSignature>),
&(siren::serialization::pickle_load<InteractionSignature>)
))
;

py::class_<PrimaryDistributionRecord, std::shared_ptr<PrimaryDistributionRecord>>(m, "PrimaryDistributionRecord")
Expand Down Expand Up @@ -174,6 +185,10 @@ PYBIND11_MODULE(dataclasses, m) {
.def_readwrite("secondary_momenta",&InteractionRecord::secondary_momenta)
.def_readwrite("secondary_helicities",&InteractionRecord::secondary_helicities)
.def_readwrite("interaction_parameters",&InteractionRecord::interaction_parameters)
.def(pybind11::pickle(
&(siren::serialization::pickle_save<InteractionRecord>),
&(siren::serialization::pickle_load<InteractionRecord>)
))
;

py::class_<InteractionTreeDatum, std::shared_ptr<InteractionTreeDatum>>(m, "InteractionTreeDatum")
Expand All @@ -182,13 +197,21 @@ PYBIND11_MODULE(dataclasses, m) {
.def_readwrite("parent",&InteractionTreeDatum::parent)
.def_readwrite("daughters",&InteractionTreeDatum::daughters)
.def("depth",&InteractionTreeDatum::depth)
.def(pybind11::pickle(
&(siren::serialization::pickle_save<InteractionTreeDatum>),
&(siren::serialization::pickle_load<InteractionTreeDatum>)
))
;

py::class_<InteractionTree, std::shared_ptr<InteractionTree>>(m, "InteractionTree")
.def(py::init<>())
.def_readwrite("tree",&InteractionTree::tree)
.def("add_entry",static_cast<std::shared_ptr<InteractionTreeDatum> (InteractionTree::*)(InteractionTreeDatum&,std::shared_ptr<InteractionTreeDatum>)>(&InteractionTree::add_entry))
.def("add_entry",static_cast<std::shared_ptr<InteractionTreeDatum> (InteractionTree::*)(InteractionRecord&,std::shared_ptr<InteractionTreeDatum>)>(&InteractionTree::add_entry))
.def(pybind11::pickle(
&(siren::serialization::pickle_save<InteractionTree>),
&(siren::serialization::pickle_load<InteractionTree>)
))
;

m.def("SaveInteractionTrees",&SaveInteractionTrees);
Expand Down

0 comments on commit 5bcb2ef

Please sign in to comment.