# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]
π‘ Binning temperatures π‘ΒΆ
Letβs again considering the air temperatures dataset. It is sampled at an hourly resolution. We will make βtrailingβ air temperature bins during each day and βresetβ the bin aggregation process at each day change.
import numpy as onp
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.modules import OHLC, HasChanged
register_wax_accessors()
dataset = xr.tutorial.open_dataset("air_temperature")
dataset["date"] = dataset.time.dt.date.astype(onp.datetime64)
dataset
<xarray.Dataset> Dimensions: (lat: 25, lon: 53, time: 2920) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 ... date (time) datetime64[ns] 2013-01-01 2013-01-01 ... 2014-12-31 Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
xarray.Dataset
- lat: 25
- lon: 53
- time: 2920
- lat(lat)float3275.0 72.5 70.0 ... 20.0 17.5 15.0
- standard_name :
- latitude
- long_name :
- Latitude
- units :
- degrees_north
- axis :
- Y
array([75. , 72.5, 70. , 67.5, 65. , 62.5, 60. , 57.5, 55. , 52.5, 50. , 47.5, 45. , 42.5, 40. , 37.5, 35. , 32.5, 30. , 27.5, 25. , 22.5, 20. , 17.5, 15. ], dtype=float32)
- lon(lon)float32200.0 202.5 205.0 ... 327.5 330.0
- standard_name :
- longitude
- long_name :
- Longitude
- units :
- degrees_east
- axis :
- X
array([200. , 202.5, 205. , 207.5, 210. , 212.5, 215. , 217.5, 220. , 222.5, 225. , 227.5, 230. , 232.5, 235. , 237.5, 240. , 242.5, 245. , 247.5, 250. , 252.5, 255. , 257.5, 260. , 262.5, 265. , 267.5, 270. , 272.5, 275. , 277.5, 280. , 282.5, 285. , 287.5, 290. , 292.5, 295. , 297.5, 300. , 302.5, 305. , 307.5, 310. , 312.5, 315. , 317.5, 320. , 322.5, 325. , 327.5, 330. ], dtype=float32)
- time(time)datetime64[ns]2013-01-01 ... 2014-12-31T18:00:00
- standard_name :
- time
- long_name :
- Time
array(['2013-01-01T00:00:00.000000000', '2013-01-01T06:00:00.000000000', '2013-01-01T12:00:00.000000000', ..., '2014-12-31T06:00:00.000000000', '2014-12-31T12:00:00.000000000', '2014-12-31T18:00:00.000000000'], dtype='datetime64[ns]')
- air(time, lat, lon)float32...
- long_name :
- 4xDaily Air temperature at sigma level 995
- units :
- degK
- precision :
- 2
- GRIB_id :
- 11
- GRIB_name :
- TMP
- var_desc :
- Air temperature
- dataset :
- NMC Reanalysis
- level_desc :
- Surface
- statistic :
- Individual Obs
- parent_stat :
- Other
- actual_range :
- [185.16 322.1 ]
[3869000 values with dtype=float32]
- date(time)datetime64[ns]2013-01-01 ... 2014-12-31
array(['2013-01-01T00:00:00.000000000', '2013-01-01T00:00:00.000000000', '2013-01-01T00:00:00.000000000', ..., '2014-12-31T00:00:00.000000000', '2014-12-31T00:00:00.000000000', '2014-12-31T00:00:00.000000000'], dtype='datetime64[ns]')
- Conventions :
- COARDS
- title :
- 4x daily NMC reanalysis (1948)
- description :
- Data is from NMC initialized reanalysis (4x/day). These are the 0.9950 sigma level values.
- platform :
- Model
- references :
- http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html
def bin_temperature(da):
day_change = HasChanged()(da["date"])
return OHLC()(da["air"], reset_on=day_change)
output, state = dataset.wax.stream().apply(
bin_temperature, format_dims=onp.array(dataset.air.dims)
)
output = xr.Dataset(output._asdict())
df = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().loc["2013-01"]
_ = df.plot(figsize=(12, 8), title="Trailing Open-High-Low-Close temperatures")
The UpdateOnEvent
moduleΒΆ
The OHLC
module uses the primitive wax.modules.UpdateOnEvent
.
Its implementation required to complete Haiku with a central function
set_params_or_state_dict
which we have actually integrated in this WAX-ML module.
We have opened an issue on the Haiku github to integrate it in Haiku.