Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: xarray grid output #1477

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
140 changes: 111 additions & 29 deletions pyart/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,47 +326,129 @@ def to_xarray(self):
y = self.y["data"]
x = self.x["data"]

time = np.array([num2date(self.time["data"][0], self.time["units"])])
time = np.array(
[num2date(self.time["data"][0], units=self.time["units"])],
)

ds = xarray.Dataset()
for field in list(self.fields.keys()):
field_data = self.fields[field]["data"]
for field, field_info in self.fields.items():
field_data = field_info["data"]
data = xarray.DataArray(
np.ma.expand_dims(field_data, 0),
dims=("time", "z", "y", "x"),
coords={
"time": (["time"], time),
"z": (["z"], z),
"time": time,
"z": z,
"lat": (["y", "x"], lat),
"lon": (["y", "x"], lon),
"y": (["y"], y),
"x": (["x"], x),
"y": y,
"x": x,
},
)
for meta in list(self.fields[field].keys()):

for meta, value in field_info.items():
if meta != "data":
data.attrs.update({meta: self.fields[field][meta]})
data.attrs.update({meta: value})

ds[field] = data
ds.lon.attrs = [
("long_name", "longitude of grid cell center"),
("units", "degree_E"),
("standard_name", "Longitude"),
]
ds.lat.attrs = [
("long_name", "latitude of grid cell center"),
("units", "degree_N"),
("standard_name", "Latitude"),
]

ds.z.attrs = get_metadata("z")
ds.y.attrs = get_metadata("y")
ds.x.attrs = get_metadata("x")

ds.z.encoding["_FillValue"] = None
ds.lat.encoding["_FillValue"] = None
ds.lon.encoding["_FillValue"] = None
ds.close()

ds.lon.attrs = [
("long_name", "longitude of grid cell center"),
("units", "degree_E"),
("standard_name", "Longitude"),
]
ds.lat.attrs = [
("long_name", "latitude of grid cell center"),
("units", "degree_N"),
("standard_name", "Latitude"),
]

ds.z.attrs = get_metadata("z")
ds.y.attrs = get_metadata("y")
ds.x.attrs = get_metadata("x")

for attr in [ds.z, ds.lat, ds.lon]:
attr.encoding["_FillValue"] = None

# Delayed import
from ..io.grid_io import _make_coordinatesystem_dict

ds.coords["ProjectionCoordinateSystem"] = xarray.DataArray(
data=np.array(1, dtype="int32"),
attrs=_make_coordinatesystem_dict(self),
)

# write the projection dictionary as a scalar
projection = self.projection.copy()
# NetCDF does not support boolean attribute, covert to string
if "_include_lon_0_lat_0" in projection:
include = projection["_include_lon_0_lat_0"]
projection["_include_lon_0_lat_0"] = ["false", "true"][include]
ds.coords["projection"] = xarray.DataArray(
data=np.array(1, dtype="int32"),
dims=None,
attrs=projection,
)

for attr_name in [
"origin_latitude",
"origin_longitude",
"origin_altitude",
"radar_altitude",
"radar_latitude",
"radar_longitude",
"radar_time",
]:
if hasattr(self, attr_name):
attr_data = getattr(self, attr_name)
if attr_data is not None:
if attr_name in [
"origin_latitude",
"origin_longitude",
"origin_altitude",
]:
# Adjusting the dims to 'time' for the origin attributes
attr_value = np.ma.expand_dims(attr_data["data"][0], 0)
dims = ("time",)
else:
if "radar_time" not in attr_name:
attr_value = np.ma.expand_dims(attr_data["data"][0], 0)
else:
attr_value = [
np.array(
num2date(
attr_data["data"][0],
units=attr_data["units"],
),
dtype="datetime64[ns]",
)
]
dims = ("nradar",)

ds.coords[attr_name] = xarray.DataArray(
attr_value, dims=dims, attrs=get_metadata(attr_name)
)

if "radar_time" in ds.variables:
ds.radar_time.attrs.pop("calendar")

if self.radar_name is not None:
radar_name = self.radar_name["data"]
ds["radar_name"] = xarray.DataArray(
np.array([b"".join(radar_name)]),
mgrover1 marked this conversation as resolved.
Show resolved Hide resolved
dims=("nradar"),
attrs=get_metadata("radar_name"),
)

ds.attrs = self.metadata
for key in ds.attrs:
try:
ds.attrs[key] = ds.attrs[key].decode("utf-8")
except AttributeError:
# If the attribute is not a byte string, just pass
pass

ds.close()
return ds

def add_field(self, field_name, field_dict, replace_existing=False):
Expand All @@ -389,7 +471,7 @@ def add_field(self, field_name, field_dict, replace_existing=False):
if "data" not in field_dict:
raise KeyError('Field dictionary must contain a "data" key')
if field_name in self.fields and replace_existing is False:
raise ValueError(f"A field named {field_name} already exists")
raise ValueError("A field named %s already exists" % (field_name))
if field_dict["data"].shape != (self.nz, self.ny, self.nx):
raise ValueError("Field has invalid shape")

Expand Down