Skip to content

Commit

Permalink
Complete support for Python 3 by adding small changes to the main Par…
Browse files Browse the repository at this point in the history
…ameters class.
  • Loading branch information
matthewwardrop committed Dec 8, 2015
1 parent aa7af52 commit 6ccbe32
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 34 deletions.
47 changes: 24 additions & 23 deletions parampy/parameters.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from .iteration import RangesIterator
from .quantities import Quantity
from .text import colour_text
from .units import Units, Unit
from .utility.compat import str_types

import copy
import imp
Expand Down Expand Up @@ -446,7 +447,7 @@ class Parameters(object):

def __get_unit(self, unit):

if isinstance(unit, str):
if isinstance(unit, str_types):
return self.__units(unit)

elif isinstance(unit, Units):
Expand Down Expand Up @@ -497,7 +498,7 @@ class Parameters(object):

############# PARAMETER RESOLUTION #########################################
def __get_pam_name(self, param):
if isinstance(param, str):
if isinstance(param, str_types):
if param[:1] == "_":
return param[1:]
return param
Expand All @@ -524,7 +525,7 @@ class Parameters(object):

value = self.__parameters[param]
if type(value) == types.FunctionType:
self.__cache_deps[param] = map(self.__get_pam_name, self.__function_getargs(value))
self.__cache_deps[param] = list(map(self.__get_pam_name, self.__function_getargs(value)))
else:
self.__cache_deps[param] = []

Expand Down Expand Up @@ -578,7 +579,7 @@ class Parameters(object):
'''
checked = []
for arg in args:
if isinstance(arg, str):
if isinstance(arg, str_types):
for pam in self.__get_pam_sups(arg):
if pam in self.__parameters_bounds:
keys = self.__get_pam_deps(pam)
Expand Down Expand Up @@ -609,7 +610,7 @@ class Parameters(object):
pam_name = self.__get_pam_name(arg)

# If the parameter is actually a function or otherwise not directly in the dictionary of stored parameters
if not isinstance(arg, str) or (pam_name not in kwargs and pam_name not in self.__parameters):
if not isinstance(arg, str_types) or (pam_name not in kwargs and pam_name not in self.__parameters):
return self.__eval(arg, kwargs, default_scaled)
else:
scaled = default_scaled if default_scaled is not None else self.__default_scaled
Expand Down Expand Up @@ -647,7 +648,7 @@ class Parameters(object):
return

if restrict is None:
restrict = kwargs.keys()
restrict = list(kwargs.keys())

if len(restrict) == 0:
return
Expand Down Expand Up @@ -686,12 +687,12 @@ class Parameters(object):
if pam[0] == "_":
raise ValueError("Parameter type is autodetected when passed as a keyword argument. Do not use '_' to switch between scaled and unitted parameters.")
val = kwargs[pam]
if type(val) is str:
if type(val) in str_types:
val = self.__get_function(val)
kwargs[pam] = val
if type(val) is tuple and type(val[0]) is types.FunctionType:
val = val[0]
if type(val) is tuple and type(val[0]) is str:
if type(val) is tuple and type(val[0]) in str_types:
val = self.__get_function(val[0])
if type(val) is types.FunctionType:
pam = self.__get_pam_name(pam)
Expand Down Expand Up @@ -728,7 +729,7 @@ class Parameters(object):

if len(new) != 0:
kwargs.update(new)
self.__process_override(kwargs, restrict=new.keys())
self.__process_override(kwargs, restrict=list(new.keys()))

def __eval_function(self, param, kwargs={}):
'''
Expand Down Expand Up @@ -902,9 +903,9 @@ class Parameters(object):
elif isinstance(arg, Quantity):
return self.__get_quantity(arg, scaled=default_scaled)

