Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out unit tests for operators and functions no longer pass if something is not implemented #1939

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 90 additions & 76 deletions Wrappers/Python/test/test_out_in_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,36 +87,38 @@ def setUp(self):
b_ig = ig.allocate('random')
c = numpy.float64(0.3)
bg = BlockGeometry(ig, ig)

# [(function, geometry, test_proximal, test_proximal_conjugate, test_gradient), ...]
self.func_geom_test_list = [
(IndicatorBox(), ag),
(KullbackLeibler(b=b, backend='numba'), ag),
(KullbackLeibler(b=b, backend='numpy'), ag),
(L1Norm(), ag),
(L1Norm(), ig),
(L1Norm(b=b), ag),
(L1Norm(b=b, weight=b), ag),
(TranslateFunction(L1Norm(), b), ag),
(TranslateFunction(L2NormSquared(), b), ag),
(L2NormSquared(), ag),
(scalar * L2NormSquared(), ag),
(SumFunction(L2NormSquared(), scalar * L2NormSquared()), ag),
(SumScalarFunction(L2NormSquared(), 3), ag),
(ConstantFunction(3), ag),
(ZeroFunction(), ag),
(L2NormSquared(b=b), ag),
(L2NormSquared(), ag),
(LeastSquares(A, b_ig, c, weight_ls), ig),
(LeastSquares(A, b_ig, c), ig),
(WeightedL2NormSquared(weight=b_ig), ig),
(TotalVariation(backend='c', warm_start=False, max_iteration=100), ig),
(TotalVariation(backend='numpy', warm_start=False, max_iteration=100), ig),
(OperatorCompositionFunction(L2NormSquared(), A), ig),
(MixedL21Norm(), bg),
(SmoothMixedL21Norm(epsilon=0.3), bg),
(MixedL11Norm(), bg),
(BlockFunction(L2NormSquared(),L2NormSquared()), bg),
(L1Sparsity(WaveletOperator(ig)), ig)
(IndicatorBox(), ag, True, True, False),
(KullbackLeibler(b=b, backend='numba'), ag, True, True, True),
(KullbackLeibler(b=b, backend='numpy'), ag, True, True, True),
(L1Norm(), ag, True, True, False),
(L1Norm(), ig, True, True, False),
(L1Norm(b=b), ag, True, True, False),
(L1Norm(b=b, weight=b), ag, True, True, False),
(TranslateFunction(L1Norm(), b), ag, True, True, False),
(TranslateFunction(L2NormSquared(), b), ag, True, True, True),
(L2NormSquared(), ag, True, True, True),
(scalar * L2NormSquared(), ag, True, True, True),
(SumFunction(L2NormSquared(), scalar * L2NormSquared()), ag, False, False, True),
(SumScalarFunction(L2NormSquared(), 3), ag, True, True, True),
(ConstantFunction(3), ag, True, True, True),
(ZeroFunction(), ag, True, True, True),
(L2NormSquared(b=b), ag, True, True, True),
(L2NormSquared(), ag, True, True, True),
(LeastSquares(A, b_ig, c, weight_ls), ig, False, False, True),
(LeastSquares(A, b_ig, c), ig, False, False, True),
(WeightedL2NormSquared(weight=b_ig), ig, True, True, True),
(TotalVariation(backend='c', warm_start=False, max_iteration=100), ig, True, True, False),
(TotalVariation(backend='numpy', warm_start=False, max_iteration=100), ig, True, True, False),
(OperatorCompositionFunction(L2NormSquared(), A), ig, False, False, True),
(MixedL21Norm(), bg, True, True, False),
(SmoothMixedL21Norm(epsilon=0.3), bg, False, False, True),
(MixedL11Norm(), bg, True, True, False),
(BlockFunction(L1Norm(),L2NormSquared()), bg, True, True, False),
(BlockFunction(L2NormSquared(),L2NormSquared()), bg, True, True, True),
(L1Sparsity(WaveletOperator(ig)), ig, True, True, False)


]

Expand All @@ -135,62 +137,71 @@ def get_result(self, function, method, x, *args):
self.assertDataArraysInContainerAllClose(input, x, rtol=1e-5, msg= "In case func."+method+'(data, *args) where func is ' + function.__class__.__name__+ 'the input data has been incorrectly affected by the calculation. ')
return out
except NotImplementedError:
return None
raise NotImplementedError(function.__class__.__name__+" raises a NotImplementedError for "+method)


def in_place_test(self,desired_result, function, method, x, *args, ):
out3 = x.copy()
try:
if method == 'proximal':
function.proximal(out3, *args, out=out3)
elif method == 'proximal_conjugate':
function.proximal_conjugate(out3, *args, out=out3)
elif method == 'gradient':
function.gradient(out3, *args, out=out3)
self.assertDataArraysInContainerAllClose(desired_result, out3, rtol=1e-5, msg= "In place calculation failed for func."+method+'(data, *args, out=data) where func is ' + function.__class__.__name__+ '. ')

except (InPlaceError, NotImplementedError):
pass

try:
if method == 'proximal':
function.proximal(out3, *args, out=out3)
elif method == 'proximal_conjugate':
function.proximal_conjugate(out3, *args, out=out3)
elif method == 'gradient':
function.gradient(out3, *args, out=out3)
self.assertDataArraysInContainerAllClose(desired_result, out3, rtol=1e-5, msg= "In place calculation failed for func."+method+'(data, *args, out=data) where func is ' + function.__class__.__name__+ '. ')

except InPlaceError:
pass
except NotImplementedError:
raise NotImplementedError(function.__class__.__name__+" raises a NotImplementedError for "+method)

