Skip to content

Commit

Permalink
Require h and v data order at the end
Browse files Browse the repository at this point in the history
  • Loading branch information
hrobarts committed Oct 15, 2024
1 parent db715e6 commit 69e4466
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 87 deletions.
188 changes: 119 additions & 69 deletions Wrappers/Python/cil/processors/FluxNormaliser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class FluxNormaliser(Processor):
target: string or float
The value to scale the normalised data with. If float, the data is scaled
to the float value. If string 'mean' the data is scaled to the mean value
of the input flux or flux in the roi, if 'first' the data is scaled to
of the input flux or flux across all rois, if 'first' the data is scaled to
the first input flux value or the flux in the roi of the first projection.
Default is 'mean'
Returns:
--------
Expand Down Expand Up @@ -83,12 +84,7 @@ class FluxNormaliser(Processor):
in the roi dictionary
'''

def __init__(self, flux=None, roi=None, target=None):

if roi is not None and flux is not None:
raise ValueError("Please specify either flux or roi, not both")
if roi is None and flux is None:
raise ValueError("Please specify either flux or roi, found None")
def __init__(self, flux=None, roi=None, target='mean'):

kwargs = {
'flux' : flux,
Expand All @@ -97,22 +93,54 @@ def __init__(self, flux=None, roi=None, target=None):
'roi_axes' : None,
'target' : target,
'target_value' : None,
'v_size' : 1,
'v_axis' : None,
'h_size' : 1,
'h_axis' : None
}
super(FluxNormaliser, self).__init__(**kwargs)

def check_input(self, dataset):

if self.roi is not None and self.flux is not None:
raise ValueError("Please specify either flux or roi, not both")
if self.roi is None and self.flux is None:
raise ValueError("Please specify either flux or roi, found None")

if not (type(dataset), AcquisitionData):
raise TypeError("Expected AcquistionData, found {}"
.format(type(dataset)))

image_axes = 0
if 'vertical' in dataset.dimension_labels:
v_axis = dataset.get_dimension_axis('vertical')
self.v_size = dataset.get_dimension_size('vertical')
image_axes += 1

if 'horizontal' in dataset.dimension_labels:
self.h_axis = dataset.get_dimension_axis('horizontal')
self.h_size = dataset.get_dimension_size('horizontal')
image_axes += 1

if (( self.h_axis is not None) and (self. h_axis < (len(dataset.shape)-image_axes))) or \
((self.v_axis is not None) and self.v_axis < (len(dataset.shape)-image_axes)):
raise ValueError('Projections must be the last two axes of the dataset')

return True


def _calculate_flux(self):
dataset = self.get_input()
# convert flux to float32
'''
Function to calculate flux from a region of interest in the data. If the
flux is already provided as an array, convert the array to float 32 and
check the size matches the number of projections
'''

dataset = self.get_input()
if dataset is None:
raise ValueError('Data not found, please run `set_input(data)`')

# Calculate the flux from the roi in the data
if self.flux is None:

if isinstance(self.roi, dict):
Expand All @@ -128,11 +156,8 @@ def _calculate_flux(self):
if (r != 'horizontal' and r != 'vertical'):
raise ValueError("roi must be 'horizontal' or 'vertical', found '{}'"
.format(str(r)))

dimension_label_list = list(self.roi.keys())

for d in dimension_label_list:
# check the dimension is in the user specified roi

for d in ['horizontal', 'vertical']:
if d in self.roi:
# check indices are ints
if not all(isinstance(i, int) for i in self.roi[d]):
Expand All @@ -147,15 +172,16 @@ def _calculate_flux(self):
ax = dataset.get_dimension_axis(d)
slc[ax] = slice(self.roi[d][0], self.roi[d][1])
axes.append(ax)

# if the dimension is not in roi, average across the whole dimension
# if a projection dimension isn't in the roi, use the whole axis
else:
ax = dataset.get_dimension_axis(d)
axes.append(ax)
self.roi.update({d:(0,dataset.get_dimension_size(d))})
if d in dataset.dimension_labels:
ax = dataset.get_dimension_axis(d)
axes.append(ax)
self.roi.update({d:(0,dataset.get_dimension_size(d))})

self.flux = numpy.mean(dataset.array[tuple(slc)], axis=tuple(axes))

# Warn if the flux is more than 10% of the dataset range
dataset_range = numpy.max(dataset.array, axis=tuple(axes)) - numpy.min(dataset.array, axis=tuple(axes))

if (numpy.mean(self.flux) > dataset.mean()):
Expand All @@ -171,36 +197,41 @@ def _calculate_flux(self):
else:
raise TypeError("roi must be a dictionary, found {}"
.format(str(type(self.roi))))



flux_size = (numpy.shape(self.flux))
if len(flux_size) > 0:
# check this
data_size = numpy.shape(dataset.geometry.angles) # make this also account for channels
if data_size != flux_size:
# convert flux array to float32
self.flux = numpy.array(self.flux, dtype=numpy.float32, ndmin=1)

# check flux array is the right size
flux_size_flat = len(self.flux.ravel())
if flux_size_flat > 1:
data_size_flat = len(dataset.geometry.angles)*dataset.geometry.channels
if data_size_flat != flux_size_flat:
raise ValueError("Flux must be a scalar or array with length \
\n = number of projections, found {} and {}"
.format(flux_size, data_size))
if numpy.any(self.flux==0):
raise ValueError('Flux value can\'t be 0, provide a different flux\
or region of interest with non-zero values')
else:
if self.flux==0:
raise ValueError('Flux value can\'t be 0, provide a different flux\
or region of interest with non-zero values')
.format(flux_size_flat, data_size_flat))

self.flux = numpy.float32(self.flux)
# check if flux array contains 0s
if 0 in self.flux:
raise ValueError('Flux value can\'t be 0, provide a different flux\
or region of interest with non-zero values')


def _calculate_target(self):
'''
Calculate the target value for the normalisation
'''
if isinstance(self.target, float):
self.target_value = self.target
elif isinstance(self.target, str):
if self.target == 'first':
if len(numpy.shape(self.flux)) > 0 :
self.target_value = self.flux[0]
self.target_value = self.flux.flat[0]
else:
self.target_value = self.flux
elif self.target == 'mean':
self.target_value = numpy.mean(self.flux)
self.target_value = numpy.mean(self.flux.ravel())
else:
raise ValueError("Target string not recognised, found {}, expected 'first' or 'mean'"
.format(self.target))
Expand Down Expand Up @@ -231,15 +262,29 @@ def preview_configuration(self, angle=None, channel=None, log=False):
log: bool (optional)
If True, plot the image with a log scale, default is False
'''
self._calculate_flux()
if self.roi_slice is None:
raise ValueError('Preview available with roi, run `processor= FluxNormaliser(roi=roi)` then `set_input(data)`')
else:
self._calculate_flux()
data = self.get_input()

