Welcome to WAX-ML¶
WAX-ML is a library for machine-learning on streaming data.
For an introduction to WAX-ML, start at the WAX-ML GitHub page.
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]
〰 Compute exponential moving averages with xarray and pandas accessors 〰¶
WAX-ML implements pandas and xarray accessors to ease the usage of machine-learning algorithms with high-level data APIs :
pandas’s
DataFrame
andSeries
xarray’s
Dataset
andDataArray
.
These accessors allow to easily execute any function using Haiku modules on these data containers.
For instance, WAX-ML propose an implementation of the exponential moving average realized with this mechanism.
Let’s show how it works.
Load accessors¶
First you need to load accessors:
from wax.accessors import register_wax_accessors
register_wax_accessors()
EWMA on dataframes¶
Let’s look at a simple example: The exponential moving average (EWMA).
Let’s apply the EWMA algorithm to the NCEP/NCAR ‘s Air temperature data.
🌡 Load temperature dataset 🌡¶
import xarray as xr
dataset = xr.tutorial.open_dataset("air_temperature")
Let’s see what this dataset looks like:
dataset
<xarray.Dataset> Dimensions: (lat: 25, lon: 53, time: 2920) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 ... Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
- lat: 25
- lon: 53
- time: 2920
- lat(lat)float3275.0 72.5 70.0 ... 20.0 17.5 15.0
- standard_name :
- latitude
- long_name :
- Latitude
- units :
- degrees_north
- axis :
- Y
array([75. , 72.5, 70. , 67.5, 65. , 62.5, 60. , 57.5, 55. , 52.5, 50. , 47.5, 45. , 42.5, 40. , 37.5, 35. , 32.5, 30. , 27.5, 25. , 22.5, 20. , 17.5, 15. ], dtype=float32)
- lon(lon)float32200.0 202.5 205.0 ... 327.5 330.0
- standard_name :
- longitude
- long_name :
- Longitude
- units :
- degrees_east
- axis :
- X
array([200. , 202.5, 205. , 207.5, 210. , 212.5, 215. , 217.5, 220. , 222.5, 225. , 227.5, 230. , 232.5, 235. , 237.5, 240. , 242.5, 245. , 247.5, 250. , 252.5, 255. , 257.5, 260. , 262.5, 265. , 267.5, 270. , 272.5, 275. , 277.5, 280. , 282.5, 285. , 287.5, 290. , 292.5, 295. , 297.5, 300. , 302.5, 305. , 307.5, 310. , 312.5, 315. , 317.5, 320. , 322.5, 325. , 327.5, 330. ], dtype=float32)
- time(time)datetime64[ns]2013-01-01 ... 2014-12-31T18:00:00
- standard_name :
- time
- long_name :
- Time
array(['2013-01-01T00:00:00.000000000', '2013-01-01T06:00:00.000000000', '2013-01-01T12:00:00.000000000', ..., '2014-12-31T06:00:00.000000000', '2014-12-31T12:00:00.000000000', '2014-12-31T18:00:00.000000000'], dtype='datetime64[ns]')
- air(time, lat, lon)float32...
- long_name :
- 4xDaily Air temperature at sigma level 995
- units :
- degK
- precision :
- 2
- GRIB_id :
- 11
- GRIB_name :
- TMP
- var_desc :
- Air temperature
- dataset :
- NMC Reanalysis
- level_desc :
- Surface
- statistic :
- Individual Obs
- parent_stat :
- Other
- actual_range :
- [185.16 322.1 ]
[3869000 values with dtype=float32]
- Conventions :
- COARDS
- title :
- 4x daily NMC reanalysis (1948)
- description :
- Data is from NMC initialized reanalysis (4x/day). These are the 0.9950 sigma level values.
- platform :
- Model
- references :
- http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html
To compute a EWMA on some variables of a dataset, we usually need to convert data in pandas series or dataframe.
So, let’s convert the dataset into a dataframe to illustrate accessors
on a dataframe:
dataframe = dataset.air.to_series().unstack(["lon", "lat"])
EWMA with pandas¶
air_temp_ewma = dataframe.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()

EWMA with WAX-ML¶
air_temp_ewma = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()

On small data, WAX-ML’s EWMA is slower than Pandas’ because of the expensive data conversion steps. WAX-ML’s accessors are interesting to use on large data loads (See our three-steps_workflow)
Apply a custom function to a Dataset¶
Now let’s illustrate how WAX-ML accessors work on xarray datasets.
from wax.modules import EWMA
def my_custom_function(dataset):
return {
"air_10": EWMA(1.0 / 10.0)(dataset["air"]),
"air_100": EWMA(1.0 / 100.0)(dataset["air"]),
}
dataset = xr.tutorial.open_dataset("air_temperature")
output, state = dataset.wax.stream().apply(
my_custom_function, format_dims=dataset.air.dims
)
_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]
⏱ Synchronize data streams ⏱¶
Physicists, and not the least 😅, have brought a solution to the synchronization problem. See Poincaré-Einstein synchronization Wikipedia page for more details.
In WAX-ML we strive to follow their recommendations and implement a synchronization mechanism between different data streams. Using the terminology of Henri Poincaré (see link above), we introduce the notion of “local time” to unravel the stream in which the user wants to apply transformations. We call the other streams “secondary streams”. They can work at different frequencies, lower or higher. The data from these secondary streams will be represented in the “local time” either with the use of a forward filling mechanism for lower frequencies or a buffering mechanism for higher frequencies.
We implement a “data tracing” mechanism to optimize access to out-of-sync streams. This mechanism works on in-memory data. We perform the first pass on the data, without actually accessing it, and determine the indices necessary to later access the data. Doing so we are vigilant to not let any “future” information pass through and thus guaranty a data processing that respects causality.
The buffering mechanism used in the case of higher frequencies works with a fixed
buffer size (see the WAX-ML module
wax.modules.Buffer
)
which allows us to use JAX / XLA optimizations and have efficient processing.
Let’s illustrate with a small example how wax.stream.Stream
synchronizes data streams.
Let’s use the dataset “air temperature” with :
An air temperature is defined with hourly resolution.
A “fake” ground temperature is defined with a daily resolution as the air temperature minus 10 degrees.
import xarray as xr
dataset = xr.tutorial.open_dataset("air_temperature")
dataset["ground"] = dataset.air.resample(time="d").last().rename({"time": "day"}) - 10
Let’s see what this dataset looks like:
dataset
<xarray.Dataset> Dimensions: (day: 730, lat: 25, lon: 53, time: 2920) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 * day (day) datetime64[ns] 2013-01-01 2013-01-02 ... 2014-12-31 Data variables: air (time, lat, lon) float32 ... ground (day, lat, lon) float32 231.9 231.8 231.8 ... 286.5 286.2 285.7 Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
- day: 730
- lat: 25
- lon: 53
- time: 2920
- lat(lat)float3275.0 72.5 70.0 ... 20.0 17.5 15.0
- standard_name :
- latitude
- long_name :
- Latitude
- units :
- degrees_north
- axis :
- Y
array([75. , 72.5, 70. , 67.5, 65. , 62.5, 60. , 57.5, 55. , 52.5, 50. , 47.5, 45. , 42.5, 40. , 37.5, 35. , 32.5, 30. , 27.5, 25. , 22.5, 20. , 17.5, 15. ], dtype=float32)
- lon(lon)float32200.0 202.5 205.0 ... 327.5 330.0
- standard_name :
- longitude
- long_name :
- Longitude
- units :
- degrees_east
- axis :
- X
array([200. , 202.5, 205. , 207.5, 210. , 212.5, 215. , 217.5, 220. , 222.5, 225. , 227.5, 230. , 232.5, 235. , 237.5, 240. , 242.5, 245. , 247.5, 250. , 252.5, 255. , 257.5, 260. , 262.5, 265. , 267.5, 270. , 272.5, 275. , 277.5, 280. , 282.5, 285. , 287.5, 290. , 292.5, 295. , 297.5, 300. , 302.5, 305. , 307.5, 310. , 312.5, 315. , 317.5, 320. , 322.5, 325. , 327.5, 330. ], dtype=float32)
- time(time)datetime64[ns]2013-01-01 ... 2014-12-31T18:00:00
- standard_name :
- time
- long_name :
- Time
array(['2013-01-01T00:00:00.000000000', '2013-01-01T06:00:00.000000000', '2013-01-01T12:00:00.000000000', ..., '2014-12-31T06:00:00.000000000', '2014-12-31T12:00:00.000000000', '2014-12-31T18:00:00.000000000'], dtype='datetime64[ns]')
- day(day)datetime64[ns]2013-01-01 ... 2014-12-31
array(['2013-01-01T00:00:00.000000000', '2013-01-02T00:00:00.000000000', '2013-01-03T00:00:00.000000000', ..., '2014-12-29T00:00:00.000000000', '2014-12-30T00:00:00.000000000', '2014-12-31T00:00:00.000000000'], dtype='datetime64[ns]')
- air(time, lat, lon)float32...
- long_name :
- 4xDaily Air temperature at sigma level 995
- units :
- degK
- precision :
- 2
- GRIB_id :
- 11
- GRIB_name :
- TMP
- var_desc :
- Air temperature
- dataset :
- NMC Reanalysis
- level_desc :
- Surface
- statistic :
- Individual Obs
- parent_stat :
- Other
- actual_range :
- [185.16 322.1 ]
[3869000 values with dtype=float32]
- ground(day, lat, lon)float32231.9 231.8 231.8 ... 286.2 285.7
array([[[231.89 , 231.79999, 231.79999, ..., 224.39 , 225.5 , 227.59999], [236.29999, 235.29999, 234.2 , ..., 220.89 , 221.5 , 224.5 ], [246.6 , 244.7 , 242.09999, ..., 220.7 , 221.79999, 226.09999], ..., [286.6 , 286.4 , 286. , ..., 286.5 , 285.79 , 285.29 ], [287. , 287.5 , 287.1 , ..., 286.79 , 286.6 , 286.29 ], [287.5 , 287.69998, 287.5 , ..., 287.79 , 288. , 287.9 ]], [[233.79999, 233.79999, 233.5 , ..., 230.89 , 232.7 , 234.59999], [237.59999, 237.7 , 237.29999, ..., 227.09999, 227.7 , 229.29999], [241.89 , 241.29999, 240.5 , ..., 229.39 , 230.5 , 233.09999], ... [286.59 , 285.88998, 285.29 , ..., 286.88998, 286.29 , 285.38998], [286.69 , 287.49 , 287.29 , ..., 286.69 , 286.29 , 285.59 ], [287.79 , 288.49 , 288.38998, ..., 287.38998, 286.88998, 286.09 ]], [[235.09 , 234.29 , 233.29 , ..., 231.68999, 231.48999, 231.79 ], [239.89 , 239.29 , 238.39 , ..., 229.59 , 230.29 , 231.68999], [252.98999, 252.19 , 251.38998, ..., 229.89 , 232.59 , 236.29 ], ..., [283.79 , 283.69 , 285.09 , ..., 285.29 , 285.09 , 284.69 ], [286.09 , 286.88998, 287.19 , ..., 285.69 , 285.69 , 285.19 ], [287.69 , 288.09 , 288.09 , ..., 286.49 , 286.19 , 285.69 ]]], dtype=float32)
- Conventions :
- COARDS
- title :
- 4x daily NMC reanalysis (1948)
- description :
- Data is from NMC initialized reanalysis (4x/day). These are the 0.9950 sigma level values.
- platform :
- Model
- references :
- http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html
from wax.accessors import register_wax_accessors
register_wax_accessors()
from wax.modules import EWMA
def my_custom_function(dataset):
return {
"air_10": EWMA(1.0 / 10.0)(dataset["air"]),
"air_100": EWMA(1.0 / 100.0)(dataset["air"]),
"ground_100": EWMA(1.0 / 100.0)(dataset["ground"]),
}
results, state = dataset.wax.stream(
local_time="time", ffills={"day": 1}, pbar=True
).apply(my_custom_function, format_dims=dataset.air.dims)
_ = results.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]
🌡 Binning temperatures 🌡¶
Let’s again considering the air temperatures dataset. It is sampled at an hourly resolution. We will make “trailing” air temperature bins during each day and “reset” the bin aggregation process at each day change.
import numpy as onp
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.modules import OHLC, HasChanged
register_wax_accessors()
dataset = xr.tutorial.open_dataset("air_temperature")
dataset["date"] = dataset.time.dt.date.astype(onp.datetime64)
dataset
<xarray.Dataset> Dimensions: (lat: 25, lon: 53, time: 2920) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 ... date (time) datetime64[ns] 2013-01-01 2013-01-01 ... 2014-12-31 Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
- lat: 25
- lon: 53
- time: 2920
- lat(lat)float3275.0 72.5 70.0 ... 20.0 17.5 15.0
- standard_name :
- latitude
- long_name :
- Latitude
- units :
- degrees_north
- axis :
- Y
array([75. , 72.5, 70. , 67.5, 65. , 62.5, 60. , 57.5, 55. , 52.5, 50. , 47.5, 45. , 42.5, 40. , 37.5, 35. , 32.5, 30. , 27.5, 25. , 22.5, 20. , 17.5, 15. ], dtype=float32)
- lon(lon)float32200.0 202.5 205.0 ... 327.5 330.0
- standard_name :
- longitude
- long_name :
- Longitude
- units :
- degrees_east
- axis :
- X
array([200. , 202.5, 205. , 207.5, 210. , 212.5, 215. , 217.5, 220. , 222.5, 225. , 227.5, 230. , 232.5, 235. , 237.5, 240. , 242.5, 245. , 247.5, 250. , 252.5, 255. , 257.5, 260. , 262.5, 265. , 267.5, 270. , 272.5, 275. , 277.5, 280. , 282.5, 285. , 287.5, 290. , 292.5, 295. , 297.5, 300. , 302.5, 305. , 307.5, 310. , 312.5, 315. , 317.5, 320. , 322.5, 325. , 327.5, 330. ], dtype=float32)
- time(time)datetime64[ns]2013-01-01 ... 2014-12-31T18:00:00
- standard_name :
- time
- long_name :
- Time
array(['2013-01-01T00:00:00.000000000', '2013-01-01T06:00:00.000000000', '2013-01-01T12:00:00.000000000', ..., '2014-12-31T06:00:00.000000000', '2014-12-31T12:00:00.000000000', '2014-12-31T18:00:00.000000000'], dtype='datetime64[ns]')
- air(time, lat, lon)float32...
- long_name :
- 4xDaily Air temperature at sigma level 995
- units :
- degK
- precision :
- 2
- GRIB_id :
- 11
- GRIB_name :
- TMP
- var_desc :
- Air temperature
- dataset :
- NMC Reanalysis
- level_desc :
- Surface
- statistic :
- Individual Obs
- parent_stat :
- Other
- actual_range :
- [185.16 322.1 ]
[3869000 values with dtype=float32]
- date(time)datetime64[ns]2013-01-01 ... 2014-12-31
array(['2013-01-01T00:00:00.000000000', '2013-01-01T00:00:00.000000000', '2013-01-01T00:00:00.000000000', ..., '2014-12-31T00:00:00.000000000', '2014-12-31T00:00:00.000000000', '2014-12-31T00:00:00.000000000'], dtype='datetime64[ns]')
- Conventions :
- COARDS
- title :
- 4x daily NMC reanalysis (1948)
- description :
- Data is from NMC initialized reanalysis (4x/day). These are the 0.9950 sigma level values.
- platform :
- Model
- references :
- http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html
def bin_temperature(da):
day_change = HasChanged()(da["date"])
return OHLC()(da["air"], reset_on=day_change)
output, state = dataset.wax.stream().apply(
bin_temperature, format_dims=onp.array(dataset.air.dims)
)
output = xr.Dataset(output._asdict())
df = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().loc["2013-01"]
_ = df.plot(figsize=(12, 8), title="Trailing Open-High-Low-Close temperatures")

