From 29b2c99a355728e10d70ba1082626d6bbc2fdea6 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Tue, 31 May 2016 17:13:57 -0700 Subject: [PATCH] NF: add string and slice slicing of AxesManager After this PR, you can do the following two slicing operations, that previously raised errors: >>> da = DataArray(np.ones((2, 3, 4)), 'abc') >>> da.axes['b'] Axis(name='b', index=1, labels=None) >>> da.axes[1:] (Axis(name='b', index=1, labels=None), Axis(name='c', index=2, labels=None)) Note that, as currently written, these two both return an Axis: >>> da.axes['b'] Axis(name='b', index=1, labels=None) >>> da.axes('b') Axis(name='b', index=1, labels=None) --- datarray/datarray.py | 52 +++++++++++++++++++++++-------- datarray/tests/test_data_array.py | 43 ++++++++++++++++--------- 2 files changed, 67 insertions(+), 28 deletions(-) diff --git a/datarray/datarray.py b/datarray/datarray.py index 336b1c4..37132b6 100644 --- a/datarray/datarray.py +++ b/datarray/datarray.py @@ -99,17 +99,26 @@ class AxesManager(object): DataArray(array(True, dtype=bool), ('date', ('stocks', ('aapl', 'ibm')), 'metric')) - - Axes can also be accessed numerically: + Axes can be accessed numerically: >>> A.axes[1] is A.axes.stocks True - Calling the AxesManager with string arguments will return an + The axis name can be used as an index, as well as an attribute: + + >>> A.axes['stocks'] is A.axes.stocks + True + + Axes can also be sliced: + + >>> A.axes[1:] + (Axis(name='stocks', index=1, labels=('aapl', 'ibm', 'goog', 'msft')), Axis(name='metric', index=2, labels=None)) + + *Calling* the AxesManager with string arguments will return an :py:class:`AxisIndexer` object which can be used to restrict slices to specified axes: - >>> Ai = A.axes('stocks', 'date') + >>> Ai = A.axes('stocks', 'date') # Note the parens >>> np.all(Ai['aapl':'goog', 100] == A[100, 0:2]) DataArray(array(True, dtype=bool), (('stocks', ('aapl', 'ibm')), 'metric')) @@ -156,20 +165,37 @@ def __getitem__(self, n): Parameters ---------- - n : int - Index of axis to be returned. + n : int or string or slice + If int, index of axis to be returned. If string, name of axis to + be returned. If slice object, slice from AxesManager to return. Returns ------- - The requested :py:class:`Axis`. - + ax : Axis or AxesManager + The requested :py:class:`Axis` if `n` is an int or string. A new + AxesManager object if `n` is a slice. """ - if not isinstance(n, int): - raise TypeError("AxesManager expects integer index") + axes = object.__getattribute__(self, '_axes') + if isinstance(n, int): # Integer slicing retuns Axis + try: + return axes[n] + except IndexError: + raise IndexError("Requested axis %i out of bounds" % n) + # Indexing by name returns Axis + namemap = object.__getattribute__(self, '_namemap') + try: + n = namemap[n] + except TypeError: + pass + else: + return axes[n] + # Indexing with slice object returns new AxesManager try: - return object.__getattribute__(self, '_axes')[n] - except IndexError: - raise IndexError("Requested axis %i out of bounds" % n) + new_axes = axes[n] + except TypeError: + raise TypeError("Invalid axis index {0}".format(n)) + arr = object.__getattribute__(self, '_arr') + return type(self)(arr, new_axes) def __eq__(self, other): """Test for equality between two axes managers. Two axes managers are diff --git a/datarray/tests/test_data_array.py b/datarray/tests/test_data_array.py index e8f6900..dfbef5a 100644 --- a/datarray/tests/test_data_array.py +++ b/datarray/tests/test_data_array.py @@ -5,8 +5,8 @@ import numpy as np -from datarray.datarray import Axis, DataArray, NamedAxisError, \ - _pull_axis, _reordered_axes +from datarray.datarray import (Axis, AxesManager, DataArray, NamedAxisError, + _pull_axis, _reordered_axes) import nose.tools as nt import numpy.testing as npt @@ -461,6 +461,14 @@ class TestAxesManager: def setUp(self): self.axes_spec = ('date', ('stocks', ('aapl', 'ibm', 'goog', 'msft')), 'metric') self.A = DataArray(np.random.randn(200, 4, 10), axes=self.axes_spec) + self.axes = [] + for i, spec in enumerate(self.axes_spec): + try: + name, labels = spec + except ValueError: + name, labels = spec, None + self.axes.append( + Axis(name=name, index=i, parent_arr=self.A, labels=labels)) def test_axes_name_collision(self): "Test .axes object for attribute collisions with axis names" @@ -475,21 +483,26 @@ def test_axes_name_collision(self): nt.assert_equal(B.shape, (1,1,2,3)) nt.assert_true(np.all(A + A == 2*A)) - def test_axes_numeric_access(self): - for i,spec in enumerate(self.axes_spec): - try: - name,labels = spec - except ValueError: - name,labels = spec,None - nt.assert_true(self.A.axes[i] == Axis(name=name, index=i, - parent_arr=self.A, labels=labels)) + def test_axes_indexing(self): + n_axes = len(self.axes) + for i, exp_axis in enumerate(self.axes): + # Index with integer + nt.assert_equal(self.A.axes[i], exp_axis) + # Negative integer + nt.assert_equal(self.A.axes[i - n_axes], exp_axis) + # Name + nt.assert_equal(self.A.axes[exp_axis.name], exp_axis) + # Single element slice + one_axis = self.A.axes[i:i + 1] + nt.assert_equal(len(one_axis), 1) + nt.assert_equal(one_axis[0], exp_axis) + # Slice with more than one element + nt.assert_equal(self.A.axes[1:], + AxesManager(np.array(self.A), self.axes[1:])) def test_axes_attribute_access(self): - for spec in self.axes_spec: - try: - name,labels = spec - except ValueError: - name,labels = spec,None + for axis in self.axes: + name = axis.name nt.assert_true(getattr(self.A.axes, name) is self.A.axes(name)) def test_equality(self):