From df90ce568cd5483f761bbe3a0a0e3f3b1ebff397 Mon Sep 17 00:00:00 2001 From: Austin Schneider Date: Mon, 21 Oct 2024 22:31:22 -0400 Subject: [PATCH] Pickle weighter --- projects/injection/private/pybindings/injection.cxx | 4 ++++ projects/injection/public/SIREN/injection/Weighter.h | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/projects/injection/private/pybindings/injection.cxx b/projects/injection/private/pybindings/injection.cxx index 5b8b904f..8cb3232b 100644 --- a/projects/injection/private/pybindings/injection.cxx +++ b/projects/injection/private/pybindings/injection.cxx @@ -155,6 +155,10 @@ PYBIND11_MODULE(injection,m) { .def("EventWeight",&Weighter::EventWeight) .def("SaveWeighter",&Weighter::SaveWeighter) .def("LoadWeighter",&Weighter::LoadWeighter) + .def(pybind11::pickle( + &(siren::serialization::pickle_save), + &(siren::serialization::pickle_load) + )) ; } diff --git a/projects/injection/public/SIREN/injection/Weighter.h b/projects/injection/public/SIREN/injection/Weighter.h index 93e845cb..4b8daf85 100644 --- a/projects/injection/public/SIREN/injection/Weighter.h +++ b/projects/injection/public/SIREN/injection/Weighter.h @@ -102,12 +102,17 @@ class Weighter { } template - void load(Archive & archive, std::uint32_t const version) const { + static void load_and_construct(Archive & archive, cereal::construct & construct, std::uint32_t const version) { if(version == 0) { + std::vector> injectors; + std::shared_ptr detector_model; + std::shared_ptr primary_physical_process; + std::vector> secondary_physical_processes; archive(::cereal::make_nvp("Injectors", injectors)); archive(::cereal::make_nvp("DetectorModel", detector_model)); archive(::cereal::make_nvp("PrimaryPhysicalProcess", primary_physical_process)); archive(::cereal::make_nvp("SecondaryPhysicalProcesses", secondary_physical_processes)); + construct(injectors, detector_model, primary_physical_process, secondary_physical_processes); } else { throw std::runtime_error("Weighter only supports version <= 0!"); }