The UpdateOnEvent
module¶
The OHLC
module uses the primitive wax.modules.UpdateOnEvent
.
Its implementation required to complete Haiku with a central function
set_params_or_state_dict
which we have actually integrated in this WAX-ML module.
We have opened an issue on the Haiku github to integrate it in Haiku.
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x17678ceb0>]
🎛 The 3-steps workflow 🎛¶
It is already very useful to be able to execute a JAX function on a dataframe in a single work step and with a single command line thanks to WAX-ML accessors.
The 1-step WAX-ML’s stream API works like that:
<data-container>.stream(...).apply(...)
But this is not optimal because, under the hood, there are mainly three costly steps:
(1) (synchronize | data tracing | encode): make the data “JAX ready”
(2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.
(3) (format): convert data back to pandas/xarray/numpy format.
With the wax.stream
primitives, it is quite easy to explicitly split the 1-step workflow
into a 3-step workflow.
This will allow the user to have full control over each step and iterate on each one.
It is actually very useful to iterate on step (2), the “calculation step” when
you are doing research.
You can then take full advantage of the JAX primitives, especially the jit
primitive.
Let’s illustrate how to reimplement WAX-ML EWMA yourself with the WAX-ML 3-step workflow.
Imports¶
import numpy as onp
import pandas as pd
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.external.eagerpy import convert_to_tensors
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import unroll
register_wax_accessors()
Performance on big dataframes¶
Generate data¶
T = 1.0e5
N = 1000
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)
pandas EWMA¶
%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 2.03 s, sys: 167 ms, total: 2.19 s
Wall time: 2.19 s
WAX-ML EWMA¶
%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 1.8 s, sys: 876 ms, total: 2.68 s
Wall time: 2.67 s
It’s a little faster, but not that much faster…
WAX-ML EWMA (without format step)¶
Let’s disable the final formatting step (the output is now in raw JAX format):
%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()
df_ewma_wax_no_format.block_until_ready()
CPU times: user 1.8 s, sys: 1 s, total: 2.8 s
Wall time: 3.55 s
DeviceArray([[-2.9661492e-01, -3.2103235e-01, -9.6144844e-03, ...,
4.4276214e-01, 1.1568004e+00, -1.0724162e+00],
[-9.5860586e-02, -2.0366390e-01, -3.7967765e-01, ...,
-6.1594141e-01, 7.7942121e-01, 1.8111229e-02],
[ 2.8211397e-01, 1.7749734e-01, -2.0034584e-01, ...,
-7.6095390e-01, 5.3778893e-01, 3.1442198e-01],
...,
[-1.1917123e-01, 1.3764068e-01, 2.4761766e-01, ...,
2.0842913e-01, 2.5283977e-01, -1.1205430e-01],
[-1.0947308e-01, 3.6484647e-01, 2.4164049e-01, ...,
2.7038181e-01, 2.4539444e-01, 6.2920153e-05],
[ 4.8219025e-02, 1.5648599e-01, 1.2161890e-01, ...,
2.0765728e-01, 8.9837506e-02, 1.0943251e-01]], dtype=float32)
type(df_ewma_wax_no_format)
jaxlib.xla_extension.DeviceArray
Let’s check the device on which the calculation was performed (if you have GPU available, this should be GpuDevice
otherwise it will be CpuDevice
):
df_ewma_wax_no_format.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
Now we will see how to break down WAX-ML one-liners <dataset>.ewm(...).mean()
or <dataset>.stream(...).apply(...)
into 3 steps:
a preparation step where we prepare JAX-ready data and functions.
a processing step where we execute the JAX program
a post-processing step where we format the results in pandas or xarray format.
Generate data (in dataset format)¶
WAX-ML Sream
object works on datasets.
So let’s transform the DataFrame
into a xarray Dataset
:
dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")
del dataframe
Step (1) (synchronize | data tracing | encode)¶
In this step, WAX-ML do:
“data tracing” : prepare the indices for fast access tin the JAX function
access_data
synchronize streams if there is multiple ones. This functionality have options :
freq
,ffills
encode and convert data from numpy to JAX: use encoders for
datetimes64
andstring_
dtypes. Be aware that by default JAX works in float32 (see JAX’s Common Gotchas to work in float64).
We have a function Stream.prepare
that implement this Step (1).
It prepares a function that wraps the input function with the actual data and indices
in a pair of pure functions (TransformedWithState
Haiku tuple).
%%time
stream = dataset.wax.stream()
CPU times: user 244 µs, sys: 36 µs, total: 280 µs
Wall time: 287 µs
Define our custom function to be applied on a dict of arrays having the same structure than the original dataset:
def my_ewma_on_dataset(dataset):
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
transform_dataset, jxs = stream.prepare(dataset, my_ewma_on_dataset)
Let’s definite the init parameters and state of the transformation we will apply.
Step (2) (compile | code tracing | execution)¶
In this step we:
prepare a pure function (with Haiku’s transform mechanism) Define a “transformation” function which:
access to the data
apply another transformation, here: EWMA
compile it with
jax.jit
perform code tracing and execution (the last line):
Unroll the transformation on “steps”
xs
(anp.arange
vector).
outputs = unroll(transform_dataset)(jxs)
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
Once it has been compiled and “traced” by JAX, the function is much faster to execute:
%%timeit
outputs = unroll(transform_dataset)(jxs)
_ = outputs.block_until_ready()
619 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
This is 3x faster than pandas implementation!
Manually prepare the data and manage the device¶
In order to manage the device on which the computations take place,
we need to have even more control over the execution flow.
Instead of calling stream.prepare
to build the transform_dataset
function,
we can do it ourselves by :
using the
stream.trace_dataset
functionconverting the numpy data in jax ourself
puting the data on the device we want.
np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")
We explicitly set data on CPUs (the is not needed if you only have CPUs):
from jax.tree_util import tree_leaves, tree_map
cpus = jax.devices("cpu")
jnp_data, jnp_index, jxs = tree_map(
lambda x: jax.device_put(x, cpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to CPU device.")
data copied to CPU device.
We have now “JAX-ready” data for later fast access.
Let’s define the transformation that wrap the actual data and indices in a pair of pure functions:
@jax.jit
@unroll
def transform_dataset(step):
dataset = tree_access_data(jnp_data, jnp_index, step)
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
And we can call it as before:
%%time
outputs = transform_dataset(jxs)
_ = outputs.block_until_ready()
CPU times: user 1.72 s, sys: 1.06 s, total: 2.78 s
Wall time: 3.31 s
%%time
outputs = transform_dataset(jxs)
_ = outputs.block_until_ready()
CPU times: user 546 ms, sys: 52.8 ms, total: 599 ms
Wall time: 567 ms
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
Step(3) (format)¶
Let’s come back to pandas/xarray:
%%time
y = format_dataframe(
dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)
CPU times: user 27.9 ms, sys: 70.4 ms, total: 98.2 ms
Wall time: 144 ms
It’s quite slow (see WEP3 enhancement proposal).
GPU execution¶
Let’s look with execution on GPU
try:
gpus = jax.devices("gpu")
jnp_data, jnp_index, jxs = tree_map(
lambda x: jax.device_put(x, gpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to GPU device.")
GPU_AVAILABLE = True
except RuntimeError as err:
print(err)
GPU_AVAILABLE = False
Requested backend gpu, but it failed to initialize: Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
Let’s check that our data is on the GPUs:
tree_leaves(jnp_data)[0].device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
tree_leaves(jnp_index)[0].device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
jxs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
%%time
if GPU_AVAILABLE:
outputs = unroll(transform_dataset)(jxs)
CPU times: user 3 µs, sys: 2 µs, total: 5 µs
Wall time: 11.2 µs
Let’s redefine our function transform_dataset
by explicitly specify to jax.jit
the device
option.
%%time
from functools import partial
if GPU_AVAILABLE:
@partial(jax.jit, device=gpus[0])
@unroll
def transform_dataset(step):
dataset = tree_access_data(jnp_data, jnp_index, step)
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
outputs = transform_dataset(jxs)
CPU times: user 8 µs, sys: 2 µs, total: 10 µs
Wall time: 15 µs
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
%%timeit
if GPU_AVAILABLE:
outputs = unroll(transform_dataset)(jxs)
_ = outputs.block_until_ready()
12 ns ± 0.0839 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
%matplotlib inline
import io
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, NamedTuple, Optional, TypeVar
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpy as onp
import optax
import pandas as pd
import plotnine as gg
import requests
from sklearn.preprocessing import MinMaxScaler
from tqdm.auto import tqdm
from wax.accessors import register_wax_accessors
from wax.compile import jit_init_apply
from wax.encode import Encoder
from wax.modules import Buffer, FillNanInf, Lag, RollingMean
from wax.unroll import unroll
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x14ade9f70>]
🔭 Reconstructing the light curve of stars with LSTM 🔭¶
Let’s take a walk through the stars…
This notebook is based on the study done in this post by Christophe Pere and the notebook available on the authors’s github.
We will repeat this study on starlight using the LSTM architecture to predict the observed light flux through time.
Our LSTM implementation is based on this notebook from Haiku’s github repository.
We’ll see how to use WAX-ML to ease the preparation of time series data stored in dataframes and having Nans before calling a “standard” deep-learning workflow.
Disclaimer¶
Despite the fact that this code works with real data, the results presented here should not be considered as scientific knowledge insights, to the knowledge of the authors of WAX-ML, neither the results nor the data source have been reviewed by an astrophysics pair.
The purpose of this notebook is only to demonstrate how WAX-ML can be used when applying a “standard” machine learning workflow, here LSTM, to analyze time series.
Download the data¶
register_wax_accessors()
# Parameters
STAR = "007609553"
SEQ_LEN = 64
BATCH_SIZE = 8
TRAIN_SIZE = 2 ** 16
NUM_EPOCHS = 10
NUM_STARS = None
RECORD_FREQ = 100
TOTAL_LEN = None
TRAIN_DATE = "2016"
CACHE_DIR = Path("./cached_data/")
%%time
filename = CACHE_DIR / "kep_lightcurves.parquet"
try:
raw_dataframe = pd.read_parquet(open(filename, "rb"))
print(f"data read from {filename}")
except FileNotFoundError:
# Downloading the csv file from Chrustioge Pere GitHub account
download = requests.get(
"https://raw.github.com/Christophe-pere/Time_series_RNN/master/kep_lightcurves.csv"
).content
raw_dataframe = pd.read_csv(io.StringIO(download.decode("utf-8")))
# set date index
raw_dataframe.index = pd.Index(
pd.date_range("2009-03-07", periods=len(raw_dataframe.index), freq="h"),
name="time",
)
# save dataframe locally in CACHE_DIR
CACHE_DIR.mkdir(exist_ok=True)
raw_dataframe.to_parquet(filename)
print(f"data saved in {filename}")
data read from cached_data/kep_lightcurves.parquet
CPU times: user 55.3 ms, sys: 32.2 ms, total: 87.5 ms
Wall time: 40.8 ms
# shortening of data to speed up the execution of the notebook in the CI
if TOTAL_LEN:
raw_dataframe = raw_dataframe.iloc[:TOTAL_LEN]
Let’s visualize the description of this dataset:
raw_dataframe.describe().T.to_xarray()
<xarray.Dataset> Dimensions: (index: 52) Coordinates: * index (index) object '001430305_orig' ... '011611275_res' Data variables: count (index) float64 6.48e+04 5.674e+04 ... 5.673e+04 5.673e+04 mean (index) float64 6.776e+04 -0.2265 0.01231 ... 0.001437 0.004351 std (index) float64 1.363e+03 15.42 15.27 12.45 ... 4.648 6.415 4.904 min (index) float64 6.529e+04 -123.3 -75.59 ... -20.32 -31.97 -20.89 25% (index) float64 6.619e+04 -9.488 -9.875 ... -3.269 -4.281 -3.279 50% (index) float64 6.806e+04 -0.3476 0.007812 ... 0.007812 -0.06529 75% (index) float64 6.882e+04 8.988 10.02 8.092 ... 2.872 4.277 3.213 max (index) float64 7.021e+04 128.7 72.31 69.34 ... 26.53 30.94 29.45
- index: 52
- index(index)object'001430305_orig' ... '011611275_...
array(['001430305_orig', '001430305_rscl', '001430305_diff', '001430305_res', '001724719_orig', '001724719_rscl', '001724719_diff', '001724719_res', '005209845_orig', '005209845_rscl', '005209845_diff', '005209845_res', '007596240_orig', '007596240_rscl', '007596240_diff', '007596240_res', '007609553_orig', '007609553_rscl', '007609553_diff', '007609553_res', '008241079_orig', '008241079_rscl', '008241079_diff', '008241079_res', '008247770_orig', '008247770_rscl', '008247770_diff', '008247770_res', '009345933_orig', '009345933_rscl', '009345933_diff', '009345933_res', '009347009_orig', '009347009_rscl', '009347009_diff', '009347009_res', '009349482_orig', '009349482_rscl', '009349482_diff', '009349482_res', '009349757_orig', '009349757_rscl', '009349757_diff', '009349757_res', '010024701_orig', '010024701_rscl', '010024701_diff', '010024701_res', '011611275_orig', '011611275_rscl', '011611275_diff', '011611275_res'], dtype=object)
- count(index)float646.48e+04 5.674e+04 ... 5.673e+04
array([64795., 56736., 50906., 50906., 64795., 59620., 55595., 55595., 54969., 49817., 45968., 45968., 64792., 60090., 56475., 56475., 64793., 50829., 45782., 45782., 64794., 54544., 50508., 50508., 64793., 58882., 54562., 54562., 64797., 60963., 58136., 58136., 64793., 60579., 57363., 57363., 64793., 61021., 58217., 58217., 64793., 60789., 57793., 57793., 51732., 47598., 44189., 44189., 64792., 60245., 56728., 56728.])
- mean(index)float646.776e+04 -0.2265 ... 0.004351
array([ 6.77637640e+04, -2.26476394e-01, 1.23063833e-02, 4.76521496e-02, 3.48858739e+04, -2.60434079e-01, -6.42200288e-04, 8.80227557e-03, 5.39107577e+03, 7.03697184e-03, -1.60903054e-02, -2.92293746e-02, 1.63334514e+04, -1.51960427e-01, -8.43860668e-03, -1.19774483e-02, 2.13926602e+04, -4.33682072e-01, 8.27861864e-03, 9.57967429e-03, 5.67090880e+04, -3.64663429e-01, -2.67114244e-03, 6.54392308e-03, 2.91100849e+04, -8.32231802e-02, -3.71574138e-03, -2.05330738e-02, 7.19740760e+03, -3.04019017e-02, -7.65547327e-04, 1.39132462e-03, 7.52192300e+03, -1.21187104e-01, 9.12380197e-03, 1.61813028e-02, 8.94735741e+03, -9.96904951e-02, 4.80096959e-03, 1.14104920e-02, 1.17224306e+04, -1.00470553e-01, 9.67554394e-03, 2.29960238e-02, 6.68149138e+04, -1.87700299e-01, -1.28047280e-02, -6.32266111e-03, 1.29621460e+04, -1.57874003e-01, 1.43721396e-03, 4.35098594e-03])
- std(index)float641.363e+03 15.42 ... 6.415 4.904
array([1362.5144182 , 15.42150384, 15.27052111, 12.44643235, 1528.19752763, 10.91976669, 11.1893542 , 8.62113287, 204.78780872, 20.19896883, 5.84919971, 4.6135537 , 535.20007192, 4.91307619, 6.8001509 , 5.21149265, 660.12491961, 153.47838423, 7.95784752, 6.83305069, 999.65408025, 15.75699475, 11.27259006, 8.82762192, 893.28598443, 11.31868477, 9.1138139 , 7.4161203 , 432.80513443, 21.70344223, 5.70210104, 4.54982055, 179.82292187, 19.17226962, 5.43930877, 4.37342724, 177.04591928, 4.67562213, 6.00262878, 4.61455226, 272.81579415, 11.3153307 , 6.40128184, 5.00633304, 2252.26388443, 89.94131559, 11.75971874, 10.21930995, 325.0546655 , 4.64757018, 6.4154822 , 4.9035335 ])
- min(index)float646.529e+04 -123.3 ... -31.97 -20.89
array([ 6.52850664e+04, -1.23261443e+02, -7.55937500e+01, -5.81310419e+01, 3.19670938e+04, -6.66089800e+01, -5.55937500e+01, -3.80387838e+01, 4.95582812e+03, -9.06898032e+01, -3.17832031e+01, -2.36347987e+01, 1.54332500e+04, -3.66425336e+01, -3.57548828e+01, -3.61389287e+01, 2.00126504e+04, -6.94903224e+02, -3.80410156e+01, -3.14954085e+01, 5.48705469e+04, -8.74110917e+01, -5.62773438e+01, -5.71788510e+01, 2.72978945e+04, -6.94690785e+01, -4.73476562e+01, -3.98079267e+01, 6.30219580e+03, -1.03889194e+02, -3.24965820e+01, -2.49301202e+01, 7.22198877e+03, -8.07118370e+01, -2.76318359e+01, -1.87084745e+01, 8.58008008e+03, -2.36194950e+01, -2.88779297e+01, -2.23368082e+01, 1.11858164e+04, -6.24034405e+01, -3.32412109e+01, -2.65777562e+01, 6.34323672e+04, -3.55021741e+02, -5.77968750e+01, -7.86537031e+01, 1.23135605e+04, -2.03195390e+01, -3.19687500e+01, -2.08865991e+01])
- 25%(index)float646.619e+04 -9.488 ... -4.281 -3.279
array([ 6.61889219e+04, -9.48833642e+00, -9.87500000e+00, -8.08656653e+00, 3.32582266e+04, -7.31995398e+00, -7.44335938e+00, -5.79592451e+00, 5.18955225e+03, -9.56397223e+00, -3.89270020e+00, -3.11812362e+00, 1.58437983e+04, -3.44825902e+00, -4.50830078e+00, -3.52359963e+00, 2.07200391e+04, -7.54642556e+01, -5.27343750e+00, -4.56884892e+00, 5.59803770e+04, -1.02040155e+01, -7.57812500e+00, -5.93617727e+00, 2.81871973e+04, -7.12718836e+00, -6.01953125e+00, -4.90452721e+00, 6.65427344e+03, -1.23264574e+01, -3.74560547e+00, -3.03522067e+00, 7.42203711e+03, -1.13228045e+01, -3.61962891e+00, -2.91987469e+00, 8.78827148e+03, -3.17652362e+00, -3.98632812e+00, -3.08491497e+00, 1.15134297e+04, -6.10318459e+00, -4.21875000e+00, -3.29996354e+00, 6.49962314e+04, -5.62104609e+01, -7.89062500e+00, -6.86189833e+00, 1.28410632e+04, -3.26949598e+00, -4.28125000e+00, -3.27908045e+00])
- 50%(index)float646.806e+04 -0.3476 ... -0.06529
array([ 6.80613828e+04, -3.47636218e-01, 7.81250000e-03, -1.30725999e-02, 3.53564023e+04, -9.16894810e-02, -3.51562500e-02, -2.89893356e-02, 5.41613721e+03, 1.46219949e-01, -5.78613281e-02, -9.43597722e-02, 1.65516709e+04, -2.64416689e-01, 2.73437500e-02, -9.33119760e-02, 2.14401523e+04, 1.14545608e+01, 2.14843750e-02, -5.92623226e-02, 5.67837734e+04, 8.92652675e-02, 1.95312500e-03, -7.25704027e-02, 2.95458281e+04, 4.87577712e-02, -4.49218750e-02, -1.00367284e-01, 7.25929248e+03, -2.94537715e-01, 3.34472656e-02, -2.26939366e-02, 7.45421924e+03, -3.65075272e-01, -1.66015625e-02, -4.78969335e-02, 8.95055371e+03, -1.40979344e-01, 7.81250000e-03, -1.58822255e-02, 1.16602783e+04, 2.74758582e-01, 2.73437500e-02, -4.15869687e-02, 6.58943555e+04, 3.32494254e+00, 4.68750000e-02, -5.63888776e-02, 1.29597114e+04, -2.24671520e-01, 7.81250000e-03, -6.52851501e-02])
- 75%(index)float646.882e+04 8.988 ... 4.277 3.213
array([6.88189727e+04, 8.98839136e+00, 1.00234375e+01, 8.09175077e+00, 3.61145820e+04, 7.01429305e+00, 7.48437500e+00, 5.69826563e+00, 5.53317725e+03, 1.00475793e+01, 3.87048340e+00, 2.98923215e+00, 1.68310957e+04, 3.02806213e+00, 4.49414062e+00, 3.38630625e+00, 2.20209531e+04, 9.08799796e+01, 5.26367188e+00, 4.52717869e+00, 5.74483965e+04, 1.01837346e+01, 7.51171875e+00, 5.81439607e+00, 2.99188320e+04, 7.15369426e+00, 6.00732422e+00, 4.81928542e+00, 7.52007275e+03, 1.29233935e+01, 3.76867676e+00, 2.95097962e+00, 7.65798096e+03, 1.08133591e+01, 3.61523438e+00, 2.87012749e+00, 9.08583008e+03, 2.94619654e+00, 3.98046875e+00, 3.03713233e+00, 1.19144434e+04, 6.53949167e+00, 4.25976562e+00, 3.31653085e+00, 6.81768262e+04, 5.91546280e+01, 7.77343750e+00, 6.74722831e+00, 1.31409404e+04, 2.87210558e+00, 4.27661133e+00, 3.21300974e+00])
- max(index)float647.021e+04 128.7 ... 30.94 29.45
array([7.02122031e+04, 1.28660432e+02, 7.23125000e+01, 6.93417611e+01, 3.75573242e+04, 4.79886762e+01, 8.00234375e+01, 6.16484304e+01, 6.00489502e+03, 9.13358716e+01, 2.45039062e+01, 2.61494406e+01, 1.73261172e+04, 2.35996818e+01, 3.72900391e+01, 2.62028292e+01, 2.24849180e+04, 3.82495214e+02, 3.95625000e+01, 3.61686737e+01, 5.90490898e+04, 7.62124556e+01, 5.41562500e+01, 5.42089099e+01, 3.03895098e+04, 6.19293590e+01, 4.28691406e+01, 4.45521063e+01, 7.83280371e+03, 7.73476225e+01, 2.73188477e+01, 2.46408888e+01, 8.00235156e+03, 6.99670019e+01, 3.19765625e+01, 2.73471557e+01, 9.59054297e+03, 2.72015177e+01, 2.51455078e+01, 2.51618206e+01, 1.22767207e+04, 4.38635401e+01, 3.68515625e+01, 3.04471056e+01, 7.25150000e+04, 3.21070047e+02, 6.51328125e+01, 5.87010158e+01, 1.39579814e+04, 2.65332384e+01, 3.09394531e+01, 2.94483002e+01])
stars = raw_dataframe.columns
stars = sorted(list(set([i.split("_")[0] for i in stars])))
print(f"The number of stars available is: {len(stars)}")
print(f"star identifiers: {stars}")
The number of stars available is: 13
star identifiers: ['001430305', '001724719', '005209845', '007596240', '007609553', '008241079', '008247770', '009345933', '009347009', '009349482', '009349757', '010024701', '011611275']
dataframe = raw_dataframe[[i + "_rscl" for i in stars]].rename(
columns=lambda c: c.replace("_rscl", "")
)
dataframe.columns.names = ["star"]
dataframe.shape
(71427, 13)
if NUM_STARS:
columns = dataframe.columns.tolist()
columns.remove(STAR)
dataframe = dataframe[[STAR] + columns[: NUM_STARS - 1]]
Rolling mean¶
We will smooth the data by applying a rolling mean with a window of 100 periods.
Count nan values¶
But before since the dataset has some nan values, we will extract few statistics about the density of nan values in windows of size 100.
It will be the occasion to show a usage of the wax.modules.Buffer
module with the format_outputs=False
option for the dataframe accessor .wax.stream
.
Let’s apply the Buffer
module to the data:
buffer, _ = dataframe.wax.stream(format_outputs=False).apply(lambda x: Buffer(100)(x))
assert isinstance(buffer, jnp.ndarray)
Equivalently, we can use wax unroll
function.
buffer = unroll(lambda x: Buffer(100)(x))(jax.device_put(dataframe.values))
Let’s describe the statistic of nans with pandas:
count_nan = jnp.isnan(buffer).sum(axis=1)
pd.DataFrame(onp.array(count_nan)).stack().describe().astype(int)
count 928551
mean 20
std 27
min 0
25% 5
50% 8
75% 19
max 100
dtype: int64
Computing the rolling mean¶
We will choose a min_periods
of 5 in order to keep at leas 75% of the points.
%%time
dataframe_mean, _ = dataframe.wax.stream().apply(
lambda x: RollingMean(100, min_periods=5)(x)
)
CPU times: user 262 ms, sys: 9.48 ms, total: 272 ms
Wall time: 270 ms
dataframe.iloc[:, :2].plot()
<AxesSubplot:xlabel='time'>

Forecasting with Machine Learning¶
We need two forecast in this data, if you look with attention you’ll see micro holes and big holes.
T = TypeVar("T")
class Pair(NamedTuple):
x: T
y: T
class TrainSplit(NamedTuple):
train: T
validation: T
gg.theme_set(gg.theme_bw())
warnings.filterwarnings("ignore")
plt.rcParams["figure.figsize"] = 18, 8
fig, (ax, lax) = plt.subplots(ncols=2, gridspec_kw={"width_ratios": [4, 1]})
dataframe.plot(ax=ax, title="raw data")
ax.legend(bbox_to_anchor=(0, 0, 1, 1), bbox_transform=lax.transAxes)
lax.axis("off")
(0.0, 1.0, 0.0, 1.0)

plt.rcParams["figure.figsize"] = 18, 8
fig, (ax, lax) = plt.subplots(ncols=2, gridspec_kw={"width_ratios": [4, 1]})
dataframe_mean.plot(ax=ax, title="Smoothed data")
ax.legend(bbox_to_anchor=(0, 0, 1, 1), bbox_transform=lax.transAxes)
lax.axis("off")
# -
(0.0, 1.0, 0.0, 1.0)

Normalize data¶
dataframe_mean.stack().hist(bins=100, log=True)
<AxesSubplot:>

def min_max_scaler(values: pd.DataFrame, output_format: str = "dataframe") -> Encoder:
scaler = MinMaxScaler(feature_range=(0, 1))
scaler.fit(values)
index = values.index
columns = values.columns
def encode(dataframe: pd.DataFrame):
nonlocal index
nonlocal columns
index = dataframe.index
columns = dataframe.columns
array_normed = scaler.transform(dataframe)
if output_format == "dataframe":
return pd.DataFrame(array_normed, index, columns)
elif output_format == "jax":
return jnp.array(array_normed)
else:
return array_normed
def decode(array_scaled):
value = scaler.inverse_transform(array_scaled)
if output_format == "dataframe":
return pd.DataFrame(value, index, columns)
else:
return value
return Encoder(encode, decode)
scaler = min_max_scaler(dataframe_mean)
dataframe_normed = scaler.encode(dataframe_mean)
assert (scaler.decode(dataframe_normed) - dataframe_mean).stack().abs().max() < 1.0e-4
dataframe_normed.stack().hist(bins=100)
<AxesSubplot:>

Prepare train / validation datasets¶
def split_feature_target(
dataframe,
look_back=SEQ_LEN,
shuffle=True,
stack=True,
min_periods_ratio: float = 0.8,
rng=None,
) -> Pair:
def prepare_xy(data):
buffer = Buffer(look_back + 1)(data)
x = buffer[:-1]
y = buffer[-1]
return x, y
def prepare_xy(data):
y = Buffer(look_back)(data)
x = Lag(1)(y)
return x, y
x, y = unroll(prepare_xy)(jax.device_put(dataframe.values))
if shuffle:
if rng is None:
rng = jax.random.PRNGKey(42)
B = x.shape[0]
idx = jnp.arange(B)
idx = jax.random.shuffle(rng, idx)
x = x[idx]
y = y[idx]
if stack:
B, T, F = x.shape
x = x.transpose(1, 0, 2).reshape(T, B * F, 1).transpose(1, 0, 2)
y = y.transpose(1, 0, 2).reshape(T, B * F, 1).transpose(1, 0, 2)
if min_periods_ratio:
T = x.shape[1]
count_nan = jnp.isnan(x).sum(axis=1)
mask = count_nan < min_periods_ratio * T
idx = jnp.where(mask)
x = x[idx[0]]
y = y[idx[0]]
# round Batch size to a power of to
B = x.shape[0]
B_round = int(2 ** jnp.floor(jnp.log2(B)))
print(f"{B} batches rounded to {B_round} batches.")
x = x[:B_round]
y = y[:B_round]
# fillnan by zeros
x, y = hk.testing.transform_and_run(lambda x: FillNanInf()(x))((x, y))
return Pair(x, y)
# split_feature_target(dataframe)
def split_train_validation(
dataframe, train_size, look_back, scaler: Optional[Callable] = None
) -> TrainSplit:
# prepare scaler
train_df = dataframe.iloc[:train_size]
if scaler:
scaler = scaler(train_df)
# prepare train data
if scaler:
train_df = scaler.encode(train_df)
train_xy = split_feature_target(train_df, look_back)
# prepare validation data
valid_size = len(dataframe) - train_size
valid_size = int(2 ** jnp.floor(jnp.log2(valid_size)))
valid_end = int(train_size + valid_size)
valid_df = dataframe.iloc[train_size:valid_end]
if scaler:
valid_df = scaler.encode(valid_df)
valid_xy = split_feature_target(valid_df, look_back)
return TrainSplit(train_xy, valid_xy)
TRAIN_SIZE
65536
print(f"Look at star: {STAR}")
train, valid = split_train_validation(dataframe_normed[[STAR]], TRAIN_SIZE, SEQ_LEN)
Look at star: 007609553
63871 batches rounded to 32768 batches.
3597 batches rounded to 2048 batches.
train[0].shape, train[1].shape, valid[0].shape, valid[1].shape
((32768, 64, 1), (32768, 64, 1), (2048, 64, 1), (2048, 64, 1))
# TRAIN_SIZE, VALID_SIZE = len(train.x), len(valid.x)
print(
f"effective train_size = {len(train.x)}, " f"effective valid size= {len(valid.x)}"
)
effective train_size = 32768, effective valid size= 2048
# Plot an observation/target pair.
rng = jax.random.PRNGKey(42)
batch_plot = jax.random.choice(rng, len(train[0]))
df = pd.DataFrame(
{"x": train.x[batch_plot, :, 0], "y": train.y[batch_plot, :, 0]}
).reset_index()
df = pd.melt(df, id_vars=["index"], value_vars=["x", "y"])
plot = (
gg.ggplot(df)
+ gg.aes(x="index", y="value", color="variable")
+ gg.geom_line()
+ gg.scales.scale_y_log10()
)
_ = plot.draw()

Dataset iterator¶
class Dataset:
"""An iterator over a numpy array, revealing batch_size elements at a time."""
def __init__(self, xy: Pair, batch_size: int):
self._x, self._y = xy
self._batch_size = batch_size
self._length = self._x.shape[0]
self._idx = 0
if self._length % batch_size != 0:
msg = "dataset size {} must be divisible by batch_size {}."
raise ValueError(msg.format(self._length, batch_size))
def __next__(self) -> Pair:
start = self._idx
end = start + self._batch_size
x, y = self._x[start:end], self._y[start:end]
if end >= self._length:
print(f"End of the data set (size={end}). Return to the beginning.")
end = end % self._length
assert end == 0 # Guaranteed by ctor assertion.
self._idx = end
return Pair(x, y)
Training an LSTM¶
To train the LSTM, we define a Haiku function which unrolls the LSTM over the input sequence, generating predictions for all output values. The LSTM always starts with its initial state at the start of the sequence.
The Haiku function is then transformed into a pure function through hk.transform
, and is trained with Adam on an L2 prediction loss.
def unroll_net(seqs: jnp.ndarray):
"""Unrolls an LSTM over seqs, mapping each output to a scalar."""
# seqs is [T, B, F].
core = hk.LSTM(32)
batch_size = seqs.shape[0]
outs, state = hk.dynamic_unroll(
core, seqs, core.initial_state(batch_size), time_major=False
)
# We could include this Linear as part of the recurrent core!
# However, it's more efficient on modern accelerators to run the linear once
# over the entire sequence than once per sequence element.
return hk.BatchApply(hk.Linear(1))(outs), state
model = jit_init_apply(hk.transform(unroll_net))
@jax.jit
def loss(pred, y):
return jnp.mean(jnp.square(pred - y))
def model_with_loss(x, y):
pred, _ = unroll_net(x)
return loss(pred, y)
class TrainState(NamedTuple):
step: int
params: Any
opt_state: Any
rng: jnp.ndarray
loss: float
def train_model(
model_with_loss: Callable,
train_ds: Dataset,
valid_ds: Dataset,
max_iterations: int = -1,
rng=None,
record_freq=100,
) -> hk.Params:
"""Initializes and trains a model on train_ds, returning the final params."""
opt = optax.adam(1e-3)
model_with_loss = jit_init_apply(hk.transform(model_with_loss))
@jax.jit
def update(train_state, x, y):
step, params, opt_state, rng, _ = train_state
if rng is not None:
(rng,) = jax.random.split(rng, 1)
l, grads = jax.value_and_grad(model_with_loss.apply)(params, rng, x, y)
grads, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, grads)
return TrainState(step + 1, params, opt_state, rng, l)
# Initialize state.
def init():
x, y = next(train_ds)
params = model_with_loss.init(rng, x, y)
opt_state = opt.init(params)
return TrainState(0, params, opt_state, rng, jnp.inf)
def _format_results(records):
records = {key: jnp.stack(l) for key, l in records.items()}
return records
records = defaultdict(list)
train_state = init()
with tqdm(total=max_iterations if max_iterations > 0 else None) as pbar:
while True:
try:
x, y = next(train_ds)
except StopIteration:
return train_state, _format_results(records)
train_state = update(train_state, x, y)
if train_state.step % record_freq == 0:
x, y = next(valid_ds)
if rng is not None:
(rng,) = jax.random.split(rng, 1)
valid_loss = model_with_loss.apply(train_state.params, rng, x, y)
records["step"].append(train_state.step)
records["valid_loss"].append(valid_loss)
records["train_loss"].append(train_state.loss)
pbar.update()
if max_iterations > 0 and train_state.step >= max_iterations:
return train_state, _format_results(records)
%%time
train, valid = split_train_validation(dataframe_normed[[STAR]], TRAIN_SIZE, SEQ_LEN)
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)
train_state, records = train_model(
model_with_loss,
train_ds,
valid_ds,
len(train.x) // BATCH_SIZE * NUM_EPOCHS,
rng=jax.random.PRNGKey(42),
record_freq=RECORD_FREQ,
)
63871 batches rounded to 32768 batches.
3597 batches rounded to 2048 batches.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=2048). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
CPU times: user 58.6 s, sys: 371 ms, total: 59 s
Wall time: 58.8 s
# train_state.params
# Plot losses
losses = pd.DataFrame(records)
df = pd.melt(losses, id_vars=["step"], value_vars=["train_loss", "valid_loss"])
plot = (
gg.ggplot(df)
+ gg.aes(x="step", y="value", color="variable")
+ gg.geom_line()
+ gg.scales.scale_y_log10()
)
_ = plot.draw()

