# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7f8d889385b0>]
π‘ Binning temperatures π‘ΒΆ
Letβs again considering the air temperatures dataset. It is sampled at a hourly resolution. We will make βtrailingβ air temperature bins during each day and βresetβ the bin aggregation process at each day change.
import numpy as onp
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.modules import OHLC, HasChanged
register_wax_accessors()
da = xr.tutorial.open_dataset("air_temperature")
da["date"] = da.time.dt.date.astype(onp.datetime64)
def bin_temperature(da):
day_change = HasChanged()(da["date"])
return OHLC()(da["air"], reset_on=day_change)
output, state = da.wax.stream().apply(
bin_temperature, format_dims=onp.array(da.air.dims)
)
output = xr.Dataset(output._asdict())
df = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().loc["2013-01"]
_ = df.plot(figsize=(12, 8))
The UpdateOnEvent
moduleΒΆ
The OHLC
module uses the primitive wax.modules.UpdateOnEvent
.
Its implementation required to complete Haiku with a central function
set_params_or_state_dict
which we have actually integrated in this WAX-ML module.
We have opened an issue on the Haiku github to integrate it in Haiku.