data = self.get_input()

min = numpy.min(data.array[tuple(self.roi_slice)], axis=tuple(self.roi_axes))
max = numpy.max(data.array[tuple(self.roi_slice)], axis=tuple(self.roi_axes))
plt.figure()
if 'channel' in data.dimension_labels:
if channel is None:
channel = int(data.get_dimension_size('channel')/2)
channel_axis = data.get_dimension_axis('channel')
flux_array = self.flux.take(indices=channel, axis=channel_axis)
min = min.take(indices=channel, axis=channel_axis)
max = max.take(indices=channel, axis=channel_axis)
else:
if channel is not None:
raise ValueError("Channel not found")
else:
flux_array = self.flux

plt.figure(figsize=(8,8))
if data.geometry.dimension == '3D':
if angle is None:
if 'angle' in data.dimension_labels:
Expand All @@ -263,13 +308,14 @@ def preview_configuration(self, angle=None, channel=None, log=False):

plt.subplot(212)
if len(data.geometry.angles)==1:
plt.plot(data.geometry.angles, self.flux, '.r', label='Mean')
plt.plot(data.geometry.angles, flux_array, '.r', label='Mean')
plt.plot(data.geometry.angles, min,'.k', label='Minimum')
plt.plot(data.geometry.angles, max,'.k', label='Maximum')
else:
plt.plot(data.geometry.angles, self.flux, 'r', label='Mean')
plt.plot(data.geometry.angles, flux_array, 'r', label='Mean')
plt.plot(data.geometry.angles, min,'--k', label='Minimum')
plt.plot(data.geometry.angles, max,'--k', label='Maximum')