Sampling¶
The point of training models is so that they can make predictions! How can we generate predictions with the trained model?
If we’re allowed to feed in the ground truth, we can just run the original model’s apply
function.
def plot_samples(truth: np.ndarray, prediction: np.ndarray) -> gg.ggplot:
assert truth.shape == prediction.shape
df = pd.DataFrame(
{"truth": truth.squeeze(), "predicted": prediction.squeeze()}
).reset_index()
df = pd.melt(df, id_vars=["index"], value_vars=["truth", "predicted"])
plot = (
gg.ggplot(df) + gg.aes(x="index", y="value", color="variable") + gg.geom_line()
)
return plot
# Grab a sample from the validation set.
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1] # Shrink to batch-size 1.
sample_y = sample_y[:1]
# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(train_state.params, None, sample_x)
plot = plot_samples(sample_y, predicted)
plot.draw()
del sample_x, predicted

Run autoregressively¶
If we can’t feed in the ground truth (because we don’t have it), we can also run the model autoregressively.
def autoregressive_predict(
trained_params: hk.Params,
context: jnp.ndarray,
seq_len: int,
pbar=False,
):
"""Given a context, autoregressively generate the rest of a sine wave."""
ar_outs = []
context = jax.device_put(context)
times = onp.arange(seq_len - context.shape[1] + 1)
if pbar:
times = tqdm(times)
for _ in times:
full_context = jnp.concatenate([context] + ar_outs, axis=1)
outs, _ = model.apply(trained_params, None, full_context)
# Append the newest prediction to ar_outs.
ar_outs.append(outs[:, -1:, :])
# Return the final full prediction.
return outs
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1] # Shrink to batch-size 1.
sample_y = sample_y[:1] # Shrink to batch-size 1.
context_length = SEQ_LEN // 8
print(f"context_length = {context_length}")
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:, :context_length]
context_length = 8
%%time
# We can reuse params we got from training for inference - as long as the
# declaration order is the same.
predicted = autoregressive_predict(train_state.params, context, SEQ_LEN, pbar=True)
CPU times: user 7.5 s, sys: 123 ms, total: 7.63 s
Wall time: 7.56 s
sample_y.shape, predicted.shape
((1, 64, 1), (1, 64, 1))
plot = plot_samples(sample_y, predicted)
plot += gg.geom_vline(xintercept=context.shape[1], linetype="dashed")
_ = plot.draw()

Sharing parameters with a different function.¶
Unfortunately, this is a bit slow - we’re doing O(N^2) computation for a sequence of length N.
It’d be better if we could do the autoregressive sampling all at once - but we need to write a new Haiku function for that.
We’re in luck - if the Haiku module names match, the same parameters can be used for multiple Haiku functions.
This can be achieved through a combination of two techniques:
If we manually give a unique name to a module, we can ensure that the parameters are directed to the right places.
If modules are instantiated in the same order, they’ll have the same names in different functions.
Here, we rely on method #2 to create a fast autoregressive prediction.
@hk.transform
def fast_autoregressive_predict_fn(context, seq_len):
"""Given a context, autoregressively generate the rest of a sine wave."""
core = hk.LSTM(32)
dense = hk.Linear(1)
state = core.initial_state(context.shape[0])
# Unroll over the context using `hk.dynamic_unroll`.
# As before, we `hk.BatchApply` the Linear for efficiency.
context_outs, state = hk.dynamic_unroll(
core,
context,
state,
time_major=False,
)
context_outs = hk.BatchApply(dense)(context_outs)
# Now, unroll one step at a time using the running recurrent state.
ar_outs = []
x = context_outs[:, -1, :]
times = range(seq_len - context.shape[1])
for _ in times:
x, state = core(x, state)
x = dense(x)
ar_outs.append(x)
ar_outs = jnp.stack(ar_outs)
ar_outs = ar_outs.transpose(1, 0, 2)
return jnp.concatenate([context_outs, ar_outs], axis=1)
fast_autoregressive_predict = jax.jit(
fast_autoregressive_predict_fn.apply, static_argnums=(3,)
)
%%time
# Reuse the same context from the previous cell.
predicted = fast_autoregressive_predict(train_state.params, None, context, SEQ_LEN)
CPU times: user 24.6 s, sys: 54.8 ms, total: 24.6 s
Wall time: 24.6 s
# The plots should be equivalent!
plot = plot_samples(sample_y, predicted)
plot += gg.geom_vline(xintercept=context.shape[1], linetype="dashed")
_ = plot.draw()

Sample trajectories¶
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1] # Shrink to batch-size 1.
sample_y = sample_y[:1] # Shrink to batch-size 1.
context_length = SEQ_LEN // 8
print(f"context_length = {context_length}")
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:, :context_length]
# Reuse the same context from the previous cell.
predicted = fast_autoregressive_predict(train_state.params, None, context, SEQ_LEN)
# The plots should be equivalent!
plot = plot_samples(sample_y, predicted)
plot += gg.geom_vline(xintercept=context.shape[1], linetype="dashed")
_ = plot.draw()
context_length = 8

timeit¶
%timeit autoregressive_predict(train_state.params, context, SEQ_LEN)
%timeit fast_autoregressive_predict(train_state.params, None, context, SEQ_LEN)
32.4 ms ± 331 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
25 µs ± 90.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Train all stars¶
Training¶
def split_train_validation_date(dataframe, date, look_back) -> TrainSplit:
train_size = len(dataframe.loc[:date])
return split_train_validation(dataframe, train_size, look_back)
%%time
train, valid = split_train_validation_date(dataframe_normed, TRAIN_DATE, SEQ_LEN)
print(f"effective train size = {train[0].shape[1]}")
838194 batches rounded to 524288 batches.
26455 batches rounded to 16384 batches.
effective train size = 64
CPU times: user 2.65 s, sys: 743 ms, total: 3.39 s
Wall time: 2.82 s
train[0].shape, train[1].shape, valid[0].shape, valid[1].shape
((524288, 64, 1), (524288, 64, 1), (16384, 64, 1), (16384, 64, 1))
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)
# del train, valid # Don't leak temporaries.
%%time
train_state, records = train_model(
model_with_loss,
train_ds,
valid_ds,
len(train.x) // BATCH_SIZE * 1,
jax.random.PRNGKey(42),
record_freq=RECORD_FREQ,
)
End of the data set (size=524288). Return to the beginning.
CPU times: user 1min 30s, sys: 408 ms, total: 1min 30s
Wall time: 1min 30s
# Plot losses
losses = pd.DataFrame(records)
df = pd.melt(losses, id_vars=["step"], value_vars=["train_loss", "valid_loss"])
plot = (
gg.ggplot(df)
+ gg.aes(x="step", y="value", color="variable")
+ gg.geom_line()
+ gg.scales.scale_y_log10()
)
_ = plot.draw()

Sampling¶
# Grab a sample from the validation set.
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1] # Shrink to batch-size 1.
sample_y = sample_y[:1] # Shrink to batch-size 1.
# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(train_state.params, None, sample_x)
plot = plot_samples(sample_y, predicted)
_ = plot.draw()

Run autoregressively¶
%%time
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1] # Shrink to batch-size 1.
sample_y = sample_y[:1] # Shrink to batch-size 1.
context_length = SEQ_LEN // 8
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:, :context_length]
# Reuse the same context from the previous cell.
predicted = fast_autoregressive_predict(train_state.params, None, context, SEQ_LEN)
# The plots should be equivalent!
plot = plot_samples(sample_y, predicted)
plot += gg.geom_vline(xintercept=len(context), linetype="dashed")
_ = plot.draw()
CPU times: user 64.6 ms, sys: 2.56 ms, total: 67.2 ms
Wall time: 65.8 ms

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]
🦎 Online linear regression with a non-stationary environment 🦎¶
We implement an online learning non-stationary linear regression problem.
We go there progressively by showing how a linear regression problem can be cast
into an online learning problem thanks to the OnlineSupervisedLearner
module.
Then, to tackle a non-stationary linear regression problem (i.e. with a weight that can vary in time)
we reformulate the problem into a reinforcement learning problem that we implement with the GymFeedBack
module of WAX-ML.
We then need to define an “agent” and an “environment” using simple functions implemented with modules:
The agent is responsible for learning the weights of its internal linear model.
The environment is responsible for generating labels and evaluating the agent’s reward metric.
We experiment with a non-stationary environment that returns the sign of the linear regression parameters at a given time step, known only to the environment.
This example shows that it is quite simple to implement this online-learning task with WAX-ML tools. In particular, the functional workflow adopted here allows reusing the functions implemented for a task for each new task of increasing complexity,
In this journey, we will use:
Haiku basic linear module
hk.Linear
.Optax stochastic gradient descent optimizer:
sgd
.WAX-ML modules:
OnlineSupervisedLearner
,Lag
,GymFeedBack
WAX-ML helper functions:
unroll
,jit_init_apply
%pylab inline
Populating the interactive namespace from numpy and matplotlib
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from matplotlib import pyplot as plt
from wax.compile import jit_init_apply
from wax.modules import OnlineSupervisedLearner
Static Linear Regression¶
First, let’s implement a simple linear regression
Generate data¶
Let’s generate a batch of data:
seq = hk.PRNGSequence(42)
X = jax.random.normal(next(seq), (100, 3))
w_true = jnp.ones(3)
Define the model¶
We use the basic module hk.Linear
which is a linear layer.
By default, it initializes the weights with random values from the truncated normal,
with a standard deviation of \(1 / \sqrt{N}\) (See https://arxiv.org/abs/1502.03167v3)
where \(N\) is the size of the inputs.
@jit_init_apply
@hk.transform_with_state
def linear_model(x):
return hk.Linear(output_size=1, with_bias=False)(x)
Run the model¶
Let’s run the model using WAX-ML unroll
on the batch of data.
from wax.unroll import unroll
params, state = linear_model.init(next(seq), X[0])
linear_model.apply(params, state, None, X[0])
(DeviceArray([-0.2070887], dtype=float32), FlatMapping({}))
Y_pred = unroll(linear_model, rng=next(seq))(X)
Y_pred.shape
(100, 1)
Check cost¶
Let’s look at the mean squared error for this non-trained model.
noise = jax.random.normal(next(seq), (100,))
Y = X.dot(w_true) + noise
L = ((Y - Y_pred) ** 2).sum(axis=1)
mean_loss = L.mean()
assert mean_loss > 0
Let’s look at the regret (cumulative sum of the loss) for the non-trained model.
plt.plot(L.cumsum())
plt.title("Regret")
Text(0.5, 1.0, 'Regret')

As expected, we have a linear regret when we did not train the model!
Online Linear Regression¶
We will now start training the model online. For a review on online-learning methods see [1]
[1] Elad Hazan, Introduction to Online Convex Optimization
Define an optimizer¶
opt = optax.sgd(1e-3)
Define a loss¶
Since we are doing online learning, we need to define a local loss function: $\( \ell_t(y, w, x) = \lVert y_t - w \cdot x_t \rVert^2 \)$
@jax.jit
def loss(y_pred, y):
return jnp.mean(jnp.square(y_pred - y))
Define a learning strategy¶
@jit_init_apply
@hk.transform_with_state
def learner(x, y):
return OnlineSupervisedLearner(linear_model, opt, loss)(x, y)
Generate data¶
def generate_many_observations(T=300, sigma=1.0e-2, rng=None):
rng = jax.random.PRNGKey(42) if rng is None else rng
X = jax.random.normal(rng, (T, 3))
noise = sigma * jax.random.normal(rng, (T,))
w_true = jnp.ones(3)
noise = sigma * jax.random.normal(rng, (T,))
Y = X.dot(w_true) + noise
return (X, Y)
T = 3000
X, Y = generate_many_observations(T)
Unroll the learner¶
(output, info) = unroll(learner, rng=next(seq))(X, Y)
Plot the regret¶
Let’s look at the loss and regret over time.
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
axs[0].plot(info.loss.cumsum())
axs[0].set_title("Regret")
axs[1].plot(info.params["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
Text(0.5, 1.0, 'Weight[0,0]')

We have sub-linear regret!
Online learning with Gym¶
Now we will recast the online linear regression learning task as a reinforcement learning task
implemented with the GymFeedback
module of WAX-ML.
For that, we define:
obserbations (
obs
) : pairs(x, y)
of features and labelsraw observations (
raw_obs
): pairs(x, noise)
of features and noise.
Linear regression agent¶
In WAX-ML, an agent is a simple function with the following API:

Let’s define a simple linear regression agent with the elements we have defined so far.
def linear_regression_agent(obs):
x, y = obs
@jit_init_apply
@hk.transform_with_state
def model(x):
return hk.Linear(output_size=1, with_bias=False)(x)
opt = optax.sgd(1e-3)
@jax.jit
def loss(y_pred, y):
return jnp.mean(jnp.square(y_pred - y))
return OnlineSupervisedLearner(model, opt, loss)(x, y)
Linear regression environment¶
In WAX-ML, an environment is a simple function with the following API:

Let’s now define a linear regression environment that, for the moment, have static weights.
It is responsible for generating the real labels and evaluating the agent’s reward.
For the evaluation of the reward, we need the Lag
module to evaluate the action of
the agent with the labels generated in the previous time step.
from wax.modules import Lag
def stationary_linear_regression_env(action, raw_obs):
# Only the environment now the true value of the parameters
w_true = -jnp.ones(3)
# The environment has its proper loss definition
@jax.jit
def loss(y_pred, y):
return jnp.mean(jnp.square(y_pred - y))
# raw observation contains features and generative noise
x, noise = raw_obs
# generate targets
y = x @ w_true + noise
obs = (x, y)
y_previous = Lag(1)(y)
# evaluate the prediction made by the agent
y_pred = action
reward = loss(y_pred, y_previous)
return reward, obs, {}
Generate raw observation¶
Let’s define a function that generate the raw observation:
def generate_many_raw_observations(T=300, sigma=1.0e-2, rng=None):
rng = jax.random.PRNGKey(42) if rng is None else rng
X = jax.random.normal(rng, (T, 3))
noise = sigma * jax.random.normal(rng, (T,))
return (X, noise)
Implement Feedback¶
We are now ready to set things up with the GymFeedback
module implemented in WAX-ML.
It implements the following feedback loop:

Equivalently, it can be described with the pair of init
and apply
functions:

from wax.modules import GymFeedback
@hk.transform_with_state
def gym_fun(raw_obs):
return GymFeedback(
linear_regression_agent, stationary_linear_regression_env, return_action=True
)(raw_obs)
And now we can unroll it on a sequence of raw observations!
seq = hk.PRNGSequence(42)
T = 3000
raw_observations = generate_many_raw_observations(T)
rng = next(seq)
(gym_output, gym_info) = unroll(gym_fun, rng=rng, skip_first=True)(raw_observations)
Let’s visualize the outputs.
We now use pd.Series
to represent the reward sequence since its first value is Nan due to the use of the lag operator.
import pandas as pd
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
pd.Series(gym_output.reward).cumsum().plot(ax=axs[0], title="Regret")
axs[1].plot(info.params["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
Text(0.5, 1.0, 'Weight[0,0]')

Non-stationary environment¶
Now, let’s implement a non-stationary environment.
We implement it so that the sign of the weight is reversed after 2000$ steps.
class NonStationaryEnvironment(hk.Module):
def __call__(self, action, raw_obs):
step = hk.get_state("step", [], init=lambda *_: 0)
# Only the environment now the true value of the parameters
# at step 2000 we flip the sign of the true parameters !
w_true = hk.cond(
step < 2000,
step,
lambda step: -jnp.ones(3),
step,
lambda step: jnp.ones(3),
)
# The environment has its proper loss definition
@jax.jit
def loss(y_pred, y):
return jnp.mean(jnp.square(y_pred - y))
# raw observation contains features and generative noise
x, noise = raw_obs
# generate targets
y = x @ w_true + noise
obs = (x, y)
# evaluate the prediction made by the agent
y_previous = Lag(1)(y)
y_pred = action
reward = loss(y_pred, y_previous)
step += 1
hk.set_state("step", step)
return reward, obs, {}
Now let’s run a gym simulation to see how the agent adapt to the change of environment.
@hk.transform_with_state
def gym_fun(raw_obs):
return GymFeedback(
linear_regression_agent, NonStationaryEnvironment(), return_action=True
)(raw_obs)
T = 6000
raw_observations = generate_many_raw_observations(T)
rng = jax.random.PRNGKey(42)
(gym_output, gym_info) = unroll(gym_fun, rng=rng, skip_first=True)(
raw_observations,
)
import pandas as pd
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
pd.Series(gym_output.reward).cumsum().plot(ax=axs[0], title="Regret")
axs[1].plot(gym_info.agent.params["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
# plt.savefig("../_static/online_linear_regression_regret.png")
Text(0.5, 1.0, 'Weight[0,0]')

It adapts!
The regret first converges, then jumps on step 2000 and finally readjusts to the new regime.
We see that the weights converge to the correct values in both regimes.
🔄 Online learning for time series prediction 🔄¶
In [1], the authors develop an online learning method to predict time-series generated by and ARMA (autoregressive moving average) model.
They develop an effective online learning algorithm based on an improper learning approach which consists to use an AR model for prediction with sufficiently long horizon, together with an online update of the prediction model parameters using either an “online Newton” algorithm [2] or a stochastic gradient descent algorithm.
This effective approach adresses the prediction problem, without assuming that the noise terms are Gaussian, identically distributed or even independent. Furthermore, they show that their algorithm’s performances asymptotically approaches the performance of the best ARMA model in hindsight.
We use WAX-ML to reproduce their empirical results.
We first focus in the reproduction of the “setting 1” for sanity checks and show how to setup a training environment with WAX-ML to study improper learning of ARMA time-series models.
We then study the behavior the method with different optimizers in the non-stationary environements proposed in [1] (settings 2, 3, and 4).
We use the following modules from WAX-ML:
ARMA : to generate a modeled time-series
SNARIMAX : to adaptively learn to predict the generated time-series.
GymFeedback: To setup a training loop.
VMap: to add batch dimensions to the training loop
optim.newton: a newton algorithm as used in [1] and developped in [2]. It extends
optax
optimizers.OnlineOptimizer: A wrapper for a model with loss and and optimizer for online learning.
References¶
%pylab inline
%load_ext autoreload
%autoreload 2
Populating the interactive namespace from numpy and matplotlib
from typing import Any, NamedTuple
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
import optax
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from wax.modules import (
ARMA,
SNARIMAX,
GymFeedback,
Lag,
OnlineOptimizer,
UpdateParams,
VMap,
)
from wax.optim import newton
from wax.unroll import unroll_transform_with_state
T = 10000
N_BATCH = 20
N_STEP_SIZE = 10
N_EPS = 5
ARMA¶
Let’s generate a sample of the “setting 1” of [1]:
alpha = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
beta = jnp.array([0.3, -0.2])
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
sim = unroll_transform_with_state(lambda eps: ARMA(alpha, beta)(eps))
params, state = sim.init(rng, eps)
y, state = sim.apply(params, state, rng, eps)
pd.Series(y).plot()

SNARIMAX¶
Let’s setup an online model to try to learn the dynamic of the time-series.
First let’s run the filter with it’s initial random weights.
def predict(y, X=None):
return SNARIMAX(10, 0, 0)(y, X)
sim = unroll_transform_with_state(predict)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (10,))
params, state = sim.init(rng, y)
(y_pred, _), state = sim.apply(params, state, rng, y)
pd.Series((y - y_pred)).plot()

def evaluate(y_pred, y):
return jnp.linalg.norm(y_pred - y) ** 2, {}
def lag(shift=1):
def __call__(y, X=None):
yp = Lag(shift)(y)
Xp = Lag(shift)(X) if X is not None else None
return yp, Xp
return __call__
def predict_and_evaluate(y, X=None):
# predict with lagged data
y_pred, pred_info = predict(*lag(1)(y, X))
# evaluate loss with actual data
loss, loss_info = evaluate(y_pred, y)
return loss, dict(pred_info=pred_info, loss_info=loss_info)
sim = unroll_transform_with_state(predict_and_evaluate)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, y)
(loss, _), state = sim.apply(params, state, rng, y)
pd.Series(loss).expanding().mean().plot()

Since the model is not trained and the coefficient of the SNARIMAX filter a choosen randomly, the loss may diverge
params
FlatMapping({
'snarimax/~/linear': FlatMapping({
'w': DeviceArray([[-0.47023755],
[ 0.07070494],
[-0.07388116],
[ 0.13453043],
[ 0.07728617],
[ 0.0851969 ],
[-0.03324771],
[ 0.17115903],
[ 0.1023274 ],
[-0.4804019 ]], dtype=float32),
'b': DeviceArray([0.], dtype=float32),
}),
})
Learn¶
To learn the model parameters we will use the OnlineOptimizer
of WAX-ML.
Setup projection¶
We can setup a projection for the parameters:
def project_params(params, opt_state=None):
w = params["snarimax/~/linear"]["w"]
w = jnp.clip(w, -1, 1)
params["snarimax/~/linear"]["w"] = w
return params
project_params(hk.data_structures.to_mutable_dict(params))
{'snarimax/~/linear': {'w': DeviceArray([[-0.47023755],
[ 0.07070494],
[-0.07388116],
[ 0.13453043],
[ 0.07728617],
[ 0.0851969 ],
[-0.03324771],
[ 0.17115903],
[ 0.1023274 ],
[-0.4804019 ]], dtype=float32),
'b': DeviceArray([0.], dtype=float32)}}
def learn(y, X=None):
optim_res = OnlineOptimizer(
predict_and_evaluate, optax.sgd(1.0e-2), project_params=project_params
)(y, X)
return optim_res
Let’s train:
sim = unroll_transform_with_state(learn)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, eps)
optim_res, state = sim.apply(params, state, rng, eps)
pd.Series(optim_res.loss).expanding().mean().plot()

Let’s look at the latest weights:
jax.tree_map(lambda x: x[-1], optim_res.updated_params)
FlatMapping({
'snarimax/~/linear': FlatMapping({
'b': DeviceArray([-0.14335798], dtype=float32),
'w': DeviceArray([[-0.06957585],
[ 0.11199051],
[-0.14902674],
[-0.02122167],
[-0.03742805],
[ 0.00436568],
[-0.10548005],
[-0.05702855],
[-0.01185708],
[-0.01168777]], dtype=float32),
}),
})
Learn and Forecast¶
class ForecastInfo(NamedTuple):
optim: Any
forecast: Any
def learn_and_forecast(y, X=None):
optim_res = OnlineOptimizer(
predict_and_evaluate, optax.sgd(1.0e-3), project_params=project_params
)(*lag(1)(y, X))
predict_params = optim_res.updated_params
forecast, forecast_info = UpdateParams(predict)(predict_params, y, X)
return forecast, ForecastInfo(optim_res, forecast_info)
sim = unroll_transform_with_state(learn_and_forecast)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, y)
(forecast, info), state = sim.apply(params, state, rng, y)
pd.Series(info.optim.loss).expanding().mean().plot()
<AxesSubplot:>

Gym simulation¶
Now let’s wrapup the training loop in a gym feedback loop.
Environment¶
Let’s build an environment corresponding to “setting 1” in [1]
def build_env():
def env(action, obs):
y_pred, eps = action, obs
ar_coefs = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
ma_coefs = jnp.array([0.3, -0.2])
y = ARMA(ar_coefs, ma_coefs)(eps)
# prediction used on a fresh y observation.
rw = -((y - y_pred) ** 2)
env_info = {"y": y, "y_pred": y_pred}
obs = y
return rw, obs, env_info
return env
env = build_env()
Agent¶
Let’s build an agent:
from optax._src.base import OptState
def build_agent(time_series_model=None, opt=None):
if time_series_model is None:
time_series_model = lambda y, X: SNARIMAX(10)(y, X)
if opt is None:
opt = optax.sgd(1.0e-3)
class AgentInfo(NamedTuple):
optim: Any
forecast: Any
class ModelWithLossInfo(NamedTuple):
pred: Any
loss: Any
def agent(obs):
if isinstance(obs, tuple):
y, X = obs
else:
y = obs
X = None
def evaluate(y_pred, y):
return jnp.linalg.norm(y_pred - y) ** 2, {}
def model_with_loss(y, X=None):
# predict with lagged data
y_pred, pred_info = time_series_model(*lag(1)(y, X))
# evaluate loss with actual data
loss, loss_info = evaluate(y_pred, y)
return loss, ModelWithLossInfo(pred_info, loss_info)
def project_params(params: Any, opt_state: OptState = None):
del opt_state
return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)
def split_params(params):
def filter_params(m, n, p):
# print(m, n, p)
return m.endswith("snarimax/~/linear") and n == "w"
return hk.data_structures.partition(filter_params, params)
def learn_and_forecast(y, X=None):
optim_res = OnlineOptimizer(
model_with_loss,
opt,
project_params=project_params,
split_params=split_params,
)(*lag(1)(y, X))
predict_params = optim_res.updated_params
y_pred, forecast_info = UpdateParams(time_series_model)(
predict_params, y, X
)
return y_pred, AgentInfo(optim_res, forecast_info)
return learn_and_forecast(y, X)
return agent
agent = build_agent()
Gym loop¶
def gym_loop(eps):
return GymFeedback(agent, env)(eps)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T,))
sim = unroll_transform_with_state(gym_loop)
params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
pd.Series(info.agent.optim.loss).expanding().mean().plot(label="agent loss")
pd.Series(-gym.reward).expanding().mean().plot(xlim=(0, 100), label="env loss")
plt.legend()
<matplotlib.legend.Legend at 0x167c42490>

