Skip to content

Commit

Permalink
allow parsing function to add_py_function (#549)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinvonk authored Nov 15, 2024
1 parent d2e85b3 commit 6080911
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
2 changes: 1 addition & 1 deletion autotest/pst_from_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def freyberg_test(tmp_path):
# (generated by pyemu.gw_utils.setup_hds_obs())
f, fdf = _gen_dummy_obs_file(pf.new_d)
pf.add_observations(f, index_cols='idx', use_cols='yes')
pf.add_py_function(__file__, '_gen_dummy_obs_file()',
pf.add_py_function(_gen_dummy_obs_file, '_gen_dummy_obs_file()',
is_pre_cmd=False)
pf.add_observations('freyberg.hds.dat', insfile='freyberg.hds.dat.ins2',
index_cols='obsnme', use_cols='obsval', prefix='hds')
Expand Down
45 changes: 28 additions & 17 deletions pyemu/utils/pst_from.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import print_function, division
from __future__ import division, print_function

import copy
import os
from pathlib import Path
import warnings
import platform
import string
import warnings
from inspect import getsource
from pathlib import Path
from typing import Callable, Union

import numpy as np
import pandas as pd
import pyemu
from ..pyemu_warnings import PyemuWarning
import copy
import string

import pyemu
from pyemu.utils.helpers import _try_pdcol_numeric

from ..pyemu_warnings import PyemuWarning

# the tolerable percent difference (100 * (max - min)/mean)
# used when checking that constant and zone type parameters are in fact constant (within
# a given zone)
Expand Down Expand Up @@ -1187,13 +1192,16 @@ def _next_count(self, prefix):
return self._prefix_count[prefix]

def add_py_function(
self, file_name, call_str=None, is_pre_cmd=True,
function_name=None
self,
file_name: Union[str, Callable],
call_str: Union[None, str] = None,
is_pre_cmd: Union[bool, None] = True,
function_name=None,
):
"""add a python function to the forward run script
Args:
file_name (`str`): a python source file
file_name (`str` or `callable`): a python source file or function/callable
call_str (`str`): the call string for python function in
`file_name`.
`call_str` will be added to the forward run script, as is.
Expand All @@ -1213,7 +1221,7 @@ def add_py_function(
`PstFrom.extra_py_imports` list.
This function adds the `call_str` call to the forward
run script (either as a pre or post command or function not
run script (either as a pre or post command or function not
directly called by main). It is up to users
to make sure `call_str` is a valid python function call
that includes the parentheses and requisite arguments
Expand Down Expand Up @@ -1249,12 +1257,6 @@ def add_py_function(
self.logger.lraise(
"add_py_function(): No function call string passed in arg " "'call_str'"
)
if not os.path.exists(file_name):
self.logger.lraise(
"add_py_function(): couldnt find python source file '{0}'".format(
file_name
)
)
if "(" not in call_str or ")" not in call_str:
self.logger.lraise(
"add_py_function(): call_str '{0}' missing paretheses".format(call_str)
Expand All @@ -1270,7 +1272,16 @@ def add_py_function(
f"original will be maintained",
PyemuWarning,
)
if callable(file_name):
func_lines = getsource(file_name).splitlines(keepends=True)
self._function_lines_list.append(func_lines)
else:
if not os.path.exists(file_name):
self.logger.lraise(
"add_py_function(): couldnt find python source file '{0}'".format(
file_name
)
)
func_lines = []
search_str = "def " + function_name + "("
abet_set = set(string.ascii_uppercase)
Expand Down

0 comments on commit 6080911

Please sign in to comment.