Skip to content

Commit

Permalink
Implement dot for multidimensional arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
FilipeMaia committed Oct 23, 2016
1 parent 09a0b92 commit 279c1cf
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
57 changes: 55 additions & 2 deletions afnumpy/linalg/linalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import afnumpy
import arrayfire
import numpy
from numpy.core import complexfloating, Inf, longdouble
from afnumpy import asarray, sqrt, abs
from afnumpy.lib import asfarray
from .. import private_utils as pu

def isComplexType(t):
return issubclass(t, complexfloating)
Expand All @@ -13,8 +15,59 @@ def vdot(a, b):

# TODO: Implement multidimensional dot
def dot(a, b):
s = arrayfire.dot((a.flat.d_array), b.flat.d_array)
return afnumpy.ndarray((), dtype=a.dtype, af_array=s)[()]
# Arrayfire requires that the types match for dot and matmul
res_dtype = numpy.result_type(a,b)
a = a.astype(res_dtype, copy=False)
b = b.astype(res_dtype, copy=False)
if a.ndim == 1 and b.ndim == 1:
s = arrayfire.dot((a.flat.d_array), b.flat.d_array)
return afnumpy.ndarray((), dtype=a.dtype, af_array=s)[()]

a_shape = a.shape
b_shape = b.shape
if a.ndim == 1:
a = a.reshape((a.shape[0],1))
if b.ndim == 1:
b = b.reshape((b.shape[0],1))

if a.ndim == 2 and b.ndim == 2:
# Notice the order of the arguments to matmul. It's not a bug!
s = arrayfire.matmul(b.d_array, a.d_array)
return afnumpy.ndarray(pu.af_shape(s), dtype=pu.typemap(s.dtype()), af_array=s)
# Multidimensional dot is done with loops

# Calculate the shape of the result array
a_shape = list(a_shape)
a_shape.pop(-1)
b_shape = list(b_shape)
b_shape.pop(-2)
res_shape = a_shape + b_shape

# Initialize the output array
res = afnumpy.empty(res_shape, dtype=res_dtype)

# Make sure the arrays are at least 3D
if a.ndim < 3:
a = a.reshape((1,)+a.shape)
if b.ndim < 3:
b = b.reshape((1,)+b.shape)

# We're going to flatten the regions over which we're going to loop over
# to make our life easier and reduce the amount of indexing code
a = a.reshape((-1,a.shape[-2],a.shape[-1]))
b = b.reshape((-1,b.shape[-2],b.shape[-1]))

# Initialize the output array. The shape matches the reshape a and b.
res = afnumpy.empty((a.shape[0], a.shape[-2], b.shape[0],
b.shape[-1]), dtype=a.dtype)

# Loop through the flattened indices and calculate the matmuls
for i in range(0,a.shape[0]):
for j in range(0,b.shape[0]):
res[i,:,j,:] = afnumpy.dot(a[i],b[j])

# Finally appropriately reshape the result array
return res.reshape(res_shape)

def norm(x, ord=None, axis=None, keepdims=False):
x = asarray(x)
Expand Down
14 changes: 12 additions & 2 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,25 @@ def test_dot_1D():
a = afnumpy.array(b)
fassert(afnumpy.dot(a,a), numpy.dot(b,b))

@xfail
a = numpy.random.random(3)+numpy.random.random(3)*1.0j
b = numpy.random.random(3)
fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b)), numpy.dot(a,b))

def test_dot_2D():
b = numpy.random.random((3,3))+numpy.random.random((3,3))*1.0j
a = afnumpy.array(b)
fassert(afnumpy.dot(a,a), numpy.dot(b,b))

@xfail
def test_dot_3D():
b = numpy.random.random((3,3,3))+numpy.random.random((3,3,3))*1.0j
a = afnumpy.array(b)
fassert(afnumpy.dot(a,a), numpy.dot(b,b))

a = numpy.random.random((3,2,4))+numpy.random.random((3,2,4))*1.0j
b = numpy.random.random((3,4,1))
fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b)), numpy.dot(a,b))

b = numpy.random.random((3,3,3))
a = numpy.random.random((3,3))
fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b)), numpy.dot(a,b))

0 comments on commit 279c1cf

Please sign in to comment.