pd.Series(-gym.reward).expanding().mean().plot() # ylim=(0.09, 0.15))
<AxesSubplot:>

We see that the agent suffers the same loss as the environment but with a time lag.
Batch simulations¶
Average over 20 experiments¶
Slow version¶
First, let’s do it “naively” by doing a simple python “for loop”.
%%time
rng = jax.random.PRNGKey(42)
sim = unroll_transform_with_state(gym_loop)
res = {}
for i in tqdm(onp.arange(N_BATCH)):
rng, _ = jax.random.split(rng)
eps = jax.random.normal(rng, (T,)) * 0.3
params, state = sim.init(rng, eps)
(gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
res[i] = gym_info
CPU times: user 13.7 s, sys: 182 ms, total: 13.9 s
Wall time: 13.8 s
pd.DataFrame({k: pd.Series(v.agent.optim.loss) for k, v in res.items()}).mean(
1
).expanding().mean().plot(ylim=(0.09, 0.15))

Fast version with vmap¶
Instead of using a “for loop” we can use jax’s vmap transformation function!
%%time
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (N_BATCH, T)) * 0.3
rng = jax.random.PRNGKey(42)
rng = jax.random.split(rng, num=N_BATCH)
sim = unroll_transform_with_state(gym_loop)
params, state = jax.vmap(sim.init)(rng, eps)
(gym_output, gym_info), final_state = jax.vmap(sim.apply)(params, state, rng, eps)
CPU times: user 2.09 s, sys: 61 ms, total: 2.15 s
Wall time: 2.13 s
This is much faster!
pd.DataFrame(gym_info.agent.optim.loss).mean().expanding().mean().plot(
ylim=(0.09, 0.15)
)

i_batch = 0
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][i_batch, :, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title=f"weights on batch {i_batch}")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight value")
plt.figure()
w.iloc[-1][::-1].plot(kind="bar")


w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"].mean(axis=0)[:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")
plt.figure()
w.iloc[-1][::-1].plot(kind="bar")


With VMap module¶
We can use the wrapper module VMap
of WAX-ML. It permits to have an ever simpler syntax.
Note: we have to swap the position of time and batch dimensions in the generation of the noise variable
eps
.
%%time
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3
def batched_gym_loop(eps):
return VMap(gym_loop)(eps)
sim = unroll_transform_with_state(batched_gym_loop)
rng = jax.random.PRNGKey(43)
params, state = sim.init(rng, eps)
(gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
CPU times: user 1.53 s, sys: 15.5 ms, total: 1.54 s
Wall time: 1.53 s
pd.DataFrame(gym_info.agent.optim.loss).shape
(10000, 20)
pd.DataFrame(gym_info.agent.optim.loss).mean(1).expanding().mean().plot(
ylim=(0.09, 0.15)
)

i_batch = 0
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][:, i_batch, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title=f"weights on batch {i_batch}")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight value")
plt.figure()
w.iloc[-1][::-1].plot(kind="bar")


w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"].mean(axis=1)[:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")
plt.figure()
w.iloc[-1][::-1].plot(kind="bar")


Taking mean inside simulation¶
%%time
def add_batch(fun, take_mean=True):
def fun_batch(*args, **kwargs):
res = VMap(fun)(*args, **kwargs)
if take_mean:
res = jax.tree_map(lambda x: x.mean(axis=0), res)
return res
return fun_batch
gym_loop_batch = add_batch(gym_loop)
sim = unroll_transform_with_state(gym_loop_batch)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3
params, state = sim.init(rng, eps)
(gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
CPU times: user 1.53 s, sys: 18.8 ms, total: 1.54 s
Wall time: 1.54 s
pd.Series(-gym_output.reward).expanding().mean().plot(ylim=(0.09, 0.15))

w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")
plt.figure()
w.iloc[-1][::-1].plot(kind="bar")


Hyper parameter tuning¶
First order optimizers¶
We will consider different first order optimizers, namely:
SGD
ADAM
ADAGRAD
For each of them, we will scan the “step_size” parameter \(\eta\).
We will average results over batches of size 40.
We will consider trajectories of size 10.000.
Finally, we will pickup the best parameter based of the minimum averaged loss for the last 5000 time steps.
%%time
STEP_SIZE_idx = pd.Index(onp.logspace(-4, 1, 30), name="step_size")
STEP_SIZE = jax.device_put(STEP_SIZE_idx.values)
OPTIMIZERS = [optax.sgd, optax.adagrad, optax.rmsprop, optax.adam]
res = {}
for optimizer in tqdm(OPTIMIZERS):
def gym_loop_scan_hparams(eps):
def scan_params(step_size):
return GymFeedback(build_agent(opt=optimizer(step_size)), env)(eps)
res = VMap(scan_params)(STEP_SIZE)
return res
sim = unroll_transform_with_state(add_batch(gym_loop_scan_hparams))
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3
params, state = sim.init(rng, eps)
_res, state = sim.apply(params, state, rng, eps)
res[optimizer.__name__] = _res
CPU times: user 11.5 s, sys: 111 ms, total: 11.6 s
Wall time: 11.6 s
ax = None
BEST_STEP_SIZE = {}
BEST_GYM = {}
for name, (gym, info) in res.items():
loss = pd.DataFrame(-gym.reward, columns=STEP_SIZE).iloc[-5000:].mean()
BEST_STEP_SIZE[name] = loss.idxmin()
best_idx = loss.reset_index(drop=True).idxmin()
BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)
ax = loss[loss < 0.15].plot(logx=True, logy=False, ax=ax, label=name)
plt.legend()

for name, gym in BEST_GYM.items():
ax = (
pd.Series(-gym.reward)
.expanding()
.mean()
.plot(
label=f"{name} - $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.15)
)
)
ax.legend(bbox_to_anchor=(1.0, 1.0))

pd.Series(BEST_STEP_SIZE).plot(kind="bar", logy=True)

Newton algorithm¶
Now let’s consider the newton algorithm.
First let’s test it with one set of parameter with average over N_BATCH batches.
%%time
@add_batch
def gym_loop_newton(eps):
return GymFeedback(build_agent(opt=newton(0.05, eps=20.0)), env)(eps)
sim = unroll_transform_with_state(gym_loop_newton)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3
params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
CPU times: user 1.88 s, sys: 43.5 ms, total: 1.92 s
Wall time: 1.73 s
pd.Series(-gym.reward).expanding().mean().plot()

