Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream selected channels #128

Open
robmarkcole opened this issue May 13, 2024 · 0 comments
Open

Stream selected channels #128

robmarkcole opened this issue May 13, 2024 · 0 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@robmarkcole
Copy link
Contributor

robmarkcole commented May 13, 2024

🚀 Feature

Streaming subsets of channels

Motivation

My geotiff data is typically multispectral and I do experiments using subsets of the channels. I would like to stream only the required channels in order to save bandwidth

Pitch

Select e.g. channels 1,3,5 to stream

Alternatives

I can list the channels as separate files, and then access only those I require

Additional context

The equivalent using Rasterio:

import rasterio

# URL to an S3 bucket raster file
s3_url = 's3://your-bucket-name/path-to-your-raster-file.tif'

# Open the raster file
with rasterio.open(s3_url) as src:
    # Read a specific band, for example, band 1
    band1 = src.read(1)  # Reading only the first band

    # You can also read multiple specific bands by passing a tuple
    band1, band3 = src.read((1, 3))  # Reading bands 1 and 3

    # Process or analyze the bands as needed
    print(band1, band3)

My current solution:

class SegmentationStreamingDataset(StreamingDataset):
    """
    Segmentation dataset with streaming.

    Args:
        input_dir (str): Local directory or S3 location of the dataset
        transforms (Optional[Callable]): A transform that takes in an image and returns a transformed version.
        band_indices (Optional[List[int]]): List of band indices to read from the dataset.
    """

    def __init__(self, *args, transforms: Optional[Callable] = None, band_indices: Optional[List[int]] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.transforms = transforms
        self.band_indices = band_indices

    def __getitem__(self, index) -> dict:
        data = super().__getitem__(index)
        image_name = data["name"]
        image = data["image"]
        mask = data["mask"]

        with MemoryFile(image) as memfile:
            with memfile.open() as dataset:
                image = torch.from_numpy(dataset.read()).float()
                if self.band_indices:
                    image = image[self.band_indices]

        with MemoryFile(mask) as memfile:
            with memfile.open() as dataset:
                mask = torch.from_numpy(dataset.read()).long()                    

        sample = {"image": image, "mask": mask, "image_name": image_name}
        if self.transforms is not None:
            sample = self.transforms(sample)
        return sample
@robmarkcole robmarkcole added enhancement New feature or request help wanted Extra attention is needed labels May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant