diff --git a/autotest/test_benchmark.py b/autotest/test_benchmark.py index 79c22ea..59f6caa 100644 --- a/autotest/test_benchmark.py +++ b/autotest/test_benchmark.py @@ -1087,7 +1087,6 @@ def test_yaml(prefix): testdf = pd.read_csv(os.path.join(cwd, wd,f"sout.csv"), index_col = 0) compare_results(benchmarkdf, testdf) - @pytest.mark.skip def test01_yaml(prefix = 'test01'): diff --git a/src/mf6rtm/mf6rtm.py b/src/mf6rtm/mf6rtm.py index f867519..025918f 100644 --- a/src/mf6rtm/mf6rtm.py +++ b/src/mf6rtm/mf6rtm.py @@ -116,6 +116,7 @@ class PhreeqcBMI(phreeqcrm.BMIPhreeqcRM): def __init__(self, yaml="mf6rtm.yaml"): phreeqcrm.BMIPhreeqcRM.__init__(self) self.initialize(yaml) + self.sat_now = None def get_grid_to_map(self): '''Function to get grid to map @@ -158,14 +159,18 @@ def set_scalar(self, var_name, value): def _solve_phreeqcrm(self, dt, diffmask): '''Function to solve phreeqc bmi ''' - # status = phreeqc_rm.SetTemperature([self.init_temp[0]] * self.ncpl) # status = phreeqc_rm.SetPressure([2.0] * nxyz) self.SetTimeStep(dt*1.0/self.GetTimeConversion()) - # update which cells to run depending on conc change between tsteps - sat = [1]*self.GetGridCellCount() + if self.sat_now is not None: + sat = self.sat_now + else: + sat = [1]*self.GetGridCellCount() + self.SetSaturation(sat) + + # update which cells to run depending on conc change between tsteps if diffmask is not None: # get idx where diffmask is 0 inact = get_indices(0, diffmask) @@ -211,17 +216,29 @@ def __init__(self, wd, dll): modflowapi.ModflowApi.__init__(self, dll, working_directory=wd) self.initialize() self.sim = flopy.mf6.MFSimulation.load(sim_ws=wd, verbosity_level=0) + self.fmi = False def _prepare_mf6(self): '''Prepare mf6 bmi for transport calculations ''' - self.modelnmes = ['Flow'] + [nme.capitalize() for nme in self.sim.model_names if nme != 'gwf'] + self.modelnmes = [nme.capitalize() for nme in self.sim.model_names] self.components = [nme.capitalize() for nme in self.sim.model_names[1:]] self.nsln = self.get_subcomponent_count() self.sim_start = datetime.now() self.ctimes = [0.0] self.num_fails = 0 + def _check_fmi(self): + '''Check if fmi is in the nam file + ''' + ... + return + + def _set_simtype_gwt(self): + '''Set the gwt sim type as sequential or flow interface + ''' + ... + def _solve_gwt(self): '''Function to solve the transport loop ''' @@ -291,6 +308,26 @@ def __init__(self, wd, mf6api, phreeqcbmi): # set time conversion factor self.set_time_conversion() + def get_saturation_from_mf6(self): + """ + Get the saturation + + Parameters + ---------- + mf6 (modflowapi): the modflow api object + + Returns + ------- + array: the saturation + """ + sat = {component: self.mf6api.get_value( + self.mf6api.get_var_address("FMI/GWFSAT", f'{component}') + ) for component in self.phreeqcbmi.components} + # select the first component to get the length of the array + sat = sat[self.phreeqcbmi.components[0]] # saturation is the same for all components + self.phreeqcbmi.sat_now = sat # set phreeqcmbi saturation + return sat + def get_time_units_from_mf6(self): '''Function to get the time units from mf6 ''' @@ -467,8 +504,6 @@ def _update_selected_output(self): def __replace_inactive_cells_in_sout(self, sout, diffmask): '''Function to replace inactive cells in the selected output dataframe ''' - components = self.mf6api.modelnmes[1:] - headers = self.phreeqcbmi.sout_headers # match headers in components closest string inactive_idx = get_indices(0, diffmask) @@ -545,7 +580,7 @@ def _solve(self): assert self._check_sout_exist(), f'{self.sout_fname} not found' print("Starting transport solution at {0}".format(sim_start.strftime(DT_FMT))) - print(f"Solving the following components: {', '.join([nme for nme in self.mf6api.modelnmes])}") + # print(f"Solving the following components: {', '.join([nme for nme in self.mf6api.modelnmes])}") ctime = self._set_ctime() etime = self._set_etime() while ctime < etime: @@ -556,6 +591,9 @@ def _solve(self): self.mf6api.prepare_time_step(dt) self.mf6api._solve_gwt() + # get saturation + self.get_saturation_from_mf6() + if self.reactive: # reaction block c_dbl_vect = self._transfer_array_to_phreeqcrm()