diff --git a/Wrappers/Python/cil/optimisation/operators/BlockOperator.py b/Wrappers/Python/cil/optimisation/operators/BlockOperator.py index 66ede257b..08ec0e029 100644 --- a/Wrappers/Python/cil/optimisation/operators/BlockOperator.py +++ b/Wrappers/Python/cil/optimisation/operators/BlockOperator.py @@ -87,6 +87,10 @@ def __init__(self, *args, **kwargs): raise ValueError( 'Dimension and size do not match: expected {} got {}' .format(n_elements, len(args))) + + self._range_block_shape = (shape[0], 1) + self._domain_block_shape = (shape[1], 1) + # TODO # until a decent way to check equality of Acquisition/Image geometries # required to fullfil "Operators in a Block are required to have the same @@ -191,41 +195,46 @@ def direct(self, x, out=None): x_b = BlockDataContainer(x) else: x_b = x - shape = self.get_output_shape(x_b.shape) + + if x_b.shape != self._domain_block_shape: + raise ValueError( + 'We expect the input to be a block data container of shape {}'.format( self._domain_block_shape)) + + return_data_container = False + + if self._range_block_shape[0]==1: + return_data_container = True ##TODO: Does this default make sense? if out is None: - res = [] - for row in range(self.shape[0]): - for col in range(self.shape[1]): - if col == 0: - prod = self.get_item(row, col).direct( - x_b.get_item(col)) - else: - prod += self.get_item(row, - col).direct(x_b.get_item(col)) - res.append(prod) - if 1 == shape[0] == shape[1]: - # the output is a single DataContainer, so we can take it out - return res[0] - else: - return BlockDataContainer(*res, shape=shape) - + # allocate the output blockdatacontainer of the correct shape + res = BlockDataContainer(*[self.get_item(row, 0).range_geometry().allocate(None) + for row in range(self.shape[0])], shape=self._range_block_shape) + elif not isinstance(out, BlockDataContainer): + # Handle datacontainers or sirf datacontainers + if self._range_block_shape[0]==1: + res = BlockDataContainer(out) + else: + raise ValueError( + f'Expected `out` to be `None` or a `BlockDataContainer` of shape {self._range_block_shape}') + else: + res = out + return_data_container = False + + for row in range(self.shape[0]): + for col in range(self.shape[1]): + + if col == 0: + self.get_item(row, col).direct(x_b.get_item(col), out=res.get_item(row)) + else: + # temp_out_row points to the element in res that we are adding to + temp_out_row = res.get_item(row) + temp_out_row += self.get_item(row, col).direct(x_b.get_item(col)) + if return_data_container: + return res.get_item(0) else: - tmp = self.range_geometry().allocate() - for row in range(self.shape[0]): - for col in range(self.shape[1]): - if col == 0: - self.get_item(row,col).direct( - x_b.get_item(col), - out=out.get_item(row)) - else: - temp_out_row = out.get_item(row) # temp_out_row points to the element in out that we are adding to - self.get_item(row,col).direct( - x_b.get_item(col), - out=tmp.get_item(row)) - temp_out_row += tmp.get_item(row) - return out + return res + def adjoint(self, x, out=None): '''Adjoint operation for the BlockOperator @@ -243,53 +252,61 @@ def adjoint(self, x, out=None): ''' if not self.is_linear(): raise ValueError('Not all operators in Block are linear.') + + if not isinstance(x, BlockDataContainer): x_b = BlockDataContainer(x) else: x_b = x - shape = self.get_output_shape(x_b.shape, adjoint=True) + + if x_b.shape != self._range_block_shape: + raise ValueError( + 'We expect the input to be a block data container of shape {}'.format( self._range_block_shape)) + + return_data_container = False + + if self._domain_block_shape[0]==1: + return_data_container = True ##TODO: Does this default make sense? + + if out is None: - res = [] - for col in range(self.shape[1]): - for row in range(self.shape[0]): - if row == 0: - prod = self.get_item(row, col).adjoint( - x_b.get_item(row)) - else: - prod += self.get_item(row, - col).adjoint(x_b.get_item(row)) - res.append(prod) - if self.shape[1] == 1: - # the output is a single DataContainer, so we can take it out - return res[0] + # allocate the output blockdatacontainer of the correct shape + res = BlockDataContainer(*[self.get_item(0, col).domain_geometry().allocate(0) + for col in range(self.shape[1])], shape=self._domain_block_shape) + + + elif not isinstance(out, BlockDataContainer): + # Handle datacontainers or sirf datacontainers + if self._domain_block_shape[0]==1: + res = BlockDataContainer(out) + + else: - return BlockDataContainer(*res, shape=shape) + raise ValueError( + f'Expected `out` to be `None` or a `BlockDataContainer` of shape {self._domain_block_shape}') + else: - for col in range(self.shape[1]): - for row in range(self.shape[0]): - if row == 0: - if issubclass(out.__class__, DataContainer) or \ - (has_sirf and issubclass(out.__class__, SIRFDataContainer)): - self.get_item(row, col).adjoint( - x_b.get_item(row), - out=out) - else: - op = self.get_item(row, col) - self.get_item(row, col).adjoint( - x_b.get_item(row), - out=out.get_item(col)) - else: - if issubclass(out.__class__, DataContainer) or \ - (has_sirf and issubclass(out.__class__, SIRFDataContainer)): - out += self.get_item(row, col).adjoint( - x_b.get_item(row)) - else: - - temp_out_col = out.get_item(col) # out_col_operator points to the column in out that we are updating - temp_out_col += self.get_item(row,col).adjoint( - x_b.get_item(row), - ) - return out + res = out + return_data_container = False + + for col in range(self.shape[1]): + for row in range(self.shape[0]): + + if row == 0: + self.get_item(row, col).adjoint( + x_b.get_item(row), + out=res.get_item(col)) + else: + # out_col_operator points to the column in res that we are updating + temp_out_col = res.get_item(col) + temp_out_col += self.get_item(row, col).adjoint( + x_b.get_item(row), + ) + + if return_data_container: + return res.get_item(0) + else: + return res def is_linear(self): '''Returns whether all the elements of the BlockOperator are linear''' diff --git a/Wrappers/Python/test/test_BlockOperator.py b/Wrappers/Python/test/test_BlockOperator.py index 910a711fd..0aae44799 100644 --- a/Wrappers/Python/test/test_BlockOperator.py +++ b/Wrappers/Python/test/test_BlockOperator.py @@ -209,7 +209,66 @@ def test_block_operator_1_1(self): self.assertEqual(K.domain_geometry(), ig) + def test_blockoperator_out_datacontainer(self): + + #test direct + M, N ,W = 3, 4, 5 + ig = ImageGeometry(M, N, W) + operator0=IdentityOperator(ig) + operator1=-IdentityOperator(ig) + K = BlockOperator(operator0, operator1, shape = (1,2)) + bg=BlockGeometry(ig, ig) + data=bg.allocate('random', seed=2) + out=K.range.allocate(0) + assert not isinstance(out, BlockDataContainer) + ans = K.direct(data) + K.direct(data, out) + self.assertNumpyArrayEqual(ans.array, out.array) + + #test direct out is BlockDataContainer + out = BlockDataContainer(out) + assert isinstance(out, BlockDataContainer) + ans = K.direct(data) + K.direct(data, out) + self.assertNumpyArrayEqual(ans.array, out.get_item(0).array) + + #test adjoint wrong dimension + out=ig.allocate(0) + data = ig.allocate('random') + print(K.range_geometry) + with self.assertRaises(ValueError): + K.adjoint(data, out) + + + #test adjoint out not BlockDataContainer + M, N ,W = 3, 4, 5 + operator0=IdentityOperator(ig) + operator1=-IdentityOperator(ig) + K = BlockOperator(operator0, operator1, shape = (2,1)) + bg=BlockGeometry(ig, ig) + data=bg.allocate('random', seed=2) + out=K.domain.allocate(0) + assert not isinstance(out, BlockDataContainer) + ans = K.adjoint(data) + K.adjoint(data, out) + self.assertNumpyArrayEqual(ans.array, out.array) + + #test adjoint out is BlockDataContainer + out = BlockDataContainer(out) + assert isinstance(out, BlockDataContainer) + ans = K.adjoint(data) + K.adjoint(data, out) + self.assertNumpyArrayEqual(ans.array, out.get_item(0).array) + + #test direct wrong dimension + out=ig.allocate(0) + data = ig.allocate('random') + print(K.range_geometry) + with self.assertRaises(ValueError): + K.direct(data, out) + + @unittest.skipIf(True, 'Skipping time tests') def test_timedifference(self):