Skip to content

Commit

Permalink
Add support for plotting subclasses of cftime.datetime (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerkclark authored and lbdreyer committed Jan 25, 2019
1 parent 46aee1a commit 531dd0d
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 22 deletions.
44 changes: 33 additions & 11 deletions nc_time_axis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,20 @@ def axisinfo(unit, axis):
*unit* is a tzinfo instance or None.
The *axis* argument is required but not used.
"""
calendar, date_unit = unit
calendar, date_unit, date_type = unit

majloc = NetCDFTimeDateLocator(4, calendar=calendar,
date_unit=date_unit)
majfmt = NetCDFTimeDateFormatter(majloc, calendar=calendar,
time_units=date_unit)
datemin = CalendarDateTime(cftime.datetime(2000, 1, 1), calendar)
datemax = CalendarDateTime(cftime.datetime(2010, 1, 1), calendar)
if date_type is CalendarDateTime:
datemin = CalendarDateTime(cftime.datetime(2000, 1, 1),
calendar=calendar)
datemax = CalendarDateTime(cftime.datetime(2010, 1, 1),
calendar=calendar)
else:
datemin = date_type(2000, 1, 1)
datemax = date_type(2010, 1, 1)
return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label='',
default_limits=(datemin, datemax))

Expand All @@ -235,6 +241,7 @@ def default_units(cls, sample_point, axis):
calendar = calendars[0]
else:
raise ValueError('Calendar units are not all equal.')
date_type = type(sample_point[0])
else:
# Deal with a single `sample_point` value.
if not hasattr(sample_point, 'calendar'):
Expand All @@ -243,7 +250,8 @@ def default_units(cls, sample_point, axis):
raise ValueError(msg)
else:
calendar = sample_point.calendar
return calendar, cls.standard_unit
date_type = type(sample_point)
return calendar, cls.standard_unit, date_type

@classmethod
def convert(cls, value, unit, axis):
Expand All @@ -266,20 +274,27 @@ def convert(cls, value, unit, axis):
return value
first_value = value

if not isinstance(first_value, CalendarDateTime):
if not isinstance(first_value, (CalendarDateTime, cftime.datetime)):
raise ValueError('The values must be numbers or instances of '
'"nc_time_axis.CalendarDateTime".')
'"nc_time_axis.CalendarDateTime" or '
'"cftime.datetime".')

if not isinstance(first_value.datetime, cftime.datetime):
raise ValueError('The datetime attribute of the CalendarDateTime '
'object must be of type `cftime.datetime`.')
if isinstance(first_value, CalendarDateTime):
if not isinstance(first_value.datetime, cftime.datetime):
raise ValueError('The datetime attribute of the '
'CalendarDateTime object must be of type '
'`cftime.datetime`.')

ut = cftime.utime(cls.standard_unit, calendar=first_value.calendar)

if isinstance(value, CalendarDateTime):
if isinstance(value, (CalendarDateTime, cftime.datetime)):
value = [value]

result = ut.date2num([v.datetime for v in value])
if isinstance(first_value, CalendarDateTime):
result = ut.date2num([v.datetime for v in value])
else:
result = ut.date2num(value)

if shape is not None:
result = result.reshape(shape)

Expand All @@ -290,3 +305,10 @@ def convert(cls, value, unit, axis):
# dictionary.
if CalendarDateTime not in munits.registry:
munits.registry[CalendarDateTime] = NetCDFTimeConverter()

CFTIME_TYPES = [cftime.DatetimeNoLeap, cftime.DatetimeAllLeap,
cftime.DatetimeProlepticGregorian, cftime.DatetimeGregorian,
cftime.Datetime360Day, cftime.DatetimeJulian]
for date_type in CFTIME_TYPES:
if date_type not in munits.registry:
munits.registry[date_type] = NetCDFTimeConverter()
9 changes: 8 additions & 1 deletion nc_time_axis/tests/integration/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def tearDown(self):
# in an odd state, so we make sure it's been disposed of.
plt.close('all')

def test_360_day_calendar(self):
def test_360_day_calendar_CalendarDateTime(self):
datetimes = [cftime.datetime(1986, month, 30)
for month in range(1, 6)]
cal_datetimes = [nc_time_axis.CalendarDateTime(dt, '360_day')
Expand All @@ -34,6 +34,13 @@ def test_360_day_calendar(self):
result_ydata = line1.get_ydata()
np.testing.assert_array_equal(result_ydata, cal_datetimes)

def test_360_day_calendar_raw_dates(self):
datetimes = [cftime.Datetime360Day(1986, month, 30)
for month in range(1, 6)]
line1, = plt.plot(datetimes)
result_ydata = line1.get_ydata()
np.testing.assert_array_equal(result_ydata, datetimes)


if __name__ == "__main__":
unittest.main()
53 changes: 43 additions & 10 deletions nc_time_axis/tests/unit/test_NetCDFTimeConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class Test_axisinfo(unittest.TestCase):
def test_axis_default_limits(self):
cal = '360_day'
unit = (cal, 'days since 2000-02-25 00:00:00')
unit = (cal, 'days since 2000-02-25 00:00:00', CalendarDateTime)
result = NetCDFTimeConverter().axisinfo(unit, None)
expected_dt = [cftime.datetime(2000, 1, 1),
cftime.datetime(2010, 1, 1)]
Expand All @@ -25,21 +25,21 @@ def test_axis_default_limits(self):


class Test_default_units(unittest.TestCase):
def test_360_day_calendar_point(self):
def test_360_day_calendar_point_CalendarDateTime(self):
calendar = '360_day'
unit = 'days since 2000-01-01'
val = CalendarDateTime(cftime.datetime(2014, 8, 12), calendar)
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit))
self.assertEqual(result, (calendar, unit, CalendarDateTime))

def test_360_day_calendar_list(self):
def test_360_day_calendar_list_CalendarDateTime(self):
calendar = '360_day'
unit = 'days since 2000-01-01'
val = [CalendarDateTime(cftime.datetime(2014, 8, 12), calendar)]
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit))
self.assertEqual(result, (calendar, unit, CalendarDateTime))

def test_360_day_calendar_nd(self):
def test_360_day_calendar_nd_CalendarDateTime(self):
# Test the case where the input is an nd-array.
calendar = '360_day'
unit = 'days since 2000-01-01'
Expand All @@ -48,7 +48,30 @@ def test_360_day_calendar_nd(self):
[CalendarDateTime(cftime.datetime(2014, 8, 13),
calendar)]])
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit))
self.assertEqual(result, (calendar, unit, CalendarDateTime))

def test_360_day_calendar_point_raw_date(self):
calendar = '360_day'
unit = 'days since 2000-01-01'
val = cftime.Datetime360Day(2014, 8, 12)
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit, cftime.Datetime360Day))

def test_360_day_calendar_list_raw_date(self):
calendar = '360_day'
unit = 'days since 2000-01-01'
val = [cftime.Datetime360Day(2014, 8, 12)]
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit, cftime.Datetime360Day))

def test_360_day_calendar_nd_raw_date(self):
# Test the case where the input is an nd-array.
calendar = '360_day'
unit = 'days since 2000-01-01'
val = np.array([[cftime.Datetime360Day(2014, 8, 12)],
[cftime.Datetime360Day(2014, 8, 13)]])
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit, cftime.Datetime360Day))

def test_nonequal_calendars(self):
# Test that different supplied calendars causes an error.
Expand Down Expand Up @@ -84,17 +107,27 @@ def test_numeric_iterable(self):
result = NetCDFTimeConverter().convert(val, None, None)
np.testing.assert_array_equal(result, val)

def test_cftime(self):
def test_cftime_CalendarDateTime(self):
val = CalendarDateTime(cftime.datetime(2014, 8, 12), '365_day')
result = NetCDFTimeConverter().convert(val, None, None)
np.testing.assert_array_equal(result, 5333.)

def test_cftime_np_array(self):
def test_cftime_raw_date(self):
val = cftime.DatetimeNoLeap(2014, 8, 12)
result = NetCDFTimeConverter().convert(val, None, None)
np.testing.assert_array_equal(result, 5333.)

def test_cftime_np_array_CalendarDateTime(self):
val = np.array([CalendarDateTime(cftime.datetime(2012, 6, 4),
'360_day')], dtype=np.object)
result = NetCDFTimeConverter().convert(val, None, None)
self.assertEqual(result, np.array([4473.]))

def test_cftime_np_array_raw_date(self):
val = np.array([cftime.Datetime360Day(2012, 6, 4)], dtype=np.object)
result = NetCDFTimeConverter().convert(val, None, None)
self.assertEqual(result, np.array([4473.]))

def test_non_cftime_datetime(self):
val = CalendarDateTime(4, '360_day')
msg = 'The datetime attribute of the CalendarDateTime object must ' \
Expand All @@ -103,7 +136,7 @@ def test_non_cftime_datetime(self):
result = NetCDFTimeConverter().convert(val, None, None)

def test_non_CalendarDateTime(self):
val = cftime.datetime(1988, 5, 6)
val = 'test'
msg = 'The values must be numbers or instances of ' \
'"nc_time_axis.CalendarDateTime".'
with assertRaisesRegex(self, ValueError, msg):
Expand Down

0 comments on commit 531dd0d

Please sign in to comment.