Skip to content

Commit

Permalink
Got caching and reloading working again after refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Bloedow committed Nov 22, 2024
1 parent 008c25b commit 2aca074
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/idmlaser_cholera/cholera.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@
model.metrics = []
for tick in tqdm(range(model.params.ticks)):
"""
if tick == 365:
model.population.save( filename="laser_cache/burnin_cholera.h5", initial_populations=initial_populations, age_distribution=age_init.age_distribution, cumulative_deaths=cumulative_deaths)
if tick == 1:
model.save( filename="laser_cache/burnin_cholera.h5" )
"""
#"""
Expand Down
37 changes: 35 additions & 2 deletions src/idmlaser_cholera/mods/age_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,36 @@
import laser_core.demographics.pyramid as pyramid
import pdb

# age_init.py
import json # Assuming the age data is stored in a JSON file; adjust as necessary.

# This get way more involved than I wanted when I needed to load the age_distribution from the user-specifiec path early enough to check whether the
# cached models match the distribution before creating a new model.

class AgeDataManager:
def __init__(self):
self._path = None
self._data = None

def get_data(self):
"""Load and return the age data."""
if self._data is None:
if self._path is None:
raise ValueError("Path not set. Please use set_path to specify a file path.")
self._data = pyramid.load_pyramid_csv(Path(self._path))
return self._data

def set_path(self, path):
"""Set the file path for the age data."""
if not isinstance(path, str):
raise TypeError("Path must be a string.")
self._path = path
self._data = None # Clear any previously loaded data

# Singleton instance for module-level access
age_data_manager = AgeDataManager()


# ## Non-Disease Mortality
# ### Part I
#
Expand All @@ -15,11 +45,14 @@
def init( model, manifest ):

print(f"Loading pyramid from '{manifest.age_data}'...")
#with open( manifest.age_data, 'r' ) as pyramid_file:
# Convert it to a string if needed
age_distribution = pyramid.load_pyramid_csv(Path(manifest.age_data))
age_distribution = age_data_manager.get_data()

#initial_populations = model.nodes.population[:,0]
if model.nodes is None:
raise RuntimeError( "nodes does not seem to be initialized in model object." )
if model.nodes.population is None:
raise RuntimeError( "nodes.population does not seem to be initialized in model object." )
initial_populations = model.nodes.population[0]
capacity = model.population.capacity

Expand Down
18 changes: 16 additions & 2 deletions src/idmlaser_cholera/mymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, params):
self.manifest_path = os.path.join(self.params.input_dir, "manifest.py")
self.manifest = None
self._load_manifest()
age_init.age_data_manager.set_path( self.manifest.age_data )

@classmethod
def get(cls, params):
Expand All @@ -32,11 +33,18 @@ def get(cls, params):
model = cls(params) # Create a Model instance
if model._check_for_cached():
print("*\nFound cached file. Using it.\n*")
# Logic for using cached data can be placed here
else:
model._init_from_data() # Initialize from data if no cache found
return model

def save( self, filename ):
if age_init.age_data_manager.get_data() is None:
raise ValueError( f"age_distribution uninitialized while saving" )
self.population.save( filename=filename,
initial_populations=self.initial_populations,
age_distribution=age_init.age_data_manager.get_data(),
cumulative_deaths=cumulative_deaths
)
def _load_manifest(self):
"""Load the manifest module if it exists."""
if os.path.isfile(self.manifest_path):
Expand Down Expand Up @@ -133,6 +141,10 @@ def _init_from_file(self, filename):
self.population = Population.load(filename)
self._extend_capacity_after_loading()
self._save_pops_in_nodes()
# We aren't yet storing the initial infections in the cached file so recreating on reload
self.nodes.initial_infections = np.uint32(
np.round(np.random.poisson(self.params.prevalence * self.initial_populations))
)

def _extend_capacity_after_loading(self):
"""Extend the population capacity after loading."""
Expand All @@ -148,10 +160,12 @@ def _check_for_cached(self):
for filename in os.listdir(hdf5_directory):
if filename.endswith(".h5"):
hdf5_filepath = os.path.join(hdf5_directory, filename)
if age_init.age_data_manager.get_data() is None:
raise RuntimeError( "age_init.age_distribution seems to None while caching" )
cached = check_hdf5_attributes(
hdf5_filename=hdf5_filepath,
initial_populations=self.initial_populations,
age_distribution=age_init.age_distribution,
age_distribution=age_init.age_data_manager.get_data(),
cumulative_deaths=cumulative_deaths,
)
if cached:
Expand Down
5 changes: 3 additions & 2 deletions src/idmlaser_cholera/numpynumba/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def save(self, filename: str, tail_number=0, initial_populations=None, age_distr
with h5py.File(filename, 'w') as hdf:
hdf.attrs['count'] = self._count
hdf.attrs['capacity'] = self._capacity
hdf.attrs['node_count'] = self.node_count
hdf.attrs['node_count'] = len(initial_populations)
print( "TBD: Need to derive node count since we don't have it here." )
if initial_populations is not None:
hdf.attrs['init_pops'] = initial_populations
if age_distribution is not None:
Expand Down Expand Up @@ -163,7 +164,7 @@ def save_npz(self, filename: str, tail_number=0) -> None:
@staticmethod
def load(filename: str) -> None:
def load_hdf5( filename ):
population = Population(0) # We'll do capacity automatically
population = ExtendedLF(0) # We'll do capacity automatically
"""Load the population properties from an HDF5 file"""
with h5py.File(filename, 'r') as hdf:
population._count = hdf.attrs['count']
Expand Down

0 comments on commit 2aca074

Please sign in to comment.