elif isinstance(arg, str) or arg.__class__.__module__.startswith('sympy'):
elif isinstance(arg, str_types) or arg.__class__.__module__.startswith('sympy'):
try:
if isinstance(arg, str):
if isinstance(arg, str_types):
# We have a string which cannot be a single parameter. Check to see if it is trying to be.
arg = sympy.S(arg, sympy.abc._clash)
fs = list(arg.free_symbols)
Expand Down Expand Up @@ -946,10 +947,10 @@ class Parameters(object):
self.__cache_funcs[param] = None
if param in self.__cache_scaled: # Clear cache if present.
del self.__cache_scaled[param]
if isinstance(val, (types.FunctionType, str)):
if isinstance(val, (types.FunctionType,) + str_types):
self.__parameters[param] = self.__check_function(param, self.__get_function(val))
self.__spec({param: self.__get_unit('')})
elif isinstance(val, (list, tuple)) and isinstance(val[0], (types.FunctionType, str)):
elif isinstance(val, (list, tuple)) and isinstance(val[0], (types.FunctionType,) + str_types):
self.__parameters[param] = self.__check_function(param, self.__get_function(val[0]))
self.__spec({param: self.__get_unit(val[1])})
else:
Expand Down Expand Up @@ -1257,7 +1258,7 @@ class Parameters(object):
for param in params:
bounds[param] = self.__parameters_bounds[param].bounds if param in self.__parameters_bounds else None
if not use_dict and len(bounds) == 1:
return bounds[bounds.keys()[0]]
return bounds[list(bounds.keys())[0]]
return bounds

def set_bounds(self, bounds_dict, error=True, clip=False, inclusive=True):
Expand Down Expand Up @@ -1463,7 +1464,7 @@ class Parameters(object):
return values

def __range_sampler(self, sampler):
if isinstance(sampler, str):
if isinstance(sampler, str_types):
if sampler == 'linear':
return np.linspace
elif sampler == 'log':
Expand Down Expand Up @@ -1498,7 +1499,7 @@ class Parameters(object):
sampler = self.__range_sampler(sampler)

for i, arg in enumerate(args):
if isinstance(arg, (tuple, str, Quantity)):
if isinstance(arg, (tuple, Quantity) + str_types):
pars = {param: arg}
if type(params) is dict:
pars.update(params)
Expand Down Expand Up @@ -1559,7 +1560,7 @@ class Parameters(object):
for param, value in kwargs.items():
d[param] = self.convert(value, output=self.units(param), value=True)
if len(d) == 1:
return d.values()[0]
return list(d.values())[0]
return d

def asscaled(self, **kwargs):
Expand Down Expand Up @@ -1587,7 +1588,7 @@ class Parameters(object):
for param, value in kwargs.items():
d[param] = self.convert(value)
if len(d) == 1:
return d.values()[0]
return list(d.values())[0]
return d

def units(self, *params):
Expand Down Expand Up @@ -1695,10 +1696,10 @@ class Parameters(object):
<function with argument t, with x evaluated to 1>
'''

if param is None or isinstance(param, types.FunctionType) or isinstance(param, str) and self.__is_valid_param(param):
if param is None or isinstance(param, types.FunctionType) or isinstance(param, str_types) and self.__is_valid_param(param):
return param

elif isinstance(param, str) or type(param).__module__.startswith('sympy'):
elif isinstance(param, str_types) or type(param).__module__.startswith('sympy'):
if len(wrt) > 0:
subs = {}
expr = sympy.S(param, locals=sympy.abc._clash)
Expand Down Expand Up @@ -1782,7 +1783,7 @@ class Parameters(object):
except:
raise errors.ParameterInvalidError("This parameters instance has no parameter named '%s', and none was provided. Parameter may or may not be constant." % param)

if isinstance(param_val, types.FunctionType) or isinstance(param_val, str) and isinstance(self.optimise(param_val), (types.FunctionType, str)):
if isinstance(param_val, types.FunctionType) or isinstance(param_val, str_types) and isinstance(self.optimise(param_val), (types.FunctionType,) + str_types):
return True
return False

Expand Down Expand Up @@ -1834,7 +1835,7 @@ class Parameters(object):
except:
raise errors.ParameterInvalidError("This parameters instance has no parameter named '%s', and none was provided. Parameter may or may not be constant." % param)

if isinstance(param_val, str):
if isinstance(param_val, str_types):
param_val = self.optimise(param_val)
else:
param_val = self.__get_quantity(param_val)
Expand Down Expand Up @@ -1964,7 +1965,7 @@ class Parameters(object):

def __rshift__(self, other):

if not isinstance(other, str):
if not isinstance(other, str_types):
raise errors.ParametersException("The right shift operator is used to save the parameters to a file. The operand must be a filename.")

self.__save__(other)
Expand Down
10 changes: 6 additions & 4 deletions parampy/quantities.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import numpy as np

from .units import UnitDispenser, Units
from .text import colour_text
from .utility.compat import UnicodeMixin

class Quantity(object):
@total_ordering
class Quantity(UnicodeMixin):
'''
Quantity (value,units=None,absolute=False,dispenser=None)
Expand Down Expand Up @@ -228,9 +230,6 @@ class Quantity(object):
def __unicode__(self):
return u"%s %s" % (self.value, unicode(self.units)) + (u" (abs)" if self.absolute else u"")