plt.legend()
plt.xlabel('angle')
plt.ylabel('Intensity in roi')
Expand All @@ -286,12 +332,7 @@ def _plot_slice_roi(self, angle_index=None, channel_index=None, log=False, ax=11
data_slice = data

if 'channel' in data.dimension_labels:
if channel_index is None:
channel_index = int(data_slice.get_dimension_size('channel')/2)
data_slice = data_slice.get_slice(channel=channel_index)
else:
if channel_index is not None:
raise ValueError("Channel not found")

if len(data_slice.shape) != 2:
raise ValueError("Data shape not compatible with preview_configuration(), data must have at least two of 'horizontal', 'vertical' and 'angle'")
Expand All @@ -305,13 +346,13 @@ def _plot_slice_roi(self, angle_index=None, channel_index=None, log=False, ax=11
extent[i*2]=min_angle
extent[i*2+1]=max_angle

plt.subplot(ax)
ax1 = plt.subplot(ax)
if log:
im = plt.imshow(numpy.log(data_slice.array), cmap='gray',aspect='equal', origin='lower', extent=extent)
plt.gcf().colorbar(im, ax=plt.gca())
im = ax1.imshow(numpy.log(data_slice.array), cmap='gray',aspect='equal', origin='lower', extent=extent)
plt.gcf().colorbar(im, ax=ax1)
else:
im = plt.imshow(data_slice.array, cmap='gray',aspect='equal', origin='lower', extent=extent)
plt.gcf().colorbar(im, ax=plt.gca())
im = ax1.imshow(data_slice.array, cmap='gray',aspect='equal', origin='lower', extent=extent)
plt.gcf().colorbar(im, ax=ax1)

h = data_slice.dimension_labels[1]
v = data_slice.dimension_labels[0]
Expand All @@ -330,21 +371,21 @@ def _plot_slice_roi(self, angle_index=None, channel_index=None, log=False, ax=11
v_min = self.roi[v][0]
v_max = self.roi[v][1]

plt.plot([h_min, h_max],[v_min, v_min],'--r')
plt.plot([h_min, h_max],[v_max, v_max],'--r')
ax1.plot([h_min, h_max],[v_min, v_min],'--r')
ax1.plot([h_min, h_max],[v_max, v_max],'--r')

plt.plot([h_min, h_min],[v_min, v_max],'--r')
plt.plot([h_max, h_max],[v_min, v_max],'--r')
ax1.plot([h_min, h_min],[v_min, v_max],'--r')
ax1.plot([h_max, h_max],[v_min, v_max],'--r')

title = 'ROI'
if angle_index is not None:
title += ' angle = ' + str(data.geometry.angles[angle_index])
if channel_index is not None:
title += ' channel = ' + str(channel_index)
plt.title(title)
ax1.set_title(title)

plt.xlabel(h)
plt.ylabel(v)
ax1.set_xlabel(h)
ax1.set_ylabel(v)

def process(self, out=None):
self._calculate_flux()
Expand All @@ -354,20 +395,29 @@ def process(self, out=None):
if out is None:
out = data.copy()

flux_size = (numpy.shape(self.flux))
proj_size = self.v_size*self.h_size
num_proj = int(data.array.size / proj_size)

f = self.flux
for i in range(num_proj):
arr_proj = data.array.flat[i*proj_size:(i+1)*proj_size]
if len(self.flux.flat) > 1:
f = self.flux.flat[i]
arr_proj *= self.target_value/f
out.array.flat[i*proj_size:(i+1)*proj_size] = arr_proj


if 'angle' in data.dimension_labels:
proj_axis = data.get_dimension_axis('angle')
slice_proj = [slice(None)]*len(data.shape)
slice_proj[proj_axis] = 0
# if 'angle' in data.dimension_labels:
# proj_axis = data.get_dimension_axis('angle')
# slice_proj = [slice(None)]*len(data.shape)
# slice_proj[proj_axis] = 0

for i in range(len(data.geometry.angles)):
if len(flux_size) > 0:
f = self.flux[i]
slice_proj[proj_axis] = i
out.array[tuple(slice_proj)] = data.array[tuple(slice_proj)]*self.target_value/f
else:
out.array = data.array*self.target_value/f
# for i in range(len(data.geometry.angles)):
# if len(flux_size) > 0:
# f = self.flux[i]
# slice_proj[proj_axis] = i
# out.array[tuple(slice_proj)] = data.array[tuple(slice_proj)]*self.target_value/f
# else:
# out.array = data.array*self.target_value/f

return out
Loading

0 comments on commit 69e4466

Please sign in to comment.