def out_test(self, desired_result, function, method, x, *args, ):
input = x.copy()
out2=0*(x.copy())
try:
if method == 'proximal':
ret = function.proximal(input, *args, out=out2)
elif method == 'proximal_conjugate':
ret = function.proximal_conjugate(input, *args, out=out2)
elif method == 'gradient':
ret = function.gradient(input, *args, out=out2)
self.assertDataArraysInContainerAllClose(desired_result, out2, rtol=1e-5, msg= "Calculation failed using `out` in func."+method+'(x, *args, out=data) where func is ' + function.__class__.__name__+ '. ')
self.assertDataArraysInContainerAllClose(input, x, rtol=1e-5, msg= "In case func."+method+'(data, *args, out=out) where func is ' + function.__class__.__name__+ 'the input data has been incorrectly affected by the calculation. ')
self.assertDataArraysInContainerAllClose(desired_result, ret, rtol=1e-5, msg= f"Calculation failed returning with `out` in ret = func.{method}(x, *args, out=data) where func is {function.__class__.__name__}")

except (InPlaceError, NotImplementedError):
pass
try:
if method == 'proximal':
ret = function.proximal(input, *args, out=out2)
elif method == 'proximal_conjugate':
ret = function.proximal_conjugate(input, *args, out=out2)
elif method == 'gradient':
ret = function.gradient(input, *args, out=out2)
self.assertDataArraysInContainerAllClose(desired_result, out2, rtol=1e-5, msg= "Calculation failed using `out` in func."+method+'(x, *args, out=data) where func is ' + function.__class__.__name__+ '. ')
self.assertDataArraysInContainerAllClose(input, x, rtol=1e-5, msg= "In case func."+method+'(data, *args, out=out) where func is ' + function.__class__.__name__+ 'the input data has been incorrectly affected by the calculation. ')
self.assertDataArraysInContainerAllClose(desired_result, ret, rtol=1e-5, msg= f"Calculation failed returning with `out` in ret = func.{method}(x, *args, out=data) where func is {function.__class__.__name__}")

except InPlaceError:
pass
except NotImplementedError:
raise NotImplementedError(function.__class__.__name__+" raises a NotImplementedError for "+method)



def test_proximal_conjugate_out(self):
for func, geom in self.func_geom_test_list:
for data_array in self.data_arrays:
data=geom.allocate(None)
data.fill(data_array)
result=self.get_result(func, 'proximal_conjugate', data, 0.5)
self.out_test(result, func, 'proximal_conjugate', data, 0.5)
self.in_place_test(result, func, 'proximal_conjugate', data, 0.5)
for func, geom, _, test_proximal_conj, _ in self.func_geom_test_list:
if test_proximal_conj:
for data_array in self.data_arrays:
data=geom.allocate(None)
data.fill(data_array)
result=self.get_result(func, 'proximal_conjugate', data, 0.5)
self.out_test(result, func, 'proximal_conjugate', data, 0.5)
self.in_place_test(result, func, 'proximal_conjugate', data, 0.5)

def test_proximal_out(self):
for func, geom in self.func_geom_test_list:
for data_array in self.data_arrays:
data=geom.allocate(None)
data.fill(data_array)
result=self.get_result(func, 'proximal', data, 0.5)
self.out_test(result, func, 'proximal', data, 0.5)
self.in_place_test(result,func, 'proximal', data, 0.5)
for func, geom, test_proximal, _, _ in self.func_geom_test_list:
if test_proximal:
for data_array in self.data_arrays:
data=geom.allocate(None)
data.fill(data_array)
result=self.get_result(func, 'proximal', data, 0.5)
self.out_test(result, func, 'proximal', data, 0.5)
self.in_place_test(result,func, 'proximal', data, 0.5)

def test_gradient_out(self):
for func, geom in self.func_geom_test_list:
for func, geom, _, _, test_gradient in self.func_geom_test_list:
if test_gradient:
for data_array in self.data_arrays:
print(func.__class__.__name__)
data=geom.allocate(None)
Expand Down Expand Up @@ -263,20 +274,23 @@ def get_result(self, operator, method, x, *args):
self.assertDataArraysInContainerAllClose(input, x, rtol=1e-5, msg= "In case operator."+method+'(data, *args) where operator is ' + operator.__class__.__name__+ 'the input data has been incorrectly affected by the calculation. ')
return out
except NotImplementedError:
return None
raise NotImplementedError(operator.__class__.__name__+" raises a NotImplementedError for "+method)

def in_place_test(self,desired_result, operator, method, x, *args, ):
out3 = x.copy()
try:
if method == 'direct':
operator.direct(out3, *args, out=out3)
elif method == 'adjoint':
operator.adjoint(out3, *args, out=out3)
try:
if method == 'direct':
operator.direct(out3, *args, out=out3)
elif method == 'adjoint':
operator.adjoint(out3, *args, out=out3)

self.assertDataArraysInContainerAllClose(desired_result, out3, rtol=1e-5, msg= "In place calculation failed for operator."+method+'(data, *args, out=data) where operator is ' + operator.__class__.__name__+ '. ')
self.assertDataArraysInContainerAllClose(desired_result, out3, rtol=1e-5, msg= "In place calculation failed for operator."+method+'(data, *args, out=data) where operator is ' + operator.__class__.__name__+ '. ')

except (InPlaceError, NotImplementedError):
pass
except InPlaceError:
pass
except NotImplementedError:
raise NotImplementedError(operator.__class__.__name__+" raises a NotImplementedError for "+method)


def out_test(self, desired_result, operator, method, x, *args):
Expand Down
Loading