# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]
〰 Compute exponential moving averages with xarray and pandas accessors 〰¶
WAX-ML implements pandas and xarray accessors to ease the usage of machine-learning algorithms with high-level data APIs :
pandas’s
DataFrame
andSeries
xarray’s
Dataset
andDataArray
.
These accessors allow to easily execute any function using Haiku modules on these data containers.
For instance, WAX-ML propose an implementation of the exponential moving average realized with this mechanism.
Let’s show how it works.
Load accessors¶
First you need to load accessors:
from wax.accessors import register_wax_accessors
register_wax_accessors()
EWMA on dataframes¶
Let’s look at a simple example: The exponential moving average (EWMA).
Let’s apply the EWMA algorithm to the NCEP/NCAR ‘s Air temperature data.
🌡 Load temperature dataset 🌡¶
import xarray as xr
dataset = xr.tutorial.open_dataset("air_temperature")
Let’s see what this dataset looks like:
dataset
<xarray.Dataset> Dimensions: (lat: 25, lon: 53, time: 2920) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 ... Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
- lat: 25
- lon: 53
- time: 2920
- lat(lat)float3275.0 72.5 70.0 ... 20.0 17.5 15.0
- standard_name :
- latitude
- long_name :
- Latitude
- units :
- degrees_north
- axis :
- Y
array([75. , 72.5, 70. , 67.5, 65. , 62.5, 60. , 57.5, 55. , 52.5, 50. , 47.5, 45. , 42.5, 40. , 37.5, 35. , 32.5, 30. , 27.5, 25. , 22.5, 20. , 17.5, 15. ], dtype=float32)
- lon(lon)float32200.0 202.5 205.0 ... 327.5 330.0
- standard_name :
- longitude
- long_name :
- Longitude
- units :
- degrees_east
- axis :
- X
array([200. , 202.5, 205. , 207.5, 210. , 212.5, 215. , 217.5, 220. , 222.5, 225. , 227.5, 230. , 232.5, 235. , 237.5, 240. , 242.5, 245. , 247.5, 250. , 252.5, 255. , 257.5, 260. , 262.5, 265. , 267.5, 270. , 272.5, 275. , 277.5, 280. , 282.5, 285. , 287.5, 290. , 292.5, 295. , 297.5, 300. , 302.5, 305. , 307.5, 310. , 312.5, 315. , 317.5, 320. , 322.5, 325. , 327.5, 330. ], dtype=float32)
- time(time)datetime64[ns]2013-01-01 ... 2014-12-31T18:00:00
- standard_name :
- time
- long_name :
- Time
array(['2013-01-01T00:00:00.000000000', '2013-01-01T06:00:00.000000000', '2013-01-01T12:00:00.000000000', ..., '2014-12-31T06:00:00.000000000', '2014-12-31T12:00:00.000000000', '2014-12-31T18:00:00.000000000'], dtype='datetime64[ns]')
- air(time, lat, lon)float32...
- long_name :
- 4xDaily Air temperature at sigma level 995
- units :
- degK
- precision :
- 2
- GRIB_id :
- 11
- GRIB_name :
- TMP
- var_desc :
- Air temperature
- dataset :
- NMC Reanalysis
- level_desc :
- Surface
- statistic :
- Individual Obs
- parent_stat :
- Other
- actual_range :
- [185.16 322.1 ]
[3869000 values with dtype=float32]
- Conventions :
- COARDS
- title :
- 4x daily NMC reanalysis (1948)
- description :
- Data is from NMC initialized reanalysis (4x/day). These are the 0.9950 sigma level values.
- platform :
- Model
- references :
- http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html
To compute a EWMA on some variables of a dataset, we usually need to convert data in pandas series or dataframe.
So, let’s convert the dataset into a dataframe to illustrate accessors
on a dataframe:
dataframe = dataset.air.to_series().unstack(["lon", "lat"])
EWMA with pandas¶
air_temp_ewma = dataframe.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
EWMA with WAX-ML¶
air_temp_ewma = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
On small data, WAX-ML’s EWMA is slower than Pandas’ because of the expensive data conversion steps. WAX-ML’s accessors are interesting to use on large data loads (See our three-steps_workflow)
Apply a custom function to a Dataset¶
Now let’s illustrate how WAX-ML accessors work on xarray datasets.
from wax.modules import EWMA
def my_custom_function(dataset):
return {
"air_10": EWMA(1.0 / 10.0)(dataset["air"]),
"air_100": EWMA(1.0 / 100.0)(dataset["air"]),
}
dataset = xr.tutorial.open_dataset("air_temperature")
output, state = dataset.wax.stream().apply(
my_custom_function, format_dims=dataset.air.dims
)
_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))