diff --git a/src/idmlaser_cholera/cholera.py b/src/idmlaser_cholera/cholera.py index 74190f1..abc10cb 100644 --- a/src/idmlaser_cholera/cholera.py +++ b/src/idmlaser_cholera/cholera.py @@ -138,8 +138,8 @@ model.metrics = [] for tick in tqdm(range(model.params.ticks)): """ - if tick == 365: - model.population.save( filename="laser_cache/burnin_cholera.h5", initial_populations=initial_populations, age_distribution=age_init.age_distribution, cumulative_deaths=cumulative_deaths) + if tick == 1: + model.save( filename="laser_cache/burnin_cholera.h5" ) """ #""" diff --git a/src/idmlaser_cholera/mods/age_init.py b/src/idmlaser_cholera/mods/age_init.py index 3597df3..607cbeb 100644 --- a/src/idmlaser_cholera/mods/age_init.py +++ b/src/idmlaser_cholera/mods/age_init.py @@ -4,6 +4,36 @@ import laser_core.demographics.pyramid as pyramid import pdb +# age_init.py +import json # Assuming the age data is stored in a JSON file; adjust as necessary. + +# This get way more involved than I wanted when I needed to load the age_distribution from the user-specifiec path early enough to check whether the +# cached models match the distribution before creating a new model. + +class AgeDataManager: + def __init__(self): + self._path = None + self._data = None + + def get_data(self): + """Load and return the age data.""" + if self._data is None: + if self._path is None: + raise ValueError("Path not set. Please use set_path to specify a file path.") + self._data = pyramid.load_pyramid_csv(Path(self._path)) + return self._data + + def set_path(self, path): + """Set the file path for the age data.""" + if not isinstance(path, str): + raise TypeError("Path must be a string.") + self._path = path + self._data = None # Clear any previously loaded data + +# Singleton instance for module-level access +age_data_manager = AgeDataManager() + + # ## Non-Disease Mortality # ### Part I # @@ -15,11 +45,14 @@ def init( model, manifest ): print(f"Loading pyramid from '{manifest.age_data}'...") - #with open( manifest.age_data, 'r' ) as pyramid_file: # Convert it to a string if needed - age_distribution = pyramid.load_pyramid_csv(Path(manifest.age_data)) + age_distribution = age_data_manager.get_data() #initial_populations = model.nodes.population[:,0] + if model.nodes is None: + raise RuntimeError( "nodes does not seem to be initialized in model object." ) + if model.nodes.population is None: + raise RuntimeError( "nodes.population does not seem to be initialized in model object." ) initial_populations = model.nodes.population[0] capacity = model.population.capacity diff --git a/src/idmlaser_cholera/mymodel.py b/src/idmlaser_cholera/mymodel.py index 4a0f24a..0dd8349 100644 --- a/src/idmlaser_cholera/mymodel.py +++ b/src/idmlaser_cholera/mymodel.py @@ -22,6 +22,7 @@ def __init__(self, params): self.manifest_path = os.path.join(self.params.input_dir, "manifest.py") self.manifest = None self._load_manifest() + age_init.age_data_manager.set_path( self.manifest.age_data ) @classmethod def get(cls, params): @@ -32,11 +33,18 @@ def get(cls, params): model = cls(params) # Create a Model instance if model._check_for_cached(): print("*\nFound cached file. Using it.\n*") - # Logic for using cached data can be placed here else: model._init_from_data() # Initialize from data if no cache found return model + def save( self, filename ): + if age_init.age_data_manager.get_data() is None: + raise ValueError( f"age_distribution uninitialized while saving" ) + self.population.save( filename=filename, + initial_populations=self.initial_populations, + age_distribution=age_init.age_data_manager.get_data(), + cumulative_deaths=cumulative_deaths + ) def _load_manifest(self): """Load the manifest module if it exists.""" if os.path.isfile(self.manifest_path): @@ -133,6 +141,10 @@ def _init_from_file(self, filename): self.population = Population.load(filename) self._extend_capacity_after_loading() self._save_pops_in_nodes() + # We aren't yet storing the initial infections in the cached file so recreating on reload + self.nodes.initial_infections = np.uint32( + np.round(np.random.poisson(self.params.prevalence * self.initial_populations)) + ) def _extend_capacity_after_loading(self): """Extend the population capacity after loading.""" @@ -148,10 +160,12 @@ def _check_for_cached(self): for filename in os.listdir(hdf5_directory): if filename.endswith(".h5"): hdf5_filepath = os.path.join(hdf5_directory, filename) + if age_init.age_data_manager.get_data() is None: + raise RuntimeError( "age_init.age_distribution seems to None while caching" ) cached = check_hdf5_attributes( hdf5_filename=hdf5_filepath, initial_populations=self.initial_populations, - age_distribution=age_init.age_distribution, + age_distribution=age_init.age_data_manager.get_data(), cumulative_deaths=cumulative_deaths, ) if cached: diff --git a/src/idmlaser_cholera/numpynumba/population.py b/src/idmlaser_cholera/numpynumba/population.py index e01d1a6..2304a65 100644 --- a/src/idmlaser_cholera/numpynumba/population.py +++ b/src/idmlaser_cholera/numpynumba/population.py @@ -124,7 +124,8 @@ def save(self, filename: str, tail_number=0, initial_populations=None, age_distr with h5py.File(filename, 'w') as hdf: hdf.attrs['count'] = self._count hdf.attrs['capacity'] = self._capacity - hdf.attrs['node_count'] = self.node_count + hdf.attrs['node_count'] = len(initial_populations) + print( "TBD: Need to derive node count since we don't have it here." ) if initial_populations is not None: hdf.attrs['init_pops'] = initial_populations if age_distribution is not None: @@ -163,7 +164,7 @@ def save_npz(self, filename: str, tail_number=0) -> None: @staticmethod def load(filename: str) -> None: def load_hdf5( filename ): - population = Population(0) # We'll do capacity automatically + population = ExtendedLF(0) # We'll do capacity automatically """Load the population properties from an HDF5 file""" with h5py.File(filename, 'r') as hdf: population._count = hdf.attrs['count']