def __str__(self):
return unicode(self).encode('utf-8')

# Arithmetic
def __add__(self, other, reverse=False):
if other == 0:
Expand Down Expand Up @@ -339,6 +338,9 @@ class Quantity(object):

def __ne__(self,other):
return not self.__eq__(other)

def __lt__(self,other): # Python 3 and newer ignore __cmp__
return self.__cmp__(other) == -1

def __cmp__(self, other):
if type(other) is tuple and len(other) == 2:
Expand Down
8 changes: 4 additions & 4 deletions parampy/units.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import re, types, inspect

from . import errors
from .text import colour_text
from .utility.compat import UnicodeMixin
from .utility.compat import UnicodeMixin, str_types


class Unit(UnicodeMixin):
Expand Down Expand Up @@ -635,7 +635,7 @@ class UnitDispenser(UnicodeMixin):
:returns: :class:`Unit` object associated with a the string representation.
:raises: :class:`UnitInvalidError` if no unit can be found that matches.
'''
if isinstance(unit, str):
if isinstance(unit, str_types):
try:
return self._units[unit]
except:
Expand Down Expand Up @@ -829,7 +829,7 @@ class Units(UnicodeMixin):
units[self.__get_unit(unit)] = Fraction(units.pop(unit))
return units

elif isinstance(units, str):
elif isinstance(units, str_types):
d = {}

if units == "units":
Expand Down Expand Up @@ -894,7 +894,7 @@ class Units(UnicodeMixin):
if getattr(self, '__scale_cache', None) is None:
self.__scale_cache = {}

if isinstance(other, str):
if isinstance(other, str_types):
other = self.__dispenser(other)

dims = self.dimensions
Expand Down
5 changes: 5 additions & 0 deletions parampy/utility/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ def strrep(obj):
return unicode(obj).encode('utf-8')
return obj.__unicode__()

if sys.version_info[0] >= 3:
str_types = (str,)
else:
str_types = (str,unicode)

class UnicodeMixin(object):

"""Mixin class to handle defining the proper __str__/__unicode__
Expand Down
6 changes: 3 additions & 3 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_ufunc(self):

self.assertRaises( errors.UnitConversionError, np.tan, SIQuantity(1,'m') )

class TestParameters(): # unittest.TestCase
class TestParameters(unittest.TestCase):

def setUp(self):
self.p = Parameters(default_scaled=False,constants=True)
Expand Down Expand Up @@ -204,8 +204,8 @@ def bad():
def test_asvalue(self):
self.p(x=(1,'J'))
self.p.scaling(mass=(-1000,'kg'))
self.assertEquals( self.p.asvalue(x=np.array([1,2,3])).tolist(),[-1000,-2000,-3000] )
self.assertEquals( self.p.asvalue(x=np.array([1,2,3]),y=np.array([1,2,3]))['y'].tolist(),[1,2,3] )
self.assertEquals( self.p.asvalue(x=np.array([1.,2.,3.])).tolist(),[-1000.,-2000.,-3000.] )
self.assertEquals( self.p.asvalue(x=np.array([1.,2.,3.]),y=np.array([1.,2.,3.]))['y'].tolist(),[1.,2.,3.] )

def test_bounds(self):
self.p(x=(1,'J'))
Expand Down

0 comments on commit 6ccbe32

Please sign in to comment.