This repository contains code for training and testing models for super-resolution and segmentation of wildfire-burned areas, as presented in the manuscript "FireSR: A Dataset for Super-Resolution and Segmentation of Burned Areas", submitted to NeurIPS 2024 Datasets and Benchmarks Track. This repository includes three main directories:
- FiRes-DDPM: A multitask adaptation of the Image-Super-Resolution-via-Iterative-Refinement (SR3) model, generating both super-resolved images and segmentation masks.
- Image-Super-Resolution-via-Iterative-Refinement: An implementation of the SR3 model for single-image super-resolution. Original repo
- Pytorch-UNet: A U-Net implementation for semantic segmentation tasks. Original repo
Before using the models, ensure you have the necessary dependencies installed. You can do this by running the following commands in each model directory:
pip install -r requirements.txt
To train or test the models, you need to prepare your dataset by tiling the GeoTIFF files into smaller patches. This step is crucial for handling high-resolution imagery efficiently. Follow the steps below to tile your images:
Run the Tiling Script
Use the Python script below to tile the GeoTIFF images. This script divides each image into smaller patches of a specified size (e.g., 128x128 pixels).
import rasterio from import Window import os def tile_image(image_path, output_dir, tile_size=128): with as src: for i in range(0, src.height, tile_size): for j in range(0, src.width, tile_size): window = Window(j, i, tile_size, tile_size) transform = src.window_transform(window) outpath = os.path.join(output_dir, f"{os.path.basename(image_path).split('.')[0]}_{i}_{j}.tif") with, 'w', driver='GTiff', height=tile_size, width=tile_size, count=src.count, dtype=src.dtypes[0],, transform=transform) as dst: dst.write( # Example usage image_dir = 'FireSR/dataset/S2/post/' output_dir = 'FireSR_tiled/hr_128/' os.makedirs(output_dir, exist_ok=True) for image_file in os.listdir(image_dir): if image_file.endswith('.tif'): tile_image(os.path.join(image_dir, image_file), output_dir)
Make sure to adjust the
to point to your specific directories. -
Dataset structure
Your dataset should be structured in the following format:
FireSR_train/ │ ├── Daymet_128/ ├── LULC_128/ ├── hr_128/ ├── hr_mask_128/ # Only for FiRes-DDPM ├── lr_16/ ├── pre_fire_128/ ├── sr_16_128/
Navigate to the
directory. -
Follow the instructions in the model's README to set up the configuration files and initiate training:
python -p train -c config/train_S2_MODIS.json
Navigate to the
directory. -
Edit the dataset configuration in the JSON files to match your data paths and resolutions.
Start training:
python -p train -c config/train_S2_MODIS.json
Navigate to the
directory. -
To train the model, run:
python --epochs 5 --batch-size 16 --learning-rate 0.001 --amp
Navigate to the
directory. -
Follow the instructions in the model's README to set up the configuration files and initiate training:
python -p val -c config/test_S2_MODIS.json
Navigate to the
directory. -
Edit the dataset configuration in the JSON files to match your data paths and resolutions.
Start training:
python -p val -c config/test_S2_MODIS.json
Navigate to the
directory. -
To train the model, run:
python -i1 /path/to/post/imgs -i2 /path/to/pre/imgs -o /path/to/output/dir
Pre-trained model weights can be downloaded from the following link:
This repository is licensed under the MIT License.
For any questions or further information, please contact:
- Name: Eric Brune
- Email: [email protected]