Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BlockOperator direct and adjoint methods: can pass out as a DataContainer instead of a (1,1) BlockDataContainer where geometry permits #1926

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
159 changes: 88 additions & 71 deletions Wrappers/Python/cil/optimisation/operators/BlockOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def __init__(self, *args, **kwargs):
raise ValueError(
'Dimension and size do not match: expected {} got {}'
.format(n_elements, len(args)))

self._range_block_shape = (shape[0], 1)
self._domain_block_shape = (shape[1], 1)

# TODO
# until a decent way to check equality of Acquisition/Image geometries
# required to fullfil "Operators in a Block are required to have the same
Expand Down Expand Up @@ -191,41 +195,46 @@ def direct(self, x, out=None):
x_b = BlockDataContainer(x)
else:
x_b = x
shape = self.get_output_shape(x_b.shape)

if x_b.shape != self._domain_block_shape:
raise ValueError(
'We expect the input to be a block data container of shape {}'.format( self._domain_block_shape))

return_data_container = False

if self._range_block_shape[0]==1:
return_data_container = True ##TODO: Does this default make sense?

if out is None:
res = []
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
prod = self.get_item(row, col).direct(
x_b.get_item(col))
else:
prod += self.get_item(row,
col).direct(x_b.get_item(col))
res.append(prod)
if 1 == shape[0] == shape[1]:
# the output is a single DataContainer, so we can take it out
return res[0]
else:
return BlockDataContainer(*res, shape=shape)

# allocate the output blockdatacontainer of the correct shape
res = BlockDataContainer(*[self.get_item(row, 0).range_geometry().allocate(None)
for row in range(self.shape[0])], shape=self._range_block_shape)
elif not isinstance(out, BlockDataContainer):
# Handle datacontainers or sirf datacontainers
if self._range_block_shape[0]==1:
res = BlockDataContainer(out)
else:
raise ValueError(
f'Expected `out` to be `None` or a `BlockDataContainer` of shape {self._range_block_shape}')
else:
res = out
return_data_container = False

for row in range(self.shape[0]):
for col in range(self.shape[1]):

if col == 0:
self.get_item(row, col).direct(x_b.get_item(col), out=res.get_item(row))
else:
# temp_out_row points to the element in res that we are adding to
temp_out_row = res.get_item(row)
temp_out_row += self.get_item(row, col).direct(x_b.get_item(col))

if return_data_container:
return res.get_item(0)
else:
tmp = self.range_geometry().allocate()
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
self.get_item(row,col).direct(
x_b.get_item(col),
out=out.get_item(row))
else:
temp_out_row = out.get_item(row) # temp_out_row points to the element in out that we are adding to
self.get_item(row,col).direct(
x_b.get_item(col),
out=tmp.get_item(row))
temp_out_row += tmp.get_item(row)
return out
return res


def adjoint(self, x, out=None):
'''Adjoint operation for the BlockOperator
Expand All @@ -243,53 +252,61 @@ def adjoint(self, x, out=None):
'''
if not self.is_linear():
raise ValueError('Not all operators in Block are linear.')


if not isinstance(x, BlockDataContainer):
x_b = BlockDataContainer(x)
else:
x_b = x
shape = self.get_output_shape(x_b.shape, adjoint=True)

if x_b.shape != self._range_block_shape:
raise ValueError(
'We expect the input to be a block data container of shape {}'.format( self._range_block_shape))

return_data_container = False

if self._domain_block_shape[0]==1:
return_data_container = True ##TODO: Does this default make sense?


if out is None:
res = []
for col in range(self.shape[1]):
for row in range(self.shape[0]):
if row == 0:
prod = self.get_item(row, col).adjoint(
x_b.get_item(row))
else:
prod += self.get_item(row,
col).adjoint(x_b.get_item(row))
res.append(prod)
if self.shape[1] == 1:
# the output is a single DataContainer, so we can take it out
return res[0]
# allocate the output blockdatacontainer of the correct shape
res = BlockDataContainer(*[self.get_item(0, col).domain_geometry().allocate(0)
for col in range(self.shape[1])], shape=self._domain_block_shape)


