Skip to content

Commit

Permalink
Merge pull request #9 from p-ortega/hotfix/saturation
Browse files Browse the repository at this point in the history
Hotfix/saturation add saturation calcs

YOLO
  • Loading branch information
p-ortega authored Sep 20, 2024
2 parents 13b5471 + ec9e049 commit 76bb212
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
1 change: 0 additions & 1 deletion autotest/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,6 @@ def test_yaml(prefix):
testdf = pd.read_csv(os.path.join(cwd, wd,f"sout.csv"), index_col = 0)
compare_results(benchmarkdf, testdf)


@pytest.mark.skip
def test01_yaml(prefix = 'test01'):

Expand Down
52 changes: 45 additions & 7 deletions src/mf6rtm/mf6rtm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class PhreeqcBMI(phreeqcrm.BMIPhreeqcRM):
def __init__(self, yaml="mf6rtm.yaml"):
phreeqcrm.BMIPhreeqcRM.__init__(self)
self.initialize(yaml)
self.sat_now = None

def get_grid_to_map(self):
'''Function to get grid to map
Expand Down Expand Up @@ -158,14 +159,18 @@ def set_scalar(self, var_name, value):
def _solve_phreeqcrm(self, dt, diffmask):
'''Function to solve phreeqc bmi
'''

# status = phreeqc_rm.SetTemperature([self.init_temp[0]] * self.ncpl)
# status = phreeqc_rm.SetPressure([2.0] * nxyz)
self.SetTimeStep(dt*1.0/self.GetTimeConversion())

# update which cells to run depending on conc change between tsteps
sat = [1]*self.GetGridCellCount()
if self.sat_now is not None:
sat = self.sat_now
else:
sat = [1]*self.GetGridCellCount()

self.SetSaturation(sat)

# update which cells to run depending on conc change between tsteps
if diffmask is not None:
# get idx where diffmask is 0
inact = get_indices(0, diffmask)
Expand Down Expand Up @@ -211,17 +216,29 @@ def __init__(self, wd, dll):
modflowapi.ModflowApi.__init__(self, dll, working_directory=wd)
self.initialize()
self.sim = flopy.mf6.MFSimulation.load(sim_ws=wd, verbosity_level=0)
self.fmi = False

def _prepare_mf6(self):
'''Prepare mf6 bmi for transport calculations
'''
self.modelnmes = ['Flow'] + [nme.capitalize() for nme in self.sim.model_names if nme != 'gwf']
self.modelnmes = [nme.capitalize() for nme in self.sim.model_names]
self.components = [nme.capitalize() for nme in self.sim.model_names[1:]]
self.nsln = self.get_subcomponent_count()
self.sim_start = datetime.now()
self.ctimes = [0.0]
self.num_fails = 0

def _check_fmi(self):
'''Check if fmi is in the nam file
'''
...
return

def _set_simtype_gwt(self):
'''Set the gwt sim type as sequential or flow interface
'''
...

def _solve_gwt(self):
'''Function to solve the transport loop
'''
Expand Down Expand Up @@ -291,6 +308,26 @@ def __init__(self, wd, mf6api, phreeqcbmi):
# set time conversion factor
self.set_time_conversion()

def get_saturation_from_mf6(self):
"""
Get the saturation
Parameters
----------
mf6 (modflowapi): the modflow api object
Returns
-------
array: the saturation
"""
sat = {component: self.mf6api.get_value(
self.mf6api.get_var_address("FMI/GWFSAT", f'{component}')
) for component in self.phreeqcbmi.components}
# select the first component to get the length of the array
sat = sat[self.phreeqcbmi.components[0]] # saturation is the same for all components
self.phreeqcbmi.sat_now = sat # set phreeqcmbi saturation
return sat

def get_time_units_from_mf6(self):
'''Function to get the time units from mf6
'''
Expand Down Expand Up @@ -467,8 +504,6 @@ def _update_selected_output(self):
def __replace_inactive_cells_in_sout(self, sout, diffmask):
'''Function to replace inactive cells in the selected output dataframe
'''
components = self.mf6api.modelnmes[1:]
headers = self.phreeqcbmi.sout_headers
# match headers in components closest string

inactive_idx = get_indices(0, diffmask)
Expand Down Expand Up @@ -545,7 +580,7 @@ def _solve(self):
assert self._check_sout_exist(), f'{self.sout_fname} not found'

print("Starting transport solution at {0}".format(sim_start.strftime(DT_FMT)))
print(f"Solving the following components: {', '.join([nme for nme in self.mf6api.modelnmes])}")
# print(f"Solving the following components: {', '.join([nme for nme in self.mf6api.modelnmes])}")
ctime = self._set_ctime()
etime = self._set_etime()
while ctime < etime:
Expand All @@ -556,6 +591,9 @@ def _solve(self):
self.mf6api.prepare_time_step(dt)
self.mf6api._solve_gwt()

# get saturation
self.get_saturation_from_mf6()

if self.reactive:
# reaction block
c_dbl_vect = self._transfer_array_to_phreeqcrm()
Expand Down

0 comments on commit 76bb212

Please sign in to comment.