Skip to content

Commit

Permalink
Vt 514 dask usage via karabo is pit of success (#516)
Browse files Browse the repository at this point in the history
* Made karabo "pit of success", added parameter if inside container for dask client creation.

* Fixed missing type for mypy.

* Fixed line too long and improved description.

* Imported at the beginning.

* Fixed failing tests

* Added test for dask.

* Added test for dask

* Improved documentation.,
  • Loading branch information
kenfus authored Sep 27, 2023
1 parent 19c379b commit eaaeb53
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 151 deletions.
61 changes: 61 additions & 0 deletions doc/src/examples/example_structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,67 @@ Please look at the karabo.package documentation for specifics on the individual

![Image](../images/telescope.png)

## Parallel processing with Karabo

Karabo streamlines the process of setting up an environment for parallelization. Through its utility function `parallelize_with_dask`, Karabo nudges the user towards a seamless parallelization experience. By adhering to its format, users find themselves in a `pit of success` with parallel processing. This ensures efficient task distribution across multiple cores or even entire cluster nodes, especially when handling large datasets or tasks with high computational demands.

### Points to Consider When Using `parallelize_with_dask` and Dask in General

When leveraging the `parallelize_with_dask` function for parallel processing in Karabo, users should keep in mind the following best practices:

1. **Avoid Infinite Tasks**: Ensure that the tasks you're parallelizing have a defined end. Infinite or extremely long-running tasks can clog the parallelization pipeline.

2. **Beware of Massive Tasks**: Large tasks can monopolize resources, potentially causing an imbalance in the workload distribution. It's often more efficient to break massive tasks into smaller, more manageable chunks.

3. **No Open h5 Connections**: Objects with open h5 connections are not `pickleable`. This means that they cannot be serialized and sent to other processes. If you need to pass an object with an open h5 connection to a function, close the connection before passing it to the function, e.g. by calling `h5file.close()` or `.compute()` inside Karabo.

4. **Use `.compute()` on Dask Arrays**: Before passing Dask arrays to the function, call `.compute()` on them to realize their values. This avoids potential issues and ensures efficient processing.

5. **Refer to Dask's Best Practices**: For a more comprehensive understanding and to avoid common pitfalls, consult [Dask's official best practices guide](https://docs.dask.org/en/stable/best-practices.html).

Following these guidelines will help ensure that you get the most out of Karabo's parallel processing capabilities.


### Parameters
- iterate_function (callable): The function to be applied to each element of the iterable. This function should take the current element of the iterable as its first argument, followed by any specified positional and keyword arguments.

- iterable (iterable): The collection of elements over which the iterate_function will be applied.

- args (tuple): Positional arguments that will be passed to the iterate_function after the current element of the iterable.

- kwargs (dict): Keyword arguments that will be passed to the iterate_function.

### Returns
- tuple: A tuple containing the results of the iterate_function for each element in the iterable. Results are gathered using Dask's compute function.

### Additional Notes
It's important when working on a `Slurm Cluster` to call DaskHandler.setup() at the beginning.

If 'verbose' is specified in kwargs and is set to True, progress messages will be printed during processing.

The function internally uses the distributed scheduler of Dask.

Leverage the `parallelize_with_dask` utility in Karabo to harness the power of parallel processing and speed up your data-intensive operations.

### Function Signature

```python
def parallelize_with_dask(
iterate_function: Callable[..., Any],
iterable: Iterable[Any],
*args: Any,
**kwargs: Any,
) -> Union[Any, Tuple[Any, ...], List[Any]]:

# Example
def my_function(element, *args, **kwargs):
# Do something with element
return result

parallelize_with_dask(my_function, my_iterable, *args, **kwargs) # The current element of the iterable is passed as the first argument to my_function
>>> (result1, result2, result3, ...)
```

## Use Karabo on a SLURM cluster

Karabo manages all available nodes through Dask, making the computational power conveniently accessible for the user. The `DaskHandler` class streamlines the creation of a Dask client and offers a user-friendly interface for interaction. This class contains static variables, which when altered, modify the behavior of the Dask client.
Expand Down
1 change: 1 addition & 0 deletions docker/dev/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# LEGACY-FILE, has to be checked before usage
# Create build container to not have copied filed in real container afterwards
FROM --platform=amd64 continuumio/miniconda3:4.12.0 as build
ARG IS_DOCKER_CONTAINER=true
COPY environment.yaml environment.yaml
COPY requirements.txt requirements.txt

Expand Down
1 change: 1 addition & 0 deletions docker/user/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Create build container to not have copied filed in real container afterwards
FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 as build
ARG KARABO_TAG
ARG IS_DOCKER_CONTAINER=true
RUN apt-get update && apt-get install -y git
RUN git clone --branch ${KARABO_TAG} --depth=1 https://github.com/i4Ds/Karabo-Pipeline.git

Expand Down
2 changes: 2 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ dependencies:
- bluebild
- cuda-cudart
- dask=2022.12.1
- dask-mpi
- mpi4py
- distributed
- eidos=1.1.0
- healpy
Expand Down
266 changes: 193 additions & 73 deletions karabo/examples/HIIM_Img_Recovery.ipynb

Large diffs are not rendered by default.

101 changes: 58 additions & 43 deletions karabo/simulation/line_emission.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
from astropy.convolution import Gaussian2DKernel
from astropy.io import fits
from astropy.wcs import WCS

# from dask.delayed import Delayed
from dask import compute, delayed # type: ignore[attr-defined]
from dask.distributed import Client
from numpy.typing import NDArray

from karabo.imaging.imager import Imager
Expand All @@ -26,7 +22,9 @@
from karabo.simulation.telescope import Telescope
from karabo.simulation.visibility import Visibility
from karabo.util._types import DirPathType, FilePathType, IntFloat, NPFloatLikeStrict
from karabo.util.dask import DaskHandler

# from dask.delayed import Delayed
from karabo.util.dask import parallelize_with_dask
from karabo.util.plotting_util import get_slices


Expand Down Expand Up @@ -226,7 +224,7 @@ def plot_scatter_recon(
plt.savefig(outfile)


def sky_slice(sky: SkyModel, z_min: np.float_, z_max: np.float_) -> SkyModel:
def sky_slice(sky: SkyModel, z_min: IntFloat, z_max: IntFloat) -> SkyModel:
"""
Extracting a slice from the sky which includes only sources between redshift z_min
and z_max.
Expand All @@ -240,15 +238,7 @@ def sky_slice(sky: SkyModel, z_min: np.float_, z_max: np.float_) -> SkyModel:
:return: Sky model only including the sources with redshifts between z_min and
z_max.
"""
sky_bin = SkyModel.copy_sky(sky)
if sky_bin.sources is None:
raise TypeError("`sky.sources` is None which is not allowed.")

z_obs = sky_bin.sources[:, 13]
sky_bin_idx = np.where((z_obs > z_min) & (z_obs < z_max))
sky_bin.sources = sky_bin.sources[sky_bin_idx]

return sky_bin
return sky.filter_by_column(13, z_min, z_max)


T = TypeVar("T", NDArray[np.float_], xr.DataArray, IntFloat)
Expand Down Expand Up @@ -462,8 +452,6 @@ def run_one_channel_simulation(
path: FilePathType,
sky: SkyModel,
telescope: Telescope,
z_min: np.float_,
z_max: np.float_,
freq_bin_start: float,
freq_bin_width: float,
ra_deg: IntFloat,
Expand All @@ -488,8 +476,6 @@ def run_one_channel_simulation(
of each source.
:param telescope: Telescope used. If None, the MEERKAT telescope will be used as a
default.
:param z_min: Smallest redshift in this bin.
:param z_max: Largest redshift in this bin.
:param freq_bin_start: Starting frequency in this bin
(i.e., largest frequency in the bin).
:param freq_bin_width: Size of the sky frequency bin which is simulated.
Expand All @@ -513,15 +499,13 @@ def run_one_channel_simulation(
:return: Reconstruction of one bin slice of the sky and its header.
"""

sky_bin = sky_slice(sky, z_min, z_max)

if verbose:
print("Starting simulation...")

freq_bin_middle = freq_bin_start - freq_bin_width / 2
dirty_image, header = karabo_reconstruction(
path,
sky=sky_bin,
sky=sky,
telescope=telescope,
ra_deg=ra_deg,
dec_deg=dec_deg,
Expand Down Expand Up @@ -594,7 +578,6 @@ def line_emission_pointing(
img_size: int = 4096,
circle: bool = True,
rascil: bool = True,
client: Optional[Client] = None,
verbose: bool = False,
) -> Tuple[NDArray[np.float_], List[NDArray[np.float_]], fits.header.Header, np.float_]:
"""
Expand Down Expand Up @@ -629,7 +612,6 @@ def line_emission_pointing(
:param circle: If set to True, the pointing has a round shape of size cut.
:param rascil: If True we use the Imager Rascil otherwise the Imager from Oskar is
used.
:param client: Setting a dask client is optional.
:param verbose: If True you get more print statements.
:return: Total line emission reconstruction, 3D line emission reconstruction,
Header of reconstruction and mean frequency.
Expand All @@ -647,18 +629,12 @@ def line_emission_pointing(

os.makedirs(outpath)

# Load sky into memory and close connection to h5
sky.compute()

if sky.sources is None:
raise TypeError(
"`sources` None is not allowed! Please set them in"
" the `SkyModel` before calling this function."
)

if not client:
client = DaskHandler.get_dask_client()

redshift_channel, freq_channel, freq_bin, freq_mid = freq_channels(
z_obs=sky.sources[:, 13],
channel_num=num_bins,
Expand All @@ -669,20 +645,39 @@ def line_emission_pointing(
n_jobs = num_bins
print(f"Submitting {n_jobs} jobs to the cluster.")

delayed_results = []
# Load the sky into memory
sky.compute()

for bin_idx in range(num_bins):
if verbose:
print(
f"Channel {bin_idx} is being processed...\n"
"Extracting the corresponding frequency slice from the sky model..."
)
delayed_ = delayed(run_one_channel_simulation)(
# Helper function to parallise with dask
def process_channel( # type: ignore[no-untyped-def]
bin_idx,
outpath,
sky,
telescope,
redshift_channel,
freq_channel,
ra_deg,
dec_deg,
beam_type,
gaussian_fwhm,
gaussian_ref_freq,
start_time,
obs_length,
cut,
img_size,
circle,
rascil,
verbose,
):
# Do the sky slicing here, so that less data is sent to each worker
z_min = redshift_channel[bin_idx]
z_max = redshift_channel[bin_idx + 1]
sky_bin = sky_slice(sky, z_min, z_max)

return run_one_channel_simulation(
path=outpath / f"slice_{bin_idx}",
sky=sky,
sky=sky_bin,
telescope=telescope,
z_min=redshift_channel[bin_idx],
z_max=redshift_channel[bin_idx + 1],
freq_bin_start=freq_channel[bin_idx],
freq_bin_width=freq_bin[bin_idx],
ra_deg=ra_deg,
Expand All @@ -698,9 +693,29 @@ def line_emission_pointing(
rascil=rascil,
verbose=verbose,
)
delayed_results.append(delayed_)

result = compute(*delayed_results, scheduler="distributed")
result = parallelize_with_dask(
process_channel,
range(num_bins),
outpath=outpath,
sky=sky,
telescope=telescope,
redshift_channel=redshift_channel,
freq_channel=freq_channel,
ra_deg=ra_deg,
dec_deg=dec_deg,
beam_type=beam_type,
gaussian_fwhm=gaussian_fwhm,
gaussian_ref_freq=gaussian_ref_freq,
start_time=start_time,
obs_length=obs_length,
cut=cut,
img_size=img_size,
circle=circle,
rascil=rascil,
verbose=verbose,
)

dirty_images = [x[0] for x in result]
headers = [x[1] for x in result]
header = headers[0]
Expand Down
53 changes: 19 additions & 34 deletions karabo/simulation/sky_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,29 +618,29 @@ def filter_by_radius_euclidean_flat_approximation(
else:
return copied_sky

def filter_by_flux(
def filter_by_column(
self,
min_flux_jy: IntFloat,
max_flux_jy: IntFloat,
col_idx: int,
min_val: IntFloat,
max_val: IntFloat,
) -> SkyModel:
"""
Filters the sky using the Stokes-I-flux
Values outside the range are removed
Filters the sky based on a specific column index
:param min_flux_jy: Minimum flux in Jy
:param max_flux_jy: Maximum flux in Jy
:param col_idx: Column index to filter by
:param min_val: Minimum value for the column
:param max_val: Maximum value for the column
:return sky: Filtered copy of the sky
"""
copied_sky = SkyModel.copy_sky(self)
if copied_sky.sources is None:
raise KaraboSkyModelError(
"`sources` None is not allowed. "
+ "Add sources before calling `filter_by_flux`."
"`sources` is None, add sources before filtering."
)

# Create mask
filter_mask = (copied_sky[:, 2] >= min_flux_jy) & (
copied_sky[:, 2] <= max_flux_jy
filter_mask = (copied_sky.sources[:, col_idx] >= min_val) & (
copied_sky.sources[:, col_idx] <= max_val
)
filter_mask = self.rechunk_array_based_on_self(filter_mask)

Expand All @@ -649,34 +649,19 @@ def filter_by_flux(

return copied_sky

def filter_by_flux(
self,
min_flux_jy: IntFloat,
max_flux_jy: IntFloat,
) -> SkyModel:
return self.filter_by_column(2, min_flux_jy, max_flux_jy)

def filter_by_frequency(
self,
min_freq: IntFloat,
max_freq: IntFloat,
) -> SkyModel:
"""
Filters the sky using the reference frequency in Hz
:param min_freq: Minimum frequency in Hz
:param max_freq: Maximum frequency in Hz
:return sky: Filtered copy of the sky
"""
copied_sky = SkyModel.copy_sky(self)
if copied_sky.sources is None:
raise KaraboSkyModelError(
"`sources` is None, add sources before calling `filter_by_frequency`."
)

# Create mask
filter_mask = (copied_sky.sources[:, 6] >= min_freq) & (
copied_sky.sources[:, 6] <= max_freq
)
filter_mask = self.rechunk_array_based_on_self(filter_mask)

# Apply the filter mask and drop the unmatched rows
copied_sky.sources = copied_sky.sources.where(filter_mask, drop=True)

return copied_sky
return self.filter_by_column(6, min_freq, max_freq)

def get_wcs(self) -> WCS:
"""
Expand Down
Loading

0 comments on commit eaaeb53

Please sign in to comment.