elif not isinstance(out, BlockDataContainer):
# Handle datacontainers or sirf datacontainers
if self._domain_block_shape[0]==1:
res = BlockDataContainer(out)


else:
return BlockDataContainer(*res, shape=shape)
raise ValueError(
f'Expected `out` to be `None` or a `BlockDataContainer` of shape {self._domain_block_shape}')

else:
for col in range(self.shape[1]):
for row in range(self.shape[0]):
if row == 0:
if issubclass(out.__class__, DataContainer) or \
(has_sirf and issubclass(out.__class__, SIRFDataContainer)):
self.get_item(row, col).adjoint(
x_b.get_item(row),
out=out)
else:
op = self.get_item(row, col)
self.get_item(row, col).adjoint(
x_b.get_item(row),
out=out.get_item(col))
else:
if issubclass(out.__class__, DataContainer) or \
(has_sirf and issubclass(out.__class__, SIRFDataContainer)):
out += self.get_item(row, col).adjoint(
x_b.get_item(row))
else:

temp_out_col = out.get_item(col) # out_col_operator points to the column in out that we are updating
temp_out_col += self.get_item(row,col).adjoint(
x_b.get_item(row),
)
return out
res = out
return_data_container = False

for col in range(self.shape[1]):
for row in range(self.shape[0]):

if row == 0:
self.get_item(row, col).adjoint(
x_b.get_item(row),
out=res.get_item(col))
else:
# out_col_operator points to the column in res that we are updating
temp_out_col = res.get_item(col)
temp_out_col += self.get_item(row, col).adjoint(
x_b.get_item(row),
)

if return_data_container:
return res.get_item(0)
else:
return res

def is_linear(self):
'''Returns whether all the elements of the BlockOperator are linear'''
Expand Down
59 changes: 59 additions & 0 deletions Wrappers/Python/test/test_BlockOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,66 @@ def test_block_operator_1_1(self):

self.assertEqual(K.domain_geometry(), ig)

def test_blockoperator_out_datacontainer(self):

#test direct
M, N ,W = 3, 4, 5
ig = ImageGeometry(M, N, W)
operator0=IdentityOperator(ig)
operator1=-IdentityOperator(ig)
K = BlockOperator(operator0, operator1, shape = (1,2))
bg=BlockGeometry(ig, ig)
data=bg.allocate('random', seed=2)
out=K.range.allocate(0)
assert not isinstance(out, BlockDataContainer)
ans = K.direct(data)
K.direct(data, out)
self.assertNumpyArrayEqual(ans.array, out.array)

#test direct out is BlockDataContainer
out = BlockDataContainer(out)
assert isinstance(out, BlockDataContainer)
ans = K.direct(data)
K.direct(data, out)
self.assertNumpyArrayEqual(ans.array, out.get_item(0).array)

#test adjoint wrong dimension
out=ig.allocate(0)
data = ig.allocate('random')
print(K.range_geometry)
with self.assertRaises(ValueError):
K.adjoint(data, out)


#test adjoint out not BlockDataContainer
M, N ,W = 3, 4, 5
operator0=IdentityOperator(ig)
operator1=-IdentityOperator(ig)
K = BlockOperator(operator0, operator1, shape = (2,1))
bg=BlockGeometry(ig, ig)
data=bg.allocate('random', seed=2)
out=K.domain.allocate(0)
assert not isinstance(out, BlockDataContainer)
ans = K.adjoint(data)
K.adjoint(data, out)
self.assertNumpyArrayEqual(ans.array, out.array)

#test adjoint out is BlockDataContainer
out = BlockDataContainer(out)
assert isinstance(out, BlockDataContainer)
ans = K.adjoint(data)
K.adjoint(data, out)
self.assertNumpyArrayEqual(ans.array, out.get_item(0).array)

#test direct wrong dimension
out=ig.allocate(0)
data = ig.allocate('random')
print(K.range_geometry)
with self.assertRaises(ValueError):
K.direct(data, out)




@unittest.skipIf(True, 'Skipping time tests')
def test_timedifference(self):
Expand Down
Loading