From 279c1cf0e519b22afb616316e6d0887e0c975177 Mon Sep 17 00:00:00 2001 From: Filipe Maia Date: Sun, 23 Oct 2016 23:16:50 +0200 Subject: [PATCH] Implement dot for multidimensional arrays --- afnumpy/linalg/linalg.py | 57 ++++++++++++++++++++++++++++++++++++++-- tests/test_linalg.py | 14 ++++++++-- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/afnumpy/linalg/linalg.py b/afnumpy/linalg/linalg.py index 759f998..9cd68d7 100644 --- a/afnumpy/linalg/linalg.py +++ b/afnumpy/linalg/linalg.py @@ -1,8 +1,10 @@ import afnumpy import arrayfire +import numpy from numpy.core import complexfloating, Inf, longdouble from afnumpy import asarray, sqrt, abs from afnumpy.lib import asfarray +from .. import private_utils as pu def isComplexType(t): return issubclass(t, complexfloating) @@ -13,8 +15,59 @@ def vdot(a, b): # TODO: Implement multidimensional dot def dot(a, b): - s = arrayfire.dot((a.flat.d_array), b.flat.d_array) - return afnumpy.ndarray((), dtype=a.dtype, af_array=s)[()] + # Arrayfire requires that the types match for dot and matmul + res_dtype = numpy.result_type(a,b) + a = a.astype(res_dtype, copy=False) + b = b.astype(res_dtype, copy=False) + if a.ndim == 1 and b.ndim == 1: + s = arrayfire.dot((a.flat.d_array), b.flat.d_array) + return afnumpy.ndarray((), dtype=a.dtype, af_array=s)[()] + + a_shape = a.shape + b_shape = b.shape + if a.ndim == 1: + a = a.reshape((a.shape[0],1)) + if b.ndim == 1: + b = b.reshape((b.shape[0],1)) + + if a.ndim == 2 and b.ndim == 2: + # Notice the order of the arguments to matmul. It's not a bug! + s = arrayfire.matmul(b.d_array, a.d_array) + return afnumpy.ndarray(pu.af_shape(s), dtype=pu.typemap(s.dtype()), af_array=s) + # Multidimensional dot is done with loops + + # Calculate the shape of the result array + a_shape = list(a_shape) + a_shape.pop(-1) + b_shape = list(b_shape) + b_shape.pop(-2) + res_shape = a_shape + b_shape + + # Initialize the output array + res = afnumpy.empty(res_shape, dtype=res_dtype) + + # Make sure the arrays are at least 3D + if a.ndim < 3: + a = a.reshape((1,)+a.shape) + if b.ndim < 3: + b = b.reshape((1,)+b.shape) + + # We're going to flatten the regions over which we're going to loop over + # to make our life easier and reduce the amount of indexing code + a = a.reshape((-1,a.shape[-2],a.shape[-1])) + b = b.reshape((-1,b.shape[-2],b.shape[-1])) + + # Initialize the output array. The shape matches the reshape a and b. + res = afnumpy.empty((a.shape[0], a.shape[-2], b.shape[0], + b.shape[-1]), dtype=a.dtype) + + # Loop through the flattened indices and calculate the matmuls + for i in range(0,a.shape[0]): + for j in range(0,b.shape[0]): + res[i,:,j,:] = afnumpy.dot(a[i],b[j]) + + # Finally appropriately reshape the result array + return res.reshape(res_shape) def norm(x, ord=None, axis=None, keepdims=False): x = asarray(x) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 6b2f27e..5c99e72 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -33,15 +33,25 @@ def test_dot_1D(): a = afnumpy.array(b) fassert(afnumpy.dot(a,a), numpy.dot(b,b)) -@xfail + a = numpy.random.random(3)+numpy.random.random(3)*1.0j + b = numpy.random.random(3) + fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b)), numpy.dot(a,b)) + def test_dot_2D(): b = numpy.random.random((3,3))+numpy.random.random((3,3))*1.0j a = afnumpy.array(b) fassert(afnumpy.dot(a,a), numpy.dot(b,b)) -@xfail def test_dot_3D(): b = numpy.random.random((3,3,3))+numpy.random.random((3,3,3))*1.0j a = afnumpy.array(b) fassert(afnumpy.dot(a,a), numpy.dot(b,b)) + a = numpy.random.random((3,2,4))+numpy.random.random((3,2,4))*1.0j + b = numpy.random.random((3,4,1)) + fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b)), numpy.dot(a,b)) + + b = numpy.random.random((3,3,3)) + a = numpy.random.random((3,3)) + fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b)), numpy.dot(a,b)) +