diff --git a/parampy/parameters.pyx b/parampy/parameters.pyx index 263af18..78f198a 100644 --- a/parampy/parameters.pyx +++ b/parampy/parameters.pyx @@ -7,6 +7,7 @@ from .iteration import RangesIterator from .quantities import Quantity from .text import colour_text from .units import Units, Unit +from .utility.compat import str_types import copy import imp @@ -446,7 +447,7 @@ class Parameters(object): def __get_unit(self, unit): - if isinstance(unit, str): + if isinstance(unit, str_types): return self.__units(unit) elif isinstance(unit, Units): @@ -497,7 +498,7 @@ class Parameters(object): ############# PARAMETER RESOLUTION ######################################### def __get_pam_name(self, param): - if isinstance(param, str): + if isinstance(param, str_types): if param[:1] == "_": return param[1:] return param @@ -524,7 +525,7 @@ class Parameters(object): value = self.__parameters[param] if type(value) == types.FunctionType: - self.__cache_deps[param] = map(self.__get_pam_name, self.__function_getargs(value)) + self.__cache_deps[param] = list(map(self.__get_pam_name, self.__function_getargs(value))) else: self.__cache_deps[param] = [] @@ -578,7 +579,7 @@ class Parameters(object): ''' checked = [] for arg in args: - if isinstance(arg, str): + if isinstance(arg, str_types): for pam in self.__get_pam_sups(arg): if pam in self.__parameters_bounds: keys = self.__get_pam_deps(pam) @@ -609,7 +610,7 @@ class Parameters(object): pam_name = self.__get_pam_name(arg) # If the parameter is actually a function or otherwise not directly in the dictionary of stored parameters - if not isinstance(arg, str) or (pam_name not in kwargs and pam_name not in self.__parameters): + if not isinstance(arg, str_types) or (pam_name not in kwargs and pam_name not in self.__parameters): return self.__eval(arg, kwargs, default_scaled) else: scaled = default_scaled if default_scaled is not None else self.__default_scaled @@ -647,7 +648,7 @@ class Parameters(object): return if restrict is None: - restrict = kwargs.keys() + restrict = list(kwargs.keys()) if len(restrict) == 0: return @@ -686,12 +687,12 @@ class Parameters(object): if pam[0] == "_": raise ValueError("Parameter type is autodetected when passed as a keyword argument. Do not use '_' to switch between scaled and unitted parameters.") val = kwargs[pam] - if type(val) is str: + if type(val) in str_types: val = self.__get_function(val) kwargs[pam] = val if type(val) is tuple and type(val[0]) is types.FunctionType: val = val[0] - if type(val) is tuple and type(val[0]) is str: + if type(val) is tuple and type(val[0]) in str_types: val = self.__get_function(val[0]) if type(val) is types.FunctionType: pam = self.__get_pam_name(pam) @@ -728,7 +729,7 @@ class Parameters(object): if len(new) != 0: kwargs.update(new) - self.__process_override(kwargs, restrict=new.keys()) + self.__process_override(kwargs, restrict=list(new.keys())) def __eval_function(self, param, kwargs={}): ''' @@ -902,9 +903,9 @@ class Parameters(object): elif isinstance(arg, Quantity): return self.__get_quantity(arg, scaled=default_scaled) - elif isinstance(arg, str) or arg.__class__.__module__.startswith('sympy'): + elif isinstance(arg, str_types) or arg.__class__.__module__.startswith('sympy'): try: - if isinstance(arg, str): + if isinstance(arg, str_types): # We have a string which cannot be a single parameter. Check to see if it is trying to be. arg = sympy.S(arg, sympy.abc._clash) fs = list(arg.free_symbols) @@ -946,10 +947,10 @@ class Parameters(object): self.__cache_funcs[param] = None if param in self.__cache_scaled: # Clear cache if present. del self.__cache_scaled[param] - if isinstance(val, (types.FunctionType, str)): + if isinstance(val, (types.FunctionType,) + str_types): self.__parameters[param] = self.__check_function(param, self.__get_function(val)) self.__spec({param: self.__get_unit('')}) - elif isinstance(val, (list, tuple)) and isinstance(val[0], (types.FunctionType, str)): + elif isinstance(val, (list, tuple)) and isinstance(val[0], (types.FunctionType,) + str_types): self.__parameters[param] = self.__check_function(param, self.__get_function(val[0])) self.__spec({param: self.__get_unit(val[1])}) else: @@ -1257,7 +1258,7 @@ class Parameters(object): for param in params: bounds[param] = self.__parameters_bounds[param].bounds if param in self.__parameters_bounds else None if not use_dict and len(bounds) == 1: - return bounds[bounds.keys()[0]] + return bounds[list(bounds.keys())[0]] return bounds def set_bounds(self, bounds_dict, error=True, clip=False, inclusive=True): @@ -1463,7 +1464,7 @@ class Parameters(object): return values def __range_sampler(self, sampler): - if isinstance(sampler, str): + if isinstance(sampler, str_types): if sampler == 'linear': return np.linspace elif sampler == 'log': @@ -1498,7 +1499,7 @@ class Parameters(object): sampler = self.__range_sampler(sampler) for i, arg in enumerate(args): - if isinstance(arg, (tuple, str, Quantity)): + if isinstance(arg, (tuple, Quantity) + str_types): pars = {param: arg} if type(params) is dict: pars.update(params) @@ -1559,7 +1560,7 @@ class Parameters(object): for param, value in kwargs.items(): d[param] = self.convert(value, output=self.units(param), value=True) if len(d) == 1: - return d.values()[0] + return list(d.values())[0] return d def asscaled(self, **kwargs): @@ -1587,7 +1588,7 @@ class Parameters(object): for param, value in kwargs.items(): d[param] = self.convert(value) if len(d) == 1: - return d.values()[0] + return list(d.values())[0] return d def units(self, *params): @@ -1695,10 +1696,10 @@ class Parameters(object): ''' - if param is None or isinstance(param, types.FunctionType) or isinstance(param, str) and self.__is_valid_param(param): + if param is None or isinstance(param, types.FunctionType) or isinstance(param, str_types) and self.__is_valid_param(param): return param - elif isinstance(param, str) or type(param).__module__.startswith('sympy'): + elif isinstance(param, str_types) or type(param).__module__.startswith('sympy'): if len(wrt) > 0: subs = {} expr = sympy.S(param, locals=sympy.abc._clash) @@ -1782,7 +1783,7 @@ class Parameters(object): except: raise errors.ParameterInvalidError("This parameters instance has no parameter named '%s', and none was provided. Parameter may or may not be constant." % param) - if isinstance(param_val, types.FunctionType) or isinstance(param_val, str) and isinstance(self.optimise(param_val), (types.FunctionType, str)): + if isinstance(param_val, types.FunctionType) or isinstance(param_val, str_types) and isinstance(self.optimise(param_val), (types.FunctionType,) + str_types): return True return False @@ -1834,7 +1835,7 @@ class Parameters(object): except: raise errors.ParameterInvalidError("This parameters instance has no parameter named '%s', and none was provided. Parameter may or may not be constant." % param) - if isinstance(param_val, str): + if isinstance(param_val, str_types): param_val = self.optimise(param_val) else: param_val = self.__get_quantity(param_val) @@ -1964,7 +1965,7 @@ class Parameters(object): def __rshift__(self, other): - if not isinstance(other, str): + if not isinstance(other, str_types): raise errors.ParametersException("The right shift operator is used to save the parameters to a file. The operand must be a filename.") self.__save__(other) diff --git a/parampy/quantities.pyx b/parampy/quantities.pyx index 240c352..132b76f 100644 --- a/parampy/quantities.pyx +++ b/parampy/quantities.pyx @@ -6,8 +6,10 @@ import numpy as np from .units import UnitDispenser, Units from .text import colour_text +from .utility.compat import UnicodeMixin -class Quantity(object): +@total_ordering +class Quantity(UnicodeMixin): ''' Quantity (value,units=None,absolute=False,dispenser=None) @@ -228,9 +230,6 @@ class Quantity(object): def __unicode__(self): return u"%s %s" % (self.value, unicode(self.units)) + (u" (abs)" if self.absolute else u"") - def __str__(self): - return unicode(self).encode('utf-8') - # Arithmetic def __add__(self, other, reverse=False): if other == 0: @@ -339,6 +338,9 @@ class Quantity(object): def __ne__(self,other): return not self.__eq__(other) + + def __lt__(self,other): # Python 3 and newer ignore __cmp__ + return self.__cmp__(other) == -1 def __cmp__(self, other): if type(other) is tuple and len(other) == 2: diff --git a/parampy/units.pyx b/parampy/units.pyx index 9d4fbd9..41dc01b 100644 --- a/parampy/units.pyx +++ b/parampy/units.pyx @@ -3,7 +3,7 @@ import re, types, inspect from . import errors from .text import colour_text -from .utility.compat import UnicodeMixin +from .utility.compat import UnicodeMixin, str_types class Unit(UnicodeMixin): @@ -635,7 +635,7 @@ class UnitDispenser(UnicodeMixin): :returns: :class:`Unit` object associated with a the string representation. :raises: :class:`UnitInvalidError` if no unit can be found that matches. ''' - if isinstance(unit, str): + if isinstance(unit, str_types): try: return self._units[unit] except: @@ -829,7 +829,7 @@ class Units(UnicodeMixin): units[self.__get_unit(unit)] = Fraction(units.pop(unit)) return units - elif isinstance(units, str): + elif isinstance(units, str_types): d = {} if units == "units": @@ -894,7 +894,7 @@ class Units(UnicodeMixin): if getattr(self, '__scale_cache', None) is None: self.__scale_cache = {} - if isinstance(other, str): + if isinstance(other, str_types): other = self.__dispenser(other) dims = self.dimensions diff --git a/parampy/utility/compat.py b/parampy/utility/compat.py index 642625b..6aa074c 100644 --- a/parampy/utility/compat.py +++ b/parampy/utility/compat.py @@ -8,6 +8,11 @@ def strrep(obj): return unicode(obj).encode('utf-8') return obj.__unicode__() +if sys.version_info[0] >= 3: + str_types = (str,) +else: + str_types = (str,unicode) + class UnicodeMixin(object): """Mixin class to handle defining the proper __str__/__unicode__ diff --git a/tests.py b/tests.py index 259f3e4..030404f 100644 --- a/tests.py +++ b/tests.py @@ -69,7 +69,7 @@ def test_ufunc(self): self.assertRaises( errors.UnitConversionError, np.tan, SIQuantity(1,'m') ) -class TestParameters(): # unittest.TestCase +class TestParameters(unittest.TestCase): def setUp(self): self.p = Parameters(default_scaled=False,constants=True) @@ -204,8 +204,8 @@ def bad(): def test_asvalue(self): self.p(x=(1,'J')) self.p.scaling(mass=(-1000,'kg')) - self.assertEquals( self.p.asvalue(x=np.array([1,2,3])).tolist(),[-1000,-2000,-3000] ) - self.assertEquals( self.p.asvalue(x=np.array([1,2,3]),y=np.array([1,2,3]))['y'].tolist(),[1,2,3] ) + self.assertEquals( self.p.asvalue(x=np.array([1.,2.,3.])).tolist(),[-1000.,-2000.,-3000.] ) + self.assertEquals( self.p.asvalue(x=np.array([1.,2.,3.]),y=np.array([1.,2.,3.]))['y'].tolist(),[1.,2.,3.] ) def test_bounds(self): self.p(x=(1,'J'))