# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7f8d889385b0>]

🌑 Binning temperatures 🌑¢

Open in Colab

Let’s again considering the air temperatures dataset. It is sampled at a hourly resolution. We will make β€œtrailing” air temperature bins during each day and β€œreset” the bin aggregation process at each day change.

import numpy as onp
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.modules import OHLC, HasChanged

register_wax_accessors()
da = xr.tutorial.open_dataset("air_temperature")
da["date"] = da.time.dt.date.astype(onp.datetime64)
def bin_temperature(da):
    day_change = HasChanged()(da["date"])
    return OHLC()(da["air"], reset_on=day_change)


output, state = da.wax.stream().apply(
    bin_temperature, format_dims=onp.array(da.air.dims)
)
output = xr.Dataset(output._asdict())
df = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().loc["2013-01"]
_ = df.plot(figsize=(12, 8))
../_images/03_ohlc_temperature_9_0.png

The UpdateOnEvent moduleΒΆ

The OHLC module uses the primitive wax.modules.UpdateOnEvent.

Its implementation required to complete Haiku with a central function set_params_or_state_dict which we have actually integrated in this WAX-ML module.

We have opened an issue on the Haiku github to integrate it in Haiku.