%%time
STEP_SIZE = pd.Index(onp.logspace(-2, 3, 10), name="step_size")
EPS = pd.Index(onp.logspace(-4, 3, 5), name="eps")
HPARAMS_idx = pd.MultiIndex.from_product([STEP_SIZE, EPS])
HPARAMS = jnp.stack(list(map(onp.array, HPARAMS_idx)))
@add_batch
def gym_loop_scan_hparams(eps):
def scan_params(hparams):
step_size, newton_eps = hparams
agent = build_agent(opt=newton(step_size, eps=newton_eps))
return GymFeedback(agent, env)(eps)
return VMap(scan_params)(HPARAMS)
sim = unroll_transform_with_state(gym_loop_scan_hparams)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3
params, state = sim.init(rng, eps)
res_newton, state = sim.apply(params, state, rng, eps)
CPU times: user 16.1 s, sys: 118 ms, total: 16.2 s
Wall time: 16 s
gym_newton, info_newton = res_newton
loss_newton = pd.DataFrame(-gym_newton.reward, columns=HPARAMS_idx).mean().unstack()
loss_newton = (
pd.DataFrame(-gym_newton.reward, columns=HPARAMS_idx).iloc[-5000:].mean().unstack()
)
sns.heatmap(loss_newton[loss_newton < 0.4], annot=True, cmap="YlGnBu")
STEP_SIZE, NEWTON_EPS = loss_newton.stack().idxmin()
x = -gym_newton.reward[-5000:].mean(axis=0)
x = jax.ops.index_update(x, jnp.isnan(x), jnp.inf)
I_BEST_PARAM = jnp.argmin(x)
BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)
Best newton parameters: 0.464158883361278 0.31622776601683794

for name, gym in BEST_GYM.items():
pd.Series(-gym.reward).rolling(5000, min_periods=5000).mean().plot(
label=f"{name} - $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.1)
)
gym = BEST_NEWTON_GYM
ax = (
pd.Series(-gym.reward)
.rolling(5000, min_periods=5000)
.mean()
.plot(
label=f"Newton - $\eta$={STEP_SIZE:.2e}, $\epsilon$={NEWTON_EPS:.2e}"
)
)
ax.legend(bbox_to_anchor=(1.0, 1.0))
plt.title("Rolling mean of loss (5000) time-steps")

for name, gym in BEST_GYM.items():
pd.Series(-gym.reward).expanding().mean().plot(
label=f"{name} - $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.15)
)
gym = BEST_NEWTON_GYM
ax = (
pd.Series(-gym.reward)
.expanding()
.mean()
.plot(
label=f"Newton - $\eta$={STEP_SIZE:.2e}, $\epsilon$={NEWTON_EPS:.2e}"
)
)
ax.legend(bbox_to_anchor=(1.0, 1.0))

In agreement with results in [1], we see that Newton’s algorithm performs much better than SGD.
In addition, we note that:
ADAGRAD performormance is between newton and sgd.
RMSPROPR and ADAM does not perform well in this online setting.
🔄 Online learning in non-stationary environments 🔄¶
We reproduce the empirical results of [1].
References¶
%pylab inline
%load_ext autoreload
%autoreload 2
Populating the interactive namespace from numpy and matplotlib
from typing import Any, NamedTuple
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
import optax
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from wax.modules import ARMA, SNARIMAX, GymFeedback, OnlineOptimizer, UpdateParams, VMap
from wax.modules.lag import tree_lag
from wax.modules.vmap import add_batch
from wax.optim import newton
from wax.unroll import unroll_transform_with_state
T = 10000
N_BATCH = 20
N_STEP_SIZE = 30
N_STEP_SIZE_NEWTON = 10
N_EPS = 5
Agent¶
OPTIMIZERS = [optax.sgd, optax.adagrad, optax.rmsprop, optax.adam]
from optax._src.base import OptState
def build_agent(time_series_model=None, opt=None):
if time_series_model is None:
time_series_model = lambda y, X: SNARIMAX(10)(y, X)
if opt is None:
opt = optax.sgd(1.0e-3)
class AgentInfo(NamedTuple):
optim: Any
forecast: Any
class ModelWithLossInfo(NamedTuple):
pred: Any
loss: Any
def agent(obs):
if isinstance(obs, tuple):
y, X = obs
else:
y = obs
X = None
def evaluate(y_pred, y):
return jnp.linalg.norm(y_pred - y) ** 2, {}
def model_with_loss(y, X=None):
# predict with lagged data
y_pred, pred_info = time_series_model(*tree_lag(1)(y, X))
# evaluate loss with actual data
loss, loss_info = evaluate(y_pred, y)
return loss, ModelWithLossInfo(pred_info, loss_info)
def project_params(params: Any, opt_state: OptState = None):
del opt_state
return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)
def split_params(params):
def filter_params(m, n, p):
# print(m, n, p)
return m.endswith("snarimax/~/linear") and n == "w"
return hk.data_structures.partition(filter_params, params)
def learn_and_forecast(y, X=None):
# use lagged data for the optimizer
optim_res = OnlineOptimizer(
model_with_loss,
opt,
project_params=project_params,
split_params=split_params,
)(*tree_lag(1)(y, X))
# use updated params to forecast with actual data
predict_params = optim_res.updated_params
y_pred, forecast_info = UpdateParams(time_series_model)(
predict_params, y, X
)
return y_pred, AgentInfo(optim_res, forecast_info)
return learn_and_forecast(y, X)
return agent
Non-stationary environments¶
We will now wrapup the study of an environment + agent in few analysis functions.
We will then use them to perform the same analysis in the non-stationary setting proposed in [1], namely:
setting 1 : sanity check (stationary ARMA environment).
setting 2 : slowly varying parameters.
setting 3 : brutal variation of parameters.
setting 4 : non-stationary (random walk) noise.
Analysis functions¶
For each solver, we will select the best hyper parameters (step size \(\eta\), \(\epsilon\)) by measuring the average loss between the 5000 and 10000 steps.
First order solvers¶
def scan_hparams_first_order():
STEP_SIZE_idx = pd.Index(onp.logspace(-4, 1, N_STEP_SIZE), name="step_size")
STEP_SIZE = jax.device_put(STEP_SIZE_idx.values)
rng = jax.random.PRNGKey(42)
eps = sample_noise(rng)
res = {}
for optimizer in tqdm(OPTIMIZERS):
def gym_loop_scan_hparams(eps):
def scan_params(step_size):
return GymFeedback(build_agent(opt=optimizer(step_size)), env)(eps)
res = VMap(scan_params)(STEP_SIZE)
return res
sim = unroll_transform_with_state(add_batch(gym_loop_scan_hparams))
params, state = sim.init(rng, eps)
_res, state = sim.apply(params, state, rng, eps)
res[optimizer.__name__] = _res
ax = None
BEST_STEP_SIZE = {}
BEST_GYM = {}
for name, (gym, info) in res.items():
loss = (
pd.DataFrame(-gym.reward, columns=STEP_SIZE).iloc[LEARN_TIME_SLICE].mean()
)
BEST_STEP_SIZE[name] = loss.idxmin()
best_idx = jnp.argmax(gym.reward[LEARN_TIME_SLICE].mean(axis=0))
BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)
ax = loss.plot(
logx=True, logy=False, ax=ax, label=name, ylim=(MIN_ERR, MAX_ERR)
)
plt.legend()
return BEST_STEP_SIZE, BEST_GYM
We will “cross-validate” the result by running the agent on new samples.
CROSS_VAL_RNG = jax.random.PRNGKey(44)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
COLORS = sns.color_palette("hls")
def cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM):
plt.figure()
eps = sample_noise(CROSS_VAL_RNG)
CROSS_VAL_GYM = {}
ax = None
# def measure(reward):
# return pd.Series(-reward).rolling(T/2, min_periods=T/2).mean()
def measure(reward):
return pd.Series(-reward).expanding().mean()
for i, (name, gym) in enumerate(BEST_GYM.items()):
ax = measure(gym.reward).plot(
ax=ax,
color=COLORS[i],
label=(f"(TRAIN) - {name} " f"- $\eta$={BEST_STEP_SIZE[name]:.2e}"),
style="--",
)
for i, optimizer in enumerate(tqdm(OPTIMIZERS)):
name = optimizer.__name__
def gym_loop(eps):
return GymFeedback(build_agent(opt=optimizer(BEST_STEP_SIZE[name])), env)(
eps
)
sim = unroll_transform_with_state(add_batch(gym_loop))
rng = jax.random.PRNGKey(42)
params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
CROSS_VAL_GYM[name] = gym
ax = measure(gym.reward).plot(
ax=ax,
color=COLORS[i],
ylim=(MIN_ERR, MAX_ERR),
label=(
f"(VALIDATE) - {name} " f"- $\eta$={BEST_STEP_SIZE[name]:.2e}"
),
)
plt.legend()
return CROSS_VAL_GYM
Newton solver¶
def scan_hparams_newton():
STEP_SIZE = pd.Index(onp.logspace(-2, 3, N_STEP_SIZE_NEWTON), name="step_size")
EPS = pd.Index(onp.logspace(-4, 3, N_EPS), name="eps")
HPARAMS_idx = pd.MultiIndex.from_product([STEP_SIZE, EPS])
HPARAMS = jnp.stack(list(map(onp.array, HPARAMS_idx)))
@add_batch
def gym_loop_scan_hparams(eps):
def scan_params(hparams):
step_size, newton_eps = hparams
agent = build_agent(opt=newton(step_size, eps=newton_eps))
return GymFeedback(agent, env)(eps)
return VMap(scan_params)(HPARAMS)
sim = unroll_transform_with_state(gym_loop_scan_hparams)
rng = jax.random.PRNGKey(42)
eps = sample_noise(rng)
params, state = sim.init(rng, eps)
res_newton, state = sim.apply(params, state, rng, eps)
gym_newton, info_newton = res_newton
loss_newton = (
pd.DataFrame(-gym_newton.reward, columns=HPARAMS_idx)
.iloc[LEARN_TIME_SLICE]
.mean()
.unstack()
)
sns.heatmap(loss_newton[loss_newton < 0.4], annot=True, cmap="YlGnBu")
STEP_SIZE, NEWTON_EPS = loss_newton.stack().idxmin()
x = -gym_newton.reward[LEARN_TIME_SLICE].mean(axis=0)
x = jax.ops.index_update(x, jnp.isnan(x), jnp.inf)
I_BEST_PARAM = jnp.argmin(x)
BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)
return (STEP_SIZE, NEWTON_EPS), BEST_NEWTON_GYM
def cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM):
(STEP_SIZE, NEWTON_EPS) = BEST_HPARAMS
plt.figure()
# def measure(reward):
# return pd.Series(-reward).rolling(T/2, min_periods=T/2).mean()
def measure(reward):
return pd.Series(-reward).expanding().mean()
@add_batch
def gym_loop(eps):
agent = build_agent(opt=newton(STEP_SIZE, eps=NEWTON_EPS))
return GymFeedback(agent, env)(eps)
sim = unroll_transform_with_state(gym_loop)
rng = jax.random.PRNGKey(44)
eps = sample_noise(rng)
params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
ax = None
i = 4
ax = measure(BEST_NEWTON_GYM.reward).plot(
ax=ax,
color=COLORS[i],
label=f"(TRAIN) - Newton - $\eta$={STEP_SIZE:.2e}, $\epsilon$={NEWTON_EPS:.2e}",
ylim=(MIN_ERR, MAX_ERR),
style="--",
)
ax = measure(gym.reward).plot(
ax=ax,
color=COLORS[i],
ylim=(MIN_ERR, MAX_ERR),
label=f"(VALIDATE) - Newton - $\eta$={STEP_SIZE:.2e}, $\epsilon$={NEWTON_EPS:.2e}",
)
plt.legend()
return gym
Plot everithing¶
def plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM):
MESURES = []
def measure(reward):
return pd.Series(-reward).rolling(int(T / 2), min_periods=int(T / 2)).mean()
MESURES.append(("Rolling mean of loss (5000) time-steps", measure))
def measure(reward):
return pd.Series(-reward).expanding().mean()
MESURES.append(("Expanding means", measure))
for MEASURE_NAME, MEASUR_FUNC in MESURES:
plt.figure()
for i, (name, gym) in enumerate(BEST_GYM.items()):
MEASUR_FUNC(gym.reward).plot(
label=f"{name} - $\eta$={BEST_STEP_SIZE[name]:.2e}",
ylim=(MIN_ERR, MAX_ERR),
color=COLORS[i],
)
i = 4
(STEP_SIZE, NEWTON_EPS) = BEST_HPARAMS
gym = BEST_NEWTON_GYM
ax = MEASUR_FUNC(gym.reward).plot(
label=f"Newton - $\eta$={STEP_SIZE:.2e}, $\epsilon$={NEWTON_EPS:.2e}",
ylim=(MIN_ERR, MAX_ERR),
color=COLORS[i],
)
ax.legend(bbox_to_anchor=(1.0, 1.0))
plt.title(MEASURE_NAME)
Setting 1¶
Environment¶
let’s wrapup the results for the “setting 1” in [1]
from wax.modules import Counter
def build_env():
def env(action, obs):
y_pred, eps = action, obs
ar_coefs = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
ma_coefs = jnp.array([0.3, -0.2])
y = ARMA(ar_coefs, ma_coefs)(eps)
rw = -((y - y_pred) ** 2)
env_info = {"y": y, "y_pred": y_pred}
obs = y
return rw, obs, env_info
return env
def sample_noise(rng):
eps = jax.random.normal(rng, (T, 20)) * 0.3
return eps
MIN_ERR = 0.09
MAX_ERR = 0.15
LEARN_TIME_SLICE = slice(int(T / 2), T)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()

CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)

BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters: 0.464158883361278 0.31622776601683794

CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)

plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)


Conclusions¶
The NEWTON and ADAGRAD optimizers are the faster to converge.
The SGD and ADAM optimizers have the worst performance.
Fixed setting¶
@add_batch
def gym_loop_newton(eps):
return GymFeedback(build_agent(opt=newton(0.1, eps=0.3)), env)(eps)
def run_fixed_setting():
rng = jax.random.PRNGKey(42)
eps = sample_noise(rng)
sim = unroll_transform_with_state(gym_loop_newton)
params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
pd.Series(-gym.reward).expanding().mean().plot() # ylim=(MIN_ERR, MAX_ERR))
%%time
run_fixed_setting()
CPU times: user 1.69 s, sys: 19.4 ms, total: 1.71 s
Wall time: 1.7 s

