我想使用xarray
功能通过自定义/外部函数在命名维度上减少数据集。
创建数据集来演示问题
import xarray as xr
import numpy as np
import pandas as pd
time = pd.date_range("2000-01-01", "2001-01-01", freq="D")
sids = np.arange(4)
obs = np.random.random(size=(len(time), len(sids)))
sim = np.random.random(size=(len(time), len(sids)))
original = xr.Dataset({"obs": (("time", "station_id"), obs), "sim": (("time", "station_id"), sim)}, coords={"time": time, "station_id": sids})
我想用原始的两个变量来计算mean_squared_error
,通过压缩"time"
维来计算度量。这应该返回一个xr.Dataset
,如下所示:
<xarray.Dataset>
Dimensions: (station_id: 4)
Coordinates:
* station_id (station_id) int64 0 1 2 3
Data variables:
mean_squared_error (station_id) float64 0.4411 0.183 0.06754 0.9662
我已经尝试使用reduce
函数
from sklearn.metrics import mean_squared_error
original.reduce(mean_squared_error, dim="time")
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-243-51111f05437b> in <module>
----> 1 original.reduce(mean_squared_error, dim="time")
~/miniconda3/envs/ml/lib/python3.8/site-packages/xarray/core/dataset.py in reduce(self, func, dim, keep_attrs, keepdims, numeric_only, **kwargs)
4915 # the former is often more efficient
4916 reduce_dims = None # type: ignore[assignment]
-> 4917 variables[name] = var.reduce(
4918 func,
4919 dim=reduce_dims,
~/miniconda3/envs/ml/lib/python3.8/site-packages/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
1721 )
1722 if axis is not None:
-> 1723 data = func(self.data, axis=axis, **kwargs)
1724 else:
1725 data = func(self.data, **kwargs)
~/miniconda3/envs/ml/lib/python3.8/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
70 FutureWarning)
71 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72 return f(**kwargs)
73 return inner_f
74
TypeError: mean_squared_error() got an unexpected keyword argument 'axis'
有一个叫做xskillscore的包,它有一个计算MSE的方法。
pip install xskillscore
xskillscore.mse(original.obs, original.sim, 'time')
我相信这是可行的:
np.sqrt(np.square(original["sim"] - original["obs"]).mean(dim="time"))
一种解决方案不使用xarray的内部函数,而是要求遍历所有维度station_id
。
from collections import defaultdict
# calculate error metric
out = defaultdict(list)
for sid in original.station_id.values:
data = original.sel(station_id=sid)
orig_err = np.sqrt(mean_squared_error(data["obs"], data["sim"]))
out["original"].append(orig_err)
out["station_id"].append(sid)
rmse = pd.DataFrame(out).set_index("station_id").to_xarray()
这为您提供了解决方案,但不使用xarray
的内部广播功能,因此将难以处理更大的数据集。