diff --git a/ush/SpatialTemporalStatsTool/README.md b/ush/SpatialTemporalStatsTool/README.md index e890f4a..40b7d02 100644 --- a/ush/SpatialTemporalStatsTool/README.md +++ b/ush/SpatialTemporalStatsTool/README.md @@ -1,18 +1,23 @@ -### April 2024 +### November 2024 ### Azadeh Gholoubi -# Python Tool for Time/Space (2D) Evaluation +# Spatial and Temporal Analysis Tool for Satellite Observation Data ## Overview -This tool provides functionalities for processing and analyzing data over time and space. +**Purpose**: This tool performs spatial and temporal analysis for satellite observation data, allowing users to create customizable grids, filter data by time and region, and generate statistical and summary plots. -The `SpatialTemporalStats` class is designed to perform spatial and temporal statistics of data stored in NetCDF files. It includes features for generating grids, reading observational values, filtering data, plotting observations, and creating summary plots based on user settings. +### Key Functionalities: +- Grid-based Data Summaries: Creates spatial grids for data aggregation. +- Data Filtering: Processes data across specified time frames and geographical regions. +- Visualization: Generates evaluation plots for different data attributes and regions + +The `SpatialTemporalStats` class is central to this tool, with methods for creating grids, reading observational data, filtering, plotting, and producing summary statistics. ### Important Methods of the SpatialTemporalStats Class -- `generate_grid(resolution=1)`: Generates a grid for spatial averaging based on the specified resolution. (default resolution is 1X1) -- `read_obs_values()`: Reads observational values from NetCDF files, filters them based on various criteria, performs spatial averaging, and returns the averaged values. -- `plot_obs()`: Plots observational data on a map, showing different regions and their corresponding data values. -- `list_variable_names(file_path)`: Lists variable names from a NetCDF file. -- `make_summary_plots()`: Generates summary plots of observational data, including scatter plots of counts, means, and standard deviations. +- `generate_grid(resolution=1)`: Generates a spatial grid with specified resolution (default: 1x1 degree). +- `read_obs_values()`: Reads and filters observational data from NetCDF files, performs spatial averaging, and returns averaged values. +- `plot_obs()`: Plots observation data on a map, with options for different regions and grid sizes. +- `list_variable_names(file_path)`: Lists variable names available in a specified NetCDF file. +- `make_summary_plots()`: Generates scatter plots for counts, means, and standard deviations of observational data. ## Requirements User need to load EVA environment when working on Hera, use the following commands: @@ -23,91 +28,61 @@ module load EVA/hera ``` ## Usage -`user_Analysis.py` contains the `SpatialTemporalStats` class, which encapsulates the functionalities of the tool. Here's how to use it: - -1. Import the `SpatialTemporalStats` class: - - ```python - from SpatialTemporalStats import SpatialTemporalStats -2. Create an instance of the SpatialTemporalStats class: - - ```python - my_tool = SpatialTemporalStats() - -3. Specify the parameters based on the type of plots that you want: - - - `input_path`: Directory for input .nc files - - `output_path`: Path to output plots - - `sensor`: Sensor name - - `channel_no`: Channel number (e.g., 1, 2, 3, 5) - - `var_name`: variable name - - `start_date, end_date`: Start and End date of the input files for evaluations - - `region`: Insert a number to select Global or Regional ouput plots (1: global (default), 2: polar region, 3: mid-latitudes region, 4: tropics region, 5: southern mid-latitudes region, 6: southern polar region) - - `resolution`: Resolution for grid generation (1: 1X1 degree(default), 2:2X2 degree, 3:3X3 degree) - - `filter_by_vars`: Filter by variable to generate plots based on surface type (land, water, snow, seaice) or can be an empty list for no filtering. - -4. Call `read_obs_values` to Read observational values and perform analysis: - -```python -o_minus_f_gdf = my_tool.read_obs_values( - input_path, - sensor, - var_name, - channel_no, - start_date, - end_date, - filter_by_vars, - QC_filter) -``` -5. Call `plot_obs` to plot evaluation plots based on your setting for grid size, channel, region, surface type, and filtering values: - +To get started, run the following command to see all available options and argument formats: ```python -my_tool.plot_obs(o_minus_f_gdf, var_name, region, resolution, output_path) +python SpatialTemporalStats.py -h ``` -6. Call `make_summary_plots` to generate summary plots: +This command will display detailed information on how to input your settings. Key parameters include: + +- input: Path to input data files +- output: Path for saving the results +- sensor: Satellite sensor name (e.g., "atms_n20") +- var: Variable to analyze (e.g., "Obs_Minus_Forecast_adjusted") +- ch: Channel number for the analysis (e.g., 1) +- grid: Grid resolution for spatial analysis (choices: 0.5, 1, 2; default: 1) +- region: Region code for map plot: +1 – Global +2 – Polar region (+60° latitude and above) +3 – Northern mid-latitudes (20° to 60° latitude) +4 – Tropics (-20° to 20° latitude) +5 – Southern mid-latitudes (-60° to -20° latitude) +6 – Southern polar region (below -60° latitude) +- sdate / -edate: Start and end dates for the time period (e.g., "2023-01-27" to "2023-01-28") +These parameters allow you to customize the spatial and temporal analysis to suit specific data and regions. + + -```python -summary_results = my_tool.make_summary_plots( - input_path, sensor, var_name, start_date, end_date, QC_filter, output_path -) -``` ## Notes Ensure that the `obs_files_path` and `output_path` variables are correctly set to the paths of observational files and output directory, respectively. Adjust method parameters and plotting settings as needed for your specific use case. Make sure to define the `filter_by_variable` method as needed for filtering observational data based on variable values. -To run the tool: - -``` -python user_Analysis.py +## Example Usage +```python +python SpatialTemporalStats.py -input /PATH/TO/INPUT/DIAG/FILES -output ./Results -sensor "atms_n20" -var "Obs_Minus_Forecast_adjusted" -ch 1 -grid 2 -region 1 -sdate "2023-01-27" -edate "2023-01-28" ``` -## Example Usage -Here's a sample script demonstrating how to use the`SpatialTemporalStats` tool: -![image](https://github.com/NOAA-EMC/PyGSI/assets/51101867/4379cb6e-e1a7-4167-8859-ae881f2c61c1) - ## Example output plots using different settings ```python -var_name = "Obs_Minus_Forecast_adjusted" -region = 1 -resolution = 2 -filter_by_vars=[] +-sensor "atms_n20" -var "Obs_Minus_Forecast_adjusted" -ch 1 -grid 2 -region 1 -sdate "2023-01-27" -edate "2023-01-28" ``` -Calling `read_obs_values` and then `my_tool.plot_obs()` method will produce three plots for ave,count, rms as shown below: -![atms_n20_ch1_Obs_Minus_Forecast_adjusted_Average_region_1](https://github.com/NOAA-EMC/PyGSI/assets/51101867/b838ae92-3303-45ca-b7ba-35b11c01213c) -![atms_n20_ch1_Obs_Minus_Forecast_adjusted_Count_region_1](https://github.com/NOAA-EMC/PyGSI/assets/51101867/113ef427-9771-462a-b543-f36166ed978e) -![atms_n20_ch1_Obs_Minus_Forecast_adjusted_RMS_region_1](https://github.com/NOAA-EMC/PyGSI/assets/51101867/ed4bc44c-6364-451b-811e-b2c8a0ce5d2a) +![atms_n20_ch1_Obs_Minus_Forecast_adjusted_Average_region_1](https://github.com/user-attachments/assets/e0ddcf64-8ce1-4175-b646-71d1d38ec3d4) +![atms_n20_ch1_Obs_Minus_Forecast_adjusted_Count_region_1](https://github.com/user-attachments/assets/a33dd6c4-bfb0-4ae9-a46d-02086f7dc960) +![atms_n20_ch1_Obs_Minus_Forecast_adjusted_RMS_region_1](https://github.com/user-attachments/assets/f9b34e74-7511-464d-a27d-82f08cfa5c6b) + Example plot for filtering out the locations where the land fraction is less than 0.9 ```python -filter_by_vars = [("Land_Fraction", "lt", 0.9),] + -filter_by_vars Land_Fraction,lt,0.9 ``` -![atms_n20_ch1_Obs_Minus_Forecast_adjusted_Average_region_1](https://github.com/NOAA-EMC/PyGSI/assets/51101867/978e2677-4a7b-45b3-a2e2-67674bf0803e) +![atms_n20_ch1_Obs_Minus_Forecast_adjusted_Average_region_1](https://github.com/user-attachments/assets/bc6b7215-9d26-41c8-b51d-0f51d42238c3) + +Example of the summary plots: +![atms_n20_Obs_Minus_Forecast_adjusted_mean_std](https://github.com/user-attachments/assets/99b09315-1faa-4fd1-9c26-e7b591dba2fc) +![atms_n20_Obs_Minus_Forecast_adjusted_sumamryCounts](https://github.com/user-attachments/assets/449cd174-f50d-4521-ab9f-e0d4b6f5ad9b) + -Calling read_obs_values and then my_tool.make_summary_plots() method will generate two summary plots: -![atms_n20_Obs_Minus_Forecast_adjusted_mean_std](https://github.com/NOAA-EMC/PyGSI/assets/51101867/28cc26f4-c024-4713-82e1-b9a7ed5f5d1b) -![atms_n20_Obs_Minus_Forecast_adjusted_sumamryCounts](https://github.com/NOAA-EMC/PyGSI/assets/51101867/fd835f41-5b9c-4a14-be85-4c74d49571f6) diff --git a/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_Average_region_1.gpkg b/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_Average_region_1.gpkg new file mode 100644 index 0000000..a8f4ee6 Binary files /dev/null and b/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_Average_region_1.gpkg differ diff --git a/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_Count_region_1.gpkg b/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_Count_region_1.gpkg new file mode 100644 index 0000000..ea5bff3 Binary files /dev/null and b/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_Count_region_1.gpkg differ diff --git a/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_RMS_region_1.gpkg b/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_RMS_region_1.gpkg new file mode 100644 index 0000000..69ffb32 Binary files /dev/null and b/ush/SpatialTemporalStatsTool/Results/atms_n20_ch1_Obs_Minus_Forecast_adjusted_RMS_region_1.gpkg differ diff --git a/ush/SpatialTemporalStatsTool/Results/atms_n20_summary.csv b/ush/SpatialTemporalStatsTool/Results/atms_n20_summary.csv new file mode 100644 index 0000000..98cd758 --- /dev/null +++ b/ush/SpatialTemporalStatsTool/Results/atms_n20_summary.csv @@ -0,0 +1,23 @@ +channel,count,std,mean,rms +1,45262,2.6988273643037384,0.19821635086344594,2.7060966102607162 +2,44462,2.885330599655865,0.44061785540288817,2.918780012918158 +3,45276,2.3529656988789394,0.13292025534906543,2.3567170755911815 +4,45672,1.4469203953076863,0.0448353262127089,1.4476148786310565 +5,45700,0.4795219456469957,0.017871514061625335,0.4798548607359764 +6,45744,0.15752856235399282,0.0006352198961456739,0.15752984308261478 +7,52929,0.11603420279762479,0.0021500649409356577,0.11605412098728322 +8,84112,0.11362915543792347,0.007189899951947341,0.1138563991475909 +9,84122,0.1242098780941416,0.005015079193266936,0.12431108090382417 +10,84119,0.1678040669750801,0.0012762595843600758,0.1678089202989674 +11,84063,0.20609831023729452,0.0049344418110122325,0.20615737240917775 +12,83916,0.25259764753168185,0.0011790454284033073,0.25260039922111327 +13,82352,0.358631419062375,-0.01701693121436197,0.35903491569296797 +14,83415,0.6214590289702845,-0.08582188340839891,0.6273569321849265 +15,83772,1.0779297883385832,-0.9148381735994018,1.4138109889452872 +16,40843,3.4015874021960975,0.12291585633857981,3.4038074508583764 +17,42105,2.0949495712374997,-0.0925232459163654,2.0969917160216016 +18,42368,1.3821771839342785,-0.010803419354404722,1.38221940431261 +19,45173,1.368377491943836,0.0239483131097599,1.3685870385764136 +20,45107,1.4126293284373275,0.09488016782235158,1.4158120870395723 +21,44935,1.4810928233580214,0.134929813972193,1.4872262793876698 +22,44728,1.5653715176881486,0.1405545452597037,1.5716690391372394 diff --git a/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py b/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py index 2ba6b46..53280cd 100644 --- a/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py +++ b/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py @@ -1,444 +1,691 @@ -import os -from datetime import datetime - -import geopandas as gpd -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import xarray -from shapely.geometry import Point, Polygon - - -class SpatialTemporalStats: - def __init__(self): - self.grid_gdf = None - self.obs_gdf = None - - def generate_grid(self, resolution=1): - self.resolution = resolution - # Generate the latitude and longitude values using meshgrid - grid_lons, grid_lats = np.meshgrid( - np.arange(-180, 181, resolution), np.arange(-90, 91, resolution) - ) - - # Flatten the arrays to get coordinates - grid_coords = np.vstack([grid_lons.flatten(), grid_lats.flatten()]).T - - # Create a GeoDataFrame from the coordinates - self.grid_gdf = gpd.GeoDataFrame( - geometry=[ - Polygon( - [ - (lon, lat), - (lon + resolution, lat), - (lon + resolution, lat + resolution), - (lon, lat + resolution), - ] - ) - for lon, lat in grid_coords - ], - crs="EPSG:4326", - ) # CRS for WGS84 - self.grid_gdf["grid_id"] = np.arange(1, len(self.grid_gdf) + 1) - - def _extract_date_times(self, filenames): - date_times = [] - for filename in filenames: - # Split the filename by '.' to get the parts - parts = filename.split(".") - - # Extract the last part which contains the date/time information - date_time_part = parts[-2] - - # date/time format in filename is 'YYYYMMDDHH', can parse it accordingly - year = int(date_time_part[:4]) - month = int(date_time_part[4:6]) - day = int(date_time_part[6:8]) - hour = int(date_time_part[8:10]) - - # Construct the datetime object - date_time = datetime(year, month, day, hour) - - date_times.append(date_time) - - return date_times - - def read_obs_values( - self, - obs_files_path, - sensor, - var_name, - channel_no, - start_date, - end_date, - filter_by_vars, - QC_filter, - ): - self.sensor = sensor - self.channel_no = channel_no - # read all obs files - all_files = os.listdir(obs_files_path) - obs_files = [ - os.path.join(obs_files_path, file) - for file in all_files - if file.endswith(".nc4") and "diag_%s_ges" % sensor in file - ] - - # get date time from file names - files_date_times_df = pd.DataFrame() - - files_date_times = self._extract_date_times(obs_files) - files_date_times_df["file_name"] = obs_files - files_date_times_df["date_time"] = files_date_times - files_date_times_df["date"] = pd.to_datetime( - files_date_times_df["date_time"].dt.date - ) - - # read start date - start_date = datetime.strptime(start_date, "%Y-%m-%d") - end_date = datetime.strptime(end_date, "%Y-%m-%d") - - studied_cycle_files = files_date_times_df[ - ( - (files_date_times_df["date"] >= start_date) - & ((files_date_times_df["date"] <= end_date)) - ) - ]["file_name"] - - studied_gdf_list = [] - for this_cycle_obs_file in studied_cycle_files: - ds = xarray.open_dataset(this_cycle_obs_file) - - Combined_bool = ds["Channel_Index"].data == channel_no - - if QC_filter: - QC_bool = ds["QC_Flag"].data == 0 - Combined_bool = Combined_bool * QC_bool - - # apply filters by variable - for this_filter in filter_by_vars: - filter_var_name, filter_operation, filter_value = this_filter - if filter_operation == "lt": - this_filter_bool = ds[filter_var_name].data <= filter_value - else: - this_filter_bool = ds[filter_var_name].data >= filter_value - Combined_bool = ( - Combined_bool * ~this_filter_bool - ) # here we have to negate the above bool to make it right - - this_cycle_var_values = ds[var_name].data[Combined_bool] - this_cycle_lat_values = ds["Latitude"].data[Combined_bool] - this_cycle_long_values = ds["Longitude"].data[Combined_bool] - this_cycle_long_values = np.where( - this_cycle_long_values <= 180, - this_cycle_long_values, - this_cycle_long_values - 360, - ) - - geometry = [ - Point(xy) for xy in zip(this_cycle_long_values, this_cycle_lat_values) - ] - - # Create a GeoDataFrame - this_cycle_gdf = gpd.GeoDataFrame(geometry=geometry, crs="EPSG:4326") - this_cycle_gdf["value"] = this_cycle_var_values - - studied_gdf_list.append(this_cycle_gdf) - - studied_gdf = pd.concat(studied_gdf_list) - - # Perform spatial join - joined_gdf = gpd.sjoin(studied_gdf, self.grid_gdf, op="within", how="right") - - # Calculate average values of points in each polygon - self.obs_gdf = self.grid_gdf.copy() - self.obs_gdf[var_name + "_Average"] = joined_gdf.groupby("grid_id")[ - "value" - ].mean() - self.obs_gdf[var_name + "_RMS"] = joined_gdf.groupby("grid_id")["value"].apply( - lambda x: np.sqrt((x**2).mean()) - ) - self.obs_gdf[var_name + "_Count"] = joined_gdf.groupby("grid_id")[ - "value" - ].count() - - # convert count of zero to null. This will help also for plotting - self.obs_gdf[var_name + "_Count"] = np.where( - self.obs_gdf[var_name + "_Count"].values == 0, - np.nan, - self.obs_gdf[var_name + "_Count"].values, - ) - - return self.obs_gdf - - def plot_obs(self, selected_var_gdf, var_name, region, resolution, output_path): - self.resolution = resolution - var_names = [var_name + "_Average", var_name + "_Count", var_name + "_RMS"] - - for _, item in enumerate(var_names): - plt.figure(figsize=(12, 8)) - ax = plt.subplot(1, 1, 1) - - if region == 1: - # Plotting global region (no need for filtering) - title = "Global Region" - filtered_gdf = selected_var_gdf - - elif region == 2: - # Plotting polar region (+60 latitude and above) - title = "Polar Region (+60 latitude and above)" - filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply( - lambda geom: self.is_polygon_in_polar_region(geom, 60) - ) - ] - - elif region == 3: - # Plotting northern mid-latitudes region (20 to 60 latitude) - title = "Northern Mid-latitudes Region (20 to 60 latitude)" - filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply( - lambda geom: self.is_polygon_in_latitude_range(geom, 20, 60) - ) - ] - - elif region == 4: - # Plotting tropics region (-20 to 20 latitude) - title = "Tropics Region (-20 to 20 latitude)" - filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply( - lambda geom: self.is_polygon_in_latitude_range(geom, -20, 20) - ) - ] - - elif region == 5: - # Plotting southern mid-latitudes region (-60 to -20 latitude) - title = "Southern Mid-latitudes Region (-60 to -20 latitude)" - filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply( - lambda geom: self.is_polygon_in_latitude_range(geom, -60, -20) - ) - ] - - elif region == 6: - # Plotting southern polar region (less than -60 latitude) - title = "Southern Polar Region (less than -60 latitude)" - filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply(lambda geom: geom.centroid.y < -60) - ] - - min_val, max_val, std_val, avg_val = ( - filtered_gdf[item].min(), - filtered_gdf[item].max(), - filtered_gdf[item].std(), - filtered_gdf[item].mean(), - ) - - if item == "Obs_Minus_Forecast_adjusted_Average": - max_val_cbar = 5.0 * std_val - min_val_cbar = -5.0 * std_val - cmap = "bwr" - else: - max_val_cbar = max_val - min_val_cbar = min_val - cmap = "jet" - - cbar_label = ( - "grid=%dx%d, min=%.3lf, max=%.3lf, bias=%.3lf, std=%.3lf\n" - % ( - resolution, - resolution, - min_val, - max_val, - avg_val, - std_val, - ) - ) - - filtered_gdf.plot( - ax=ax, - cmap=cmap, - vmin=min_val_cbar, - vmax=max_val_cbar, - column=item, - legend=True, - missing_kwds={"color": "lightgrey"}, - legend_kwds={ - "orientation": "horizontal", - "shrink": 0.5, - "label": cbar_label, - }, - ) - - plt.title("%s\n%s ch:%d %s" % (title, self.sensor, self.channel_no, item)) - plt.savefig( - os.path.join( - output_path, - "%s_ch%d_%s_region_%d.png" - % (self.sensor, self.channel_no, item, region), - ) - ) - plt.close() - - def is_polygon_in_polar_region(self, polygon, latitude_threshold): - """ - Check if a polygon is in the polar region based on a latitude threshold. - """ - # Get the centroid of the polygon - centroid = polygon.centroid - - # Extract the latitude of the centroid - centroid_latitude = centroid.y - - # Check if the latitude is above the threshold - return centroid_latitude >= latitude_threshold - - def is_polygon_in_latitude_range(self, polygon, min_latitude, max_latitude): - """ - Check if a polygon is in the specified latitude range. - """ - # Get the centroid of the polygon - centroid = polygon.centroid - - # Extract the latitude of the centroid - centroid_latitude = centroid.y - - # Check if the latitude is within the specified range - return min_latitude <= centroid_latitude <= max_latitude - - def list_variable_names(self, file_path): - ds = xarray.open_dataset(file_path) - print(ds.info()) - - def make_summary_plots( - self, - obs_files_path, - sensor, - var_name, - start_date, - end_date, - QC_filter, - output_path, - ): - self.sensor = sensor - # read all obs files - all_files = os.listdir(obs_files_path) - obs_files = [ - os.path.join(obs_files_path, file) - for file in all_files - if file.endswith(".nc4") and "diag_%s_ges" % sensor in file - ] - - # get date time from file names. - # alternatively could get from attribute but that needs reading the entire nc4 - files_date_times_df = pd.DataFrame() - - files_date_times = self._extract_date_times(obs_files) - files_date_times_df["file_name"] = obs_files - files_date_times_df["date_time"] = files_date_times - files_date_times_df["date"] = pd.to_datetime( - files_date_times_df["date_time"].dt.date - ) - - # read start date - start_date = datetime.strptime(start_date, "%Y-%m-%d") - end_date = datetime.strptime(end_date, "%Y-%m-%d") - - studied_cycle_files = files_date_times_df[ - ( - (files_date_times_df["date"] >= start_date) - & ((files_date_times_df["date"] <= end_date)) - ) - ]["file_name"] - index = studied_cycle_files.index - - Summary_results = [] - - # get unique channels from one of the files - ds = xarray.open_dataset(studied_cycle_files[index[0]]) - unique_channels = np.unique(ds["Channel_Index"].data).tolist() - print("Total Number of Channels ", len(unique_channels)) - Allchannels_data = {} - for this_channel in unique_channels: - Allchannels_data[this_channel] = np.empty(shape=(0,)) - for this_cycle_obs_file in studied_cycle_files: - ds = xarray.open_dataset(this_cycle_obs_file) - if QC_filter: - QC_bool = ds["QC_Flag"].data == 0 - for this_channel in unique_channels: - channel_bool = ds["Channel_Index"].data == this_channel - - this_cycle_channel_var_values = ds[var_name].data[ - channel_bool * QC_bool - ] - Allchannels_data[this_channel] = np.append( - Allchannels_data[this_channel], this_cycle_channel_var_values - ) - - for this_channel in unique_channels: - this_channel_values = Allchannels_data[this_channel] - squared_values = [x**2 for x in this_channel_values] - mean_of_squares = sum(squared_values) / len(squared_values) - rms_value = mean_of_squares ** 0.5 - Summary_results.append( - [ - this_channel, - np.size(this_channel_values), - np.std(this_channel_values), - np.mean(this_channel_values), - rms_value, - ] - ) - - Summary_resultsDF = pd.DataFrame( - Summary_results, columns=["channel", "count", "std", "mean", "rms"]) - # Plotting - plt.figure(figsize=(10, 6)) - plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["count"], s=50) - plt.xlabel("Channel") - plt.ylabel("Count") - plt.title("%s %s" % ((self.sensor, var_name))) - plt.grid(True) - plt.tight_layout() - plt.savefig( - os.path.join( - output_path, "%s_%s_sumamryCounts.png" % (self.sensor, var_name) - ) - ) - plt.close() - - # Plotting scatter plot for mean and std - plt.figure(figsize=(15, 6)) - plt.scatter( - Summary_resultsDF["channel"], - Summary_resultsDF["mean"], - s=50, - c="green", - label="Mean", - ) - plt.scatter( - Summary_resultsDF["channel"], - Summary_resultsDF["std"], - s=50, - c="red", - label="Std", - ) - plt.scatter( - Summary_resultsDF["channel"], - Summary_resultsDF["rms"], - s=50, - label="Rms", - facecolors="none", - edgecolors="blue", - ) - plt.xlabel("Channel") - plt.ylabel("Statistics") - plt.title("%s %s" % ((self.sensor, var_name))) - plt.grid(True) - plt.tight_layout() - plt.legend() - plt.savefig( - os.path.join(output_path, "%s_%s_mean_std.png" % (self.sensor, var_name)) - ) - - return Summary_resultsDF +import argparse +import os +from datetime import datetime + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import geopandas as gpd +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xarray +from shapely.geometry import Point, Polygon + + +class SpatialTemporalStats: + def __init__(self): + self.grid_gdf = None + self.obs_gdf = None + + def generate_grid(self, resolution=1): + self.resolution = resolution + # Generate the latitude and longitude values using meshgrid + grid_lons, grid_lats = np.meshgrid( + np.arange(-180, 181, resolution), np.arange(-90, 91, resolution) + ) + + # Flatten the arrays to get coordinates + grid_coords = np.vstack([grid_lons.flatten(), grid_lats.flatten()]).T + + # Create a GeoDataFrame from the coordinates + self.grid_gdf = gpd.GeoDataFrame( + geometry=[ + Polygon( + [ + (lon, lat), + (lon + resolution, lat), + (lon + resolution, lat + resolution), + (lon, lat + resolution), + ] + ) + for lon, lat in grid_coords + ], + crs="EPSG:4326", + ) # CRS for WGS84 + self.grid_gdf["grid_id"] = np.arange(1, len(self.grid_gdf) + 1) + + def _extract_date_times(self, filenames): + date_times = [] + for filename in filenames: + # Split the filename by '.' to get the parts + parts = filename.split(".") + + # Extract the last part which contains the date/time information + date_time_part = parts[-2] + + # date/time format in filename is 'YYYYMMDDHH', can parse it accordingly + year = int(date_time_part[:4]) + month = int(date_time_part[4:6]) + day = int(date_time_part[6:8]) + hour = int(date_time_part[8:10]) + + # Construct the datetime object + date_time = datetime(year, month, day, hour) + + date_times.append(date_time) + + return date_times + + def read_obs_values( + self, + obs_files_path, + sensor, + var_name, + channel_no, + start_date, + end_date, + filter_by_vars, + QC_filter, + ): + self.sensor = sensor + self.channel_no = channel_no + # read all obs files + all_files = os.listdir(obs_files_path) + obs_files = [ + os.path.join(obs_files_path, file) + for file in all_files + if file.endswith(".nc4") and "diag_%s_ges" % sensor in file + ] + + # get date time from file names + files_date_times_df = pd.DataFrame() + + files_date_times = self._extract_date_times(obs_files) + files_date_times_df["file_name"] = obs_files + files_date_times_df["date_time"] = files_date_times + files_date_times_df["date"] = pd.to_datetime( + files_date_times_df["date_time"].dt.date + ) + + # read start date + start_date = datetime.strptime(start_date, "%Y-%m-%d") + end_date = datetime.strptime(end_date, "%Y-%m-%d") + + studied_cycle_files = files_date_times_df[ + ( + (files_date_times_df["date"] >= start_date) + & ((files_date_times_df["date"] <= end_date)) + ) + ]["file_name"] + + studied_gdf_list = [] + for this_cycle_obs_file in studied_cycle_files: + ds = xarray.open_dataset(this_cycle_obs_file) + + Combined_bool = ds["Channel_Index"].data == channel_no + + if QC_filter: + QC_bool = ds["QC_Flag"].data == 0 + Combined_bool = Combined_bool * QC_bool + + # apply filters by variable + for this_filter in filter_by_vars: + filter_var_name, filter_operation, filter_value = this_filter + if filter_operation == "lt": + this_filter_bool = ds[filter_var_name].data <= filter_value + else: + this_filter_bool = ds[filter_var_name].data >= filter_value + Combined_bool = ( + Combined_bool * ~this_filter_bool + ) # here we have to negate the above bool to make it right + + this_cycle_var_values = ds[var_name].data[Combined_bool] + this_cycle_lat_values = ds["Latitude"].data[Combined_bool] + this_cycle_long_values = ds["Longitude"].data[Combined_bool] + this_cycle_long_values = np.where( + this_cycle_long_values <= 180, + this_cycle_long_values, + this_cycle_long_values - 360, + ) + + geometry = [ + Point(xy) for xy in zip(this_cycle_long_values, this_cycle_lat_values) + ] + + # Create a GeoDataFrame + this_cycle_gdf = gpd.GeoDataFrame(geometry=geometry, crs="EPSG:4326") + this_cycle_gdf["value"] = this_cycle_var_values + + studied_gdf_list.append(this_cycle_gdf) + + studied_gdf = pd.concat(studied_gdf_list) + + # Perform spatial join + joined_gdf = gpd.sjoin(studied_gdf, self.grid_gdf, op="within", how="right") + + # Calculate average values of points in each polygon + self.obs_gdf = self.grid_gdf.copy() + self.obs_gdf[var_name + "_Average"] = joined_gdf.groupby("grid_id")[ + "value" + ].mean() + self.obs_gdf[var_name + "_RMS"] = joined_gdf.groupby("grid_id")["value"].apply( + lambda x: np.sqrt((x**2).mean()) + ) + self.obs_gdf[var_name + "_Count"] = joined_gdf.groupby("grid_id")[ + "value" + ].count() + + # convert count of zero to null. This will help also for plotting + self.obs_gdf[var_name + "_Count"] = np.where( + self.obs_gdf[var_name + "_Count"].values == 0, + np.nan, + self.obs_gdf[var_name + "_Count"].values, + ) + + return self.obs_gdf + + def plot_obs(self, selected_var_gdf, var_name, region, resolution, output_path): + self.resolution = resolution + var_names = [var_name + "_Average", var_name + "_Count", var_name + "_RMS"] + + for _, item in enumerate(var_names): + plt.figure(figsize=(12, 8)) + ax = plt.subplot(1, 1, 1, projection=ccrs.PlateCarree()) + # Add global map coastlines + ax.add_feature(cfeature.GSHHSFeature(scale="auto")) + filtered_gdf = selected_var_gdf.copy() + + if region == 1: + # Plotting global region (no need for filtering) + title = "Global" + # filtered_gdf = selected_var_gdf + + elif region == 2: + # Plotting polar region (+60 latitude and above) + title = "Polar Region (+60 latitude and above)" + filtered_gdf[item] = np.where( + filtered_gdf.geometry.apply( + lambda geom: self.is_polygon_in_polar_region(geom, 60) + ), + filtered_gdf[item], + np.nan, + ) + + elif region == 3: + # Plotting northern mid-latitudes region (20 to 60 latitude) + title = "Northern Mid-latitudes Region (20 to 60 latitude)" + filtered_gdf[item] = np.where( + filtered_gdf.geometry.apply( + lambda geom: self.is_polygon_in_latitude_range(geom, 20, 60) + ), + filtered_gdf[item], + np.nan, + ) + + elif region == 4: + # Plotting tropics region (-20 to 20 latitude) + title = "Tropics Region (-20 to 20 latitude)" + + filtered_gdf[item] = np.where( + filtered_gdf.geometry.apply( + lambda geom: self.is_polygon_in_latitude_range(geom, -20, 20) + ), + filtered_gdf[item], + np.nan, + ) + + elif region == 5: + # Plotting southern mid-latitudes region (-60 to -20 latitude) + title = "Southern Mid-latitudes Region (-60 to -20 latitude)" + filtered_gdf[item] = np.where( + filtered_gdf.geometry.apply( + lambda geom: self.is_polygon_in_latitude_range(geom, -60, -20) + ), + filtered_gdf[item], + np.nan, + ) + + elif region == 6: + # Plotting southern polar region (less than -60 latitude) + title = "Southern Polar Region (less than -60 latitude)" + filtered_gdf[item] = np.where( + filtered_gdf.geometry.apply(lambda geom: geom.centroid.y < -60), + filtered_gdf[item], + np.nan, + ) + # filtered_gdf = selected_var_gdf[ + # selected_var_gdf.geometry.apply(lambda geom: geom.centroid.y < -60) + # ] + + min_val, max_val, std_val, avg_val = ( + filtered_gdf[item].min(), + filtered_gdf[item].max(), + filtered_gdf[item].std(), + filtered_gdf[item].mean(), + ) + + if item == "Obs_Minus_Forecast_adjusted_Average": + max_val_cbar = 5.0 * std_val + min_val_cbar = -5.0 * std_val + cmap = "bwr" + else: + max_val_cbar = max_val + min_val_cbar = min_val + cmap = "jet" + + if item == "Obs_Minus_Forecast_adjusted_Count": + cbar_label = "grid=%dx%d, min=%.3lf, max=%.3lf\n" % ( + resolution, + resolution, + min_val, + max_val, + ) + else: + cbar_label = ( + "grid=%dx%d, min=%.3lf, max=%.3lf, bias=%.3lf, std=%.3lf\n" + % ( + resolution, + resolution, + min_val, + max_val, + avg_val, + std_val, + ) + ) + + filtered_gdf.plot( + ax=ax, + cmap=cmap, + vmin=min_val_cbar, + vmax=max_val_cbar, + column=item, + legend=True, + missing_kwds={"color": "lightgrey"}, + legend_kwds={ + "orientation": "horizontal", + "shrink": 0.5, + "label": cbar_label, + }, + ) + + filtered_gdf.to_file( + os.path.join( + output_path, + "%s_ch%d_%s_region_%d.gpkg" + % (self.sensor, self.channel_no, item, region), + ) + ) + + plt.title("%s\n%s ch:%d %s" % (title, self.sensor, self.channel_no, item)) + plt.savefig( + os.path.join( + output_path, + "%s_ch%d_%s_region_%d.png" + % (self.sensor, self.channel_no, item, region), + ) + ) + plt.close() + + def is_polygon_in_polar_region(self, polygon, latitude_threshold): + """ + Check if a polygon is in the polar region based on a latitude threshold. + """ + # Get the centroid of the polygon + centroid = polygon.centroid + + # Extract the latitude of the centroid + centroid_latitude = centroid.y + + # Check if the latitude is above the threshold + return centroid_latitude >= latitude_threshold + + def is_polygon_in_latitude_range(self, polygon, min_latitude, max_latitude): + """ + Check if a polygon is in the specified latitude range. + """ + # Get the centroid of the polygon + centroid = polygon.centroid + + # Extract the latitude of the centroid + centroid_latitude = centroid.y + + # Check if the latitude is within the specified range + return min_latitude <= centroid_latitude <= max_latitude + + def list_variable_names(self, file_path): + ds = xarray.open_dataset(file_path) + print(ds.info()) + + def make_summary_plots( + self, + obs_files_path, + sensor, + var_name, + start_date, + end_date, + QC_filter, + output_path, + ): + self.sensor = sensor + # read all obs files + all_files = os.listdir(obs_files_path) + obs_files = [ + os.path.join(obs_files_path, file) + for file in all_files + if file.endswith(".nc4") and "diag_%s_ges" % sensor in file + ] + + # get date time from file names. + # alternatively could get from attribute but that needs reading the entire nc4 + files_date_times_df = pd.DataFrame() + + files_date_times = self._extract_date_times(obs_files) + files_date_times_df["file_name"] = obs_files + files_date_times_df["date_time"] = files_date_times + files_date_times_df["date"] = pd.to_datetime( + files_date_times_df["date_time"].dt.date + ) + + # read start date + start_date = datetime.strptime(start_date, "%Y-%m-%d") + end_date = datetime.strptime(end_date, "%Y-%m-%d") + + studied_cycle_files = files_date_times_df[ + ( + (files_date_times_df["date"] >= start_date) + & ((files_date_times_df["date"] <= end_date)) + ) + ]["file_name"] + index = studied_cycle_files.index + + Summary_results = [] + + # get unique channels from one of the files + ds = xarray.open_dataset(studied_cycle_files[index[0]]) + unique_channels = np.unique(ds["Channel_Index"].data).tolist() + print("Total Number of Channels ", len(unique_channels)) + Allchannels_data = {} + for this_channel in unique_channels: + Allchannels_data[this_channel] = np.empty(shape=(0,)) + for this_cycle_obs_file in studied_cycle_files: + ds = xarray.open_dataset(this_cycle_obs_file) + if QC_filter: + QC_bool = ds["QC_Flag"].data == 0.0 + else: + QC_bool = np.ones( + ds["QC_Flag"].data.shape, dtype=bool + ) # this selects all obs as True + for this_channel in unique_channels: + channel_bool = ds["Channel_Index"].data == this_channel + + this_cycle_channel_var_values = ds[var_name].data[ + channel_bool * QC_bool + ] + Allchannels_data[this_channel] = np.append( + Allchannels_data[this_channel], this_cycle_channel_var_values + ) + + for this_channel in unique_channels: + this_channel_values = Allchannels_data[this_channel] + squared_values = [x**2 for x in this_channel_values] + mean_of_squares = sum(squared_values) / len(squared_values) + rms_value = mean_of_squares**0.5 + Summary_results.append( + [ + this_channel, + np.size(this_channel_values), + np.std(this_channel_values), + np.mean(this_channel_values), + rms_value, + ] + ) + + Summary_resultsDF = pd.DataFrame( + Summary_results, columns=["channel", "count", "std", "mean", "rms"] + ) + # Plotting + plt.figure(figsize=(10, 6)) + plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["count"], s=50) + plt.xlabel("Channel") + plt.ylabel("Count") + plt.title("%s %s" % ((self.sensor, var_name))) + plt.grid(True) + plt.tight_layout() + plt.savefig( + os.path.join( + output_path, "%s_%s_sumamryCounts.png" % (self.sensor, var_name) + ) + ) + plt.close() + + # Plotting scatter plot for mean and std + plt.figure(figsize=(15, 6)) + plt.scatter( + Summary_resultsDF["channel"], + Summary_resultsDF["mean"], + s=50, + c="green", + label="Mean", + ) + plt.scatter( + Summary_resultsDF["channel"], + Summary_resultsDF["std"], + s=50, + c="red", + label="Std", + ) + plt.scatter( + Summary_resultsDF["channel"], + Summary_resultsDF["rms"], + s=50, + label="Rms", + facecolors="none", + edgecolors="blue", + ) + plt.xlabel("Channel") + plt.ylabel("Statistics") + plt.title("%s %s" % ((self.sensor, var_name))) + plt.grid(True) + plt.tight_layout() + plt.legend() + plt.savefig( + os.path.join(output_path, "%s_%s_mean_std.png" % (self.sensor, var_name)) + ) + + return Summary_resultsDF + + +def main( + input_path, + output_path, + sensor, + var_name, + channel, + grid_size, + qc_flag, + region, + start_date, + end_date, + filter_by_vars, +): + # Initialize SpatialTemporalStats object + my_tool = SpatialTemporalStats() + + # Generate grid + my_tool.generate_grid(grid_size) # Call generate_grid method) + print("grid created!") + + # Read observational values and perform analysis + o_minus_f_gdf = my_tool.read_obs_values( + input_path, + sensor, + var_name, + channel, + start_date, + end_date, + filter_by_vars, + qc_flag, + ) + + print("read obs values!") + + # Plot observations + print("creating plots...") + + my_tool.plot_obs(o_minus_f_gdf, var_name, region, grid_size, output_path) + print("Time/Area stats plots created!") + + # Make summary plots + print("Creating summary plots...") + summary_results = my_tool.make_summary_plots( + input_path, sensor, var_name, start_date, end_date, qc_flag, output_path + ) + summary_results.to_csv( + os.path.join(output_path, "%s_summary.csv" % sensor), index=False + ) + print("Summary plots created!") + + +def parse_filter(s): + try: + var_name, comparison, threshold = s.split(",") + if comparison not in ("lt", "gt"): + raise ValueError("Comparison must be 'lt' or 'gt'") + return (var_name, comparison, float(threshold)) + except ValueError: + raise argparse.ArgumentTypeError( + "Filter must be in format 'var_name,comparison,threshold'" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Python Tool for Spatial and Temporal Analysis" + ) + + # Add arguments + parser.add_argument( + "-input", + dest="input_path", + help=r"REQUIRED: path to input config nc files", + required=True, + metavar="DIR", + type=str, + ) + parser.add_argument( + "-output", + dest="output_path", + help=r"REQUIRED: path to output files", + required=True, + metavar="DIR", + type=str, + ) + + parser.add_argument( + "-sensor", + dest="sensor", + help=r"REQUIRED: sensor name", + required=True, + metavar="string", + type=str, + ) + parser.add_argument( + "-var", + dest="var_name", + help=r"REQUIRED: variable name", + required=True, + metavar="string", + type=str, + ) + parser.add_argument( + "-ch", + dest="channel", + help=r"REQUIRED the channel number", + required=True, + metavar="integer", + type=int, + ) + parser.add_argument( + "-grid", + dest="grid_size", + help=r"optional: size of grid for plotting (choices: 0.5, 1, 2)", + required=False, + default=1, + metavar="float", + type=float, + ) + parser.add_argument( + "-no_qc_flag", + dest="no_qc_flag", + help=r"Optional: qc flag for filtering", + action="store_true", + ) + parser.add_argument( + "-region", + dest="region", + help="REQUIRED: region for mapplot. 1: global, 2: polar region, 3: mid-latitudes region," + "4: tropics region, 5:southern mid-latitudes region, 6: southern polar region", + required=False, + default=0, + metavar="integer", + type=int, + ) + parser.add_argument( + "-sdate", + dest="start_date", + help=r"REQUIRED: start date of evaluation", + required=False, + default=0, + metavar="string", + type=str, + ) + parser.add_argument( + "-edate", + dest="end_date", + help=r"REQUIRED: end date of evaluation", + required=False, + default=0, + metavar="string", + type=str, + ) + + # New argument for filter criteria + parser.add_argument( + "-filter_by_vars", + dest="filter_by_vars", + help="Optional: Filtering criteria in format 'var_name,comparison," + "threshold'. Example: Land_Fraction,lt,0.9", + nargs="+", + type=parse_filter, + default=[], + ) + + args = vars(parser.parse_args()) + + input_path = args["input_path"] + output_path = args["output_path"] + sensor = args["sensor"] + var_name = args["var_name"] + channel = args["channel"] + grid_size = args["grid_size"] + region = args["region"] + start_date = args["start_date"] + end_date = args["end_date"] + + if args["no_qc_flag"]: + qc_flag = False + else: + qc_flag = True + + # Accessing and printing the parsed filter criteria + if args["filter_by_vars"]: + for filter_criteria in args["filter_by_vars"]: + print( + f"Variable: {filter_criteria[0]}," + f"Comparison: {filter_criteria[1]}," + f"Threshold: {filter_criteria[2]}" + ) + + main( + input_path, + output_path, + sensor, + var_name, + channel, + grid_size, + qc_flag, + region, + start_date, + end_date, + args["filter_by_vars"], + ) diff --git a/ush/SpatialTemporalStatsTool/user_Analysis.py b/ush/SpatialTemporalStatsTool/user_Analysis.py deleted file mode 100644 index 7f69c82..0000000 --- a/ush/SpatialTemporalStatsTool/user_Analysis.py +++ /dev/null @@ -1,70 +0,0 @@ -from SpatialTemporalStats import SpatialTemporalStats - -# Set input and output paths -input_path = "/PATH/TO/Input/Files" -output_path = r'./Results' - -# Set sensor name -sensor = "iasi_metop-c" - -# Set variable name and channel number -var_name = "Obs_Minus_Forecast_adjusted" -channel_no = 1 - -# Set start and end dates -start_date, end_date = "2024-01-01", "2024-01-31" - -# Set region -# 1: global, 2: polar region, 3: mid-latitudes region, -# 4: tropics region, 5:southern mid-latitudes region, 6: southern polar region -region = 1 - -# Initialize SpatialTemporalStats object -my_tool = SpatialTemporalStats() - -# Set resolution for grid generation -resolution = 2 - -# Generate grid -my_tool.generate_grid(resolution) # Call generate_grid method) -print("grid created!") - -# Set QC filter -QC_filter = True # should be always False or true - -# Set filter by variables -# can be an empty list -filter_by_vars = [] - -# filter_by_vars = [("Land_Fraction", "lt", 0.9),] -# list each case in a separate tuple inside this list. -# options are 'lt' or 'gt' for 'less than' and 'greater than' - -# Read observational values and perform analysis -o_minus_f_gdf = my_tool.read_obs_values( - input_path, - sensor, - var_name, - channel_no, - start_date, - end_date, - filter_by_vars, - QC_filter, -) - -print("read obs values!") -# Can save the results in a gpkg file -# o_minus_f_gdf.to_file("filename.gpkg", driver='GPKG') - -# Plot observations -print("creating plots...") -my_tool.plot_obs(o_minus_f_gdf, var_name, region, resolution, output_path) -print("Time/Area stats plots created!") - -# Make summary plots -print("Creating summary plots...") -summary_results = my_tool.make_summary_plots( - input_path, sensor, var_name, start_date, end_date, QC_filter, output_path -) -print("Summary plots created!") -# Print summary results