Setting 2¶
Environment¶
let’s build an environment corresponding to “setting 2” in [1]
from wax.modules import Counter
def build_env():
def env(action, obs):
y_pred, eps = action, obs
t = Counter()()
ar_coefs_1 = jnp.array([-0.4, -0.5, 0.4, 0.4, 0.1])
ar_coefs_2 = jnp.array([0.6, -0.4, 0.4, -0.5, 0.5])
ar_coefs = ar_coefs_1 * t / T + ar_coefs_2 * (1 - t / T)
ma_coefs = jnp.array([0.32, -0.2])
y = ARMA(ar_coefs, ma_coefs)(eps)
# prediction used on a fresh y observation.
rw = -((y - y_pred) ** 2)
env_info = {"y": y, "y_pred": y_pred}
obs = y
return rw, obs, env_info
return env
def sample_noise(rng):
eps = jax.random.uniform(rng, (T, 20), minval=-0.5, maxval=0.5)
return eps
MIN_ERR = 0.0833
MAX_ERR = 0.15
LEARN_TIME_SLICE = slice(int(T / 2), T)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()

CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)

BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters: 1.6681005372000592 0.31622776601683794

CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)

plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)


Conclusions¶
The NEWTON and ADAGRAD optimizers are more efficient to adapt to slowly changing environments.
The SGD and ADAM optimizers seem to have the worst performance.
Fixed setting¶
%%time
run_fixed_setting()
CPU times: user 1.77 s, sys: 22.5 ms, total: 1.79 s
Wall time: 1.76 s

Setting 3¶
Environment¶
Let us build an environment corresponding to the “setting 3” of [1]. We modify it slightly by adding 10000 steps. We intentionally use use the 5000 to 10000 steps to optimize the hyper parameters. This allows us to evaluate how the models “over-optimize”.
from wax.modules import Counter
def build_env():
def env(action, obs):
y_pred, eps = action, obs
t = Counter()()
ar_coefs_1 = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
ar_coefs_2 = jnp.array([-0.4, -0.5, 0.4, 0.4, 0.1])
ar_coefs = jnp.where(t < int(Tlong / 2), ar_coefs_1, ar_coefs_2)
ma_coefs_1 = jnp.array([0.3, -0.2])
ma_coefs_2 = jnp.array([-0.3, 0.2])
ma_coefs = jnp.where(t < int(Tlong / 2), ma_coefs_1, ma_coefs_2)
y = ARMA(ar_coefs, ma_coefs)(eps)
# prediction used on a fresh y observation.
rw = -((y - y_pred) ** 2)
env_info = {"y": y, "y_pred": y_pred}
obs = y
return rw, obs, env_info
return env
def sample_noise(rng):
eps = jax.random.uniform(rng, (Tlong, N_BATCH), minval=-0.5, maxval=0.5)
return eps
Tlong = 2 * T
MIN_ERR = 0.0833
MAX_ERR = 0.12
LEARN_TIME_SLICE = slice(int(T / 2), T)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()

CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)

BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters: 0.464158883361278 17.78279410038923

CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)

plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)


Choosing hyper parameters on the whole period¶
It seems that Newton solver is more prone to overfitting (recall that we chose its hyper parameters to optimize the average loss between steps 5000 and 1000, thus only in the first regime).
However, as stated in [1], Newton algorithm can have better performances if we choose its hyper parameters in order to obtain the best performances for both regimes.
Let us check this:
LEARN_TIME_SLICE = slice(int(Tlong / 2), None)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()

CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)

BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters: 5.994842503189409 17.78279410038923

CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)

plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)


Conclusion¶
The ADAGRAD optimizers seems to be best suited for abrupt regime switching.
The SGD and NEWTON optimizers seem to behave similarly if their parameters are correctly chosen.
The ADAM optimizer seems to have the worst performance.
Fixed setting¶
%%time
run_fixed_setting()
CPU times: user 2.2 s, sys: 32.6 ms, total: 2.23 s
Wall time: 2.23 s

Setting 4¶
Environment¶
let’s build an environment corresponding to “setting 4” in [1]
from wax.modules import Counter
def build_env():
def env(action, obs):
y_pred, eps = action, obs
t = Counter()()
ar_coefs = jnp.array([0.11, -0.5])
ma_coefs = jnp.array([0.41, -0.39, -0.685, 0.1])
# rng = hk.next_rng_key()
prev_eps = hk.get_state("prev_eps", (1,), init=lambda *_: jnp.zeros_like(eps))
eps = prev_eps + eps # jax.random.normal(rng, (1, N_BATCH))
hk.set_state("prev_eps", eps)
y = ARMA(ar_coefs, ma_coefs)(eps)
# prediction used on a fresh y observation.
rw = -((y - y_pred) ** 2)
env_info = {"y": y, "y_pred": y_pred}
obs = y
return rw, obs, env_info
return env
def sample_noise(rng):
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3
return eps
MIN_ERR = 0.09
MAX_ERR = 0.3
LEARN_TIME_SLICE = slice(int(T / 2), T)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()

BEST_STEP_SIZE
{'sgd': 9.999999747378752e-05,
'adagrad': 0.05736152455210686,
'rmsprop': 0.0016102619701996446,
'adam': 0.0016102619701996446}
CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)

BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters: 0.464158883361278 0.31622776601683794

CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)

plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)


As noted in [1], the newton algorithm seems to be the only one to achieve an average error rate that converges to the variance of the noise (0.09).
Conclusion¶
In this environment with noise auto-correlations:
The NEWTON optimizer achieve to realize the minimum theoretical average loss
The other optimizers struggle to converge to the minimum theoretical loss and thus seems to suffer a linear regret.
The SGD optimizer is the worst in this setting.
Fixed setting¶
%%time
run_fixed_setting()
CPU times: user 1.84 s, sys: 16.5 ms, total: 1.86 s
Wall time: 1.85 s

Change log¶
Best viewed here.
wax 0.2.0 (October 20 2021)¶
Documentation:
New notebook : 07_Online_Time_Series_Prediction
New notebook : 08_Online_learning_in_non_stationary_environments
API modifications:
refactor accessors and stream
GymFeedback now assumes that agent and env return info object
OnlineSupervisedLearner action is y_pred, loss and params are returned as info
Improvements:
introduce general unroll transformation.
dynamic_unroll can handle Callable objects
UpdateOnEvent can handle any signature for functions
EWMCov can handle the x and y arguments explicitly
add initial action option to GymFeedback
New Features:
New module UpdateParams
New module SNARIMAX, ARMA
New module OnlineOptimizer
New module VMap
add grads_fill_nan_inf option to OnlineSupervisedLearner
Introduce
unroll_transform_with_state
following Haiku API.New function auto_format_with_shape and tree_auto_format_with_shape
New module Ffill
New module Counter
Deprecate:
deprecate dynamic_unroll and static_unroll, refactor their usages.
Fixes:
Simplify Buffer to work only on ndarrays (implementation on pytrees were too complex)
EWMA behave corectly with gradient
MaskStd behave correctly with gradient
correct encode_int64 when working on int32
update notebook 06_Online_Linear_Regression and add it to run-notebooks rule
correct pct_change to behave correctly when input data has nan values.
correct eagerpy test for update of tensorflow, pytorch and jax
remove duplicate license comments
use numpy.allclose instsead of jax.numpy.allclose for comparaison of non Jax objects
update comment in notebooks : jaxlib==0.1.67+cuda111 to jaxlib==0.1.70+cuda111
fix jupytext dependency
add seaborn as optional dependency
wax 0.1.0 (June 14 2021)¶
First realease.
Installing wax
¶
First, obtain the WAX-ML source code:
git clone https://github.com/eserie/wax-ml
cd wax
You can install wax
by running:
pip install -e .[complete] # install wax
To upgrade to the latest version from GitHub, just run git pull
from the WAX-ML
repository root. You shouldn’t have to reinstall wax
because pip install -e
sets up symbolic links from site-packages into the repository.
You can install wax
development tools by running:
pip install -e .[dev] # install wax-development-tools
Running the tests¶
To run all the WAX-ML tests, we recommend using pytest-xdist
, which can run tests in
parallel. First, install pytest-xdist
and pytest-benchmark
by running
ip install -r build/test-requirements.txt
.
Then, from the repository root directory run:
pytest -n auto .
You can run a more specific set of tests using pytest’s built-in selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:
pytest -v wax/accessors_test.py
The Colab notebooks are tested for errors as part of the documentation build and Github actions.
Type checking¶
We use mypy
to check the type hints. To check types locally the same way
as Github actions checks, you can run:
mypy wax
or
make mypy
Flake8¶
We use flake8
to check that the code follow the pep8 standard.
To check the code, you can run
make flake8
Formatting code¶
We use isort
and black
to format the code.
When you are in the root directory of the project, to format code in the package, you can run:
make format-package
To format notebooks in the documentation, you can use:
make format-notebooks
To format all files you can run:
make format
Note that the CI running with actions will verify that formatting all source code does not affect the files. You can check this locally by running :
make check-format
Check actions¶
You can check that everything is ok by running:
make act
This will check flake8, mypy, isort and black formatting, licenses headers and run tests and coverage.
Update documentation¶
To rebuild the documentation, install several packages:
pip install -r docs/requirements.txt
And then run:
sphinx-build -b html docs docs/build/html
or run
make docs
This can take a long time because it executes many of the notebooks in the documentation source; if you’d prefer to build the docs without executing the notebooks, you can run:
sphinx-build -b html -D jupyter_execute_notebooks=off docs docs/build/html
or run
make docs-fast
You can then see the generated documentation in docs/_build/html/index.html
.
Update notebooks¶
We use jupytext to maintain three synced copies of the notebooks
in docs/notebooks
: one in ipynb
format, one in py
and one in md
format.
The advantage of the former is that it can be opened and executed directly in Colab;
the advantage of the second is that it makes easier to refactor and format python code;
the advantage of the latter is that it makes it much easier to track diffs within version control.
Editing ipynb¶
For making large changes that substantially modify code and outputs, it is easiest to
edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface,
open http://colab.research.google.com and Upload
from your local repo.
Update it as needed, Run all cells
then Download ipynb
.
You may want to test that it executes properly, using sphinx-build
as explained above.
You could format the python code in your notebooks by running make format
in the docs/notebooks
directory
or
make format-notebooks
in the root directory.
Editing md¶
For making smaller changes to the text content of the notebooks, it is easiest to edit the
.md
versions using a text editor.
Syncing notebooks¶
After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:
jupytext --sync docs/notebooks/*
or:
cd docs/notebooks/
make sync
Alternatively, you can run this command via the pre-commit framework by executing the following in the main WAX-ML directory:
pre-commit run --all
See the pre-commit framework documentation for information on how to set your local git environment to execute this automatically.
Creating new notebooks¶
If you are adding a new notebook to the documentation and would like to use the jupytext --sync
command discussed here, you can set up your notebook for jupytext by using the following command:
jupytext --set-formats ipynb,py,md:myst path/to/the/notebook.ipynb
This works by adding a "jupytext"
metadata field to the notebook file which specifies the
desired formats, and which the jupytext --sync
command recognizes when invoked.
Notebooks within the sphinx build¶
Some of the notebooks are built automatically as part of the Travis pre-submit checks and
as part of the Read the docs build.
The build will fail if cells raise errors. If the errors are intentional,
you can either catch them, or tag the cell with raises-exceptions
metadata (example PR).
You have to add this metadata by hand in the .ipynb
file. It will be preserved when somebody else
re-saves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations.
See exclude_patterns
in conf.py.
Documentation building on readthedocs.io¶
WAX-ML’s auto-generated documentations is at https://wax-ml.readthedocs.io/.
The documentation building is controlled for the entire project by the
readthedocs WAX-ML settings. The current settings
trigger a documentation build as soon as code is pushed to the GitHub main
branch.
For each code version, the building process is driven by the
.readthedocs.yml
and the docs/conf.py
configuration files.
For each automated documentation build you can see the documentation build logs.
If you want to test the documentation generation on Readthedocs, you can push code to the test-docs
branch. That branch is also built automatically, and you can
see the generated documentation here. If the documentation build
fails you may want to wipe the build environment for test-docs.
For a local test, you can do it in a fresh directory by replaying the commands executed by Readthedocs and written in their logs:
mkvirtualenv wax-ml-docs # A new virtualenv
mkdir wax-ml-docs # A new directory
cd wax-ml-docs
git clone --no-single-branch --depth 50 https://github.com/eserie/wax-ml
cd wax-ml-docs
git checkout --force origin/test-docs
git clean -d -f
workon wax-ml-docs
python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html
Public API: wax package¶
Subpackages¶
wax.modules package¶
Gym modules¶
In WAX-ML, an agent and environments are simple functions:
A Gym feedback loops can be represented with the diagram:
Equivalently, it can be described with the pair of init and apply functions:
Gym feedback between an agent and a Gym environment. |
Online Learning¶
WAX-ML contains a module to perform online learning for supervised problems.
Online supervised learner. |
Other Haiku modules¶
Implement buffering mechanism. |
|
Implement difference of values on sequential data. |
|
Compute exponentioal moving average. |
|
Compute exponentially weighted covariance. |
|
Compute exponentially weighted variance. |
|
Fill nan, posinf and neginf values. |
|
Detect if something has changed. |
|
Delay operator. |
|
Open-High-Low-Close binning. |
|
Relative change between the current and a prior element. |
|
Rolling mean. |
|
Apply a module when an event occur otherwise return last computed output. |
wax.gym package¶
|
Define accessors for xarray and pandas data containers. |
|
Compilation helper for Haiku Transformed and TransformedWithState pairs of pure funcions. |
|
Encoding schemes to encode/decode numpy data types non supported by JAX, e.g. |
|
Format nested data structures to numpy/xarray/pandas containers. |
|
Define Stream object used to synchronize in-memory data streams and unroll data transformations on it. |
|
Transformation functions to work on batches of data. |
|
Unroll modules on data along first axis. |
|
Some utils functions used in WAX-ML. |