Welcome to WAX-ML¶
WAX-ML is a library for machine-learning on streaming data.
For an introduction to WAX-ML, start at the WAX-ML GitHub page.
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7ffb501244b0>]
〰 Compute exponential moving averages with xarray and pandas accessors 〰¶
WAX-ML implements pandas and xarray accessors to ease the usage of machine-learning algorithms with high-level data APIs :
pandas’s
DataFrame
andSeries
xarray’s
Dataset
andDataArray
.
These accessors allow to easily execute any function using Haiku modules on these data containers.
For instance, WAX-ML propose an implementation of the exponential moving average realized with this mechanism.
Let’s show how it works.
Load accessors¶
First you need to load accessors:
from wax.accessors import register_wax_accessors
register_wax_accessors()
EWMA on dataframes¶
Let’s look at a simple example: The exponential moving average (EWMA).
Let’s apply the EWMA algorithm to the NCEP/NCAR ‘s Air temperature data.
🌡 Load temperature dataset 🌡¶
import xarray as xr
da = xr.tutorial.open_dataset("air_temperature")
Let’s see what this dataset looks like:
da
<xarray.Dataset> Dimensions: (lat: 25, time: 2920, lon: 53) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 ... Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
- lat: 25
- time: 2920
- lon: 53
- lat(lat)float3275.0 72.5 70.0 ... 20.0 17.5 15.0
- standard_name :
- latitude
- long_name :
- Latitude
- units :
- degrees_north
- axis :
- Y
array([75. , 72.5, 70. , 67.5, 65. , 62.5, 60. , 57.5, 55. , 52.5, 50. , 47.5, 45. , 42.5, 40. , 37.5, 35. , 32.5, 30. , 27.5, 25. , 22.5, 20. , 17.5, 15. ], dtype=float32)
- lon(lon)float32200.0 202.5 205.0 ... 327.5 330.0
- standard_name :
- longitude
- long_name :
- Longitude
- units :
- degrees_east
- axis :
- X
array([200. , 202.5, 205. , 207.5, 210. , 212.5, 215. , 217.5, 220. , 222.5, 225. , 227.5, 230. , 232.5, 235. , 237.5, 240. , 242.5, 245. , 247.5, 250. , 252.5, 255. , 257.5, 260. , 262.5, 265. , 267.5, 270. , 272.5, 275. , 277.5, 280. , 282.5, 285. , 287.5, 290. , 292.5, 295. , 297.5, 300. , 302.5, 305. , 307.5, 310. , 312.5, 315. , 317.5, 320. , 322.5, 325. , 327.5, 330. ], dtype=float32)
- time(time)datetime64[ns]2013-01-01 ... 2014-12-31T18:00:00
- standard_name :
- time
- long_name :
- Time
array(['2013-01-01T00:00:00.000000000', '2013-01-01T06:00:00.000000000', '2013-01-01T12:00:00.000000000', ..., '2014-12-31T06:00:00.000000000', '2014-12-31T12:00:00.000000000', '2014-12-31T18:00:00.000000000'], dtype='datetime64[ns]')
- air(time, lat, lon)float32...
- long_name :
- 4xDaily Air temperature at sigma level 995
- units :
- degK
- precision :
- 2
- GRIB_id :
- 11
- GRIB_name :
- TMP
- var_desc :
- Air temperature
- dataset :
- NMC Reanalysis
- level_desc :
- Surface
- statistic :
- Individual Obs
- parent_stat :
- Other
- actual_range :
- [185.16 322.1 ]
[3869000 values with dtype=float32]
- Conventions :
- COARDS
- title :
- 4x daily NMC reanalysis (1948)
- description :
- Data is from NMC initialized reanalysis (4x/day). These are the 0.9950 sigma level values.
- platform :
- Model
- references :
- http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html
To compute a EWMA on some variables of a dataset, we usually need to convert data in pandas series or dataframe.
So, let’s convert the dataset into a dataframe to illustrate accessors
on a dataframe:
dataframe = da.air.to_series().unstack(["lon", "lat"])
EWMA with pandas¶
%%time
air_temp_ewma = dataframe.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
CPU times: user 406 ms, sys: 20.5 ms, total: 426 ms
Wall time: 426 ms

EWMA with WAX-ML¶
%%time
air_temp_ewma = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
CPU times: user 1.78 s, sys: 452 ms, total: 2.23 s
Wall time: 2.22 s

On small data, WAX-ML’s EWMA is slower than Pandas’ because of the expensive data conversion steps. WAX-ML’s accessors are interesting to use on large data loads (See our three-steps_workflow)
Apply a custom function to a Dataset¶
Now let’s illustrate how WAX-ML accessors work on xarray datasets.
from wax.modules import EWMA
def my_custom_function(da):
return {
"air_10": EWMA(1.0 / 10.0)(da["air"]),
"air_100": EWMA(1.0 / 100.0)(da["air"]),
}
da = xr.tutorial.open_dataset("air_temperature")
output, state = da.wax.stream().apply(my_custom_function, format_dims=da.air.dims)
_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7f68731083f0>]
⏱ Synchronize data streams ⏱¶
Physicists, and not the least 😅, have brought a solution to the synchronization problem. See Poincaré-Einstein synchronization Wikipedia page for more details.
In WAX-ML we strive to follow their recommendations and implement a synchronization mechanism between different data streams. Using the terminology of Henri Poincaré (see link above), we introduce the notion of “local time” to unravel the stream in which the user wants to apply transformations. We call the other streams “secondary streams”. They can work at different frequencies, lower or higher. The data from these secondary streams will be represented in the “local time” either with the use of a forward filling mechanism for lower frequencies or a buffering mechanism for higher frequencies.
We implement a “data tracing” mechanism to optimize access to out-of-sync streams. This mechanism works on in-memory data. We perform the first pass on the data, without actually accessing it, and determine the indices necessary to later access the data. Doing so we are vigilant to not let any “future” information pass through and thus guaranty a data processing that respects causality.
The buffering mechanism used in the case of higher frequencies works with a fixed
buffer size (see the WAX-ML module
wax.modules.Buffer
)
which allows us to use JAX / XLA optimizations and have efficient processing.
Let’s illustrate with a small example how wax.stream.Stream
synchronizes data streams.
Let’s use the dataset “air temperature” with :
An air temperature is defined with hourly resolution.
A “fake” ground temperature is defined with a daily resolution as the air temperature minus 10 degrees.
import xarray as xr
da = xr.tutorial.open_dataset("air_temperature")
da["ground"] = da.air.resample(time="d").last().rename({"time": "day"}) - 10
Let’s see what this dataset looks like:
da
<xarray.Dataset> Dimensions: (lat: 25, time: 2920, lon: 53, day: 730) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 * day (day) datetime64[ns] 2013-01-01 2013-01-02 ... 2014-12-31 Data variables: air (time, lat, lon) float32 ... ground (day, lat, lon) float32 231.9 231.8 231.8 ... 286.5 286.2 285.7 Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
- lat: 25
- time: 2920
- lon: 53
- day: 730
- lat(lat)float3275.0 72.5 70.0 ... 20.0 17.5 15.0
- standard_name :
- latitude
- long_name :
- Latitude
- units :
- degrees_north
- axis :
- Y
array([75. , 72.5, 70. , 67.5, 65. , 62.5, 60. , 57.5, 55. , 52.5, 50. , 47.5, 45. , 42.5, 40. , 37.5, 35. , 32.5, 30. , 27.5, 25. , 22.5, 20. , 17.5, 15. ], dtype=float32)
- lon(lon)float32200.0 202.5 205.0 ... 327.5 330.0
- standard_name :
- longitude
- long_name :
- Longitude
- units :
- degrees_east
- axis :
- X
array([200. , 202.5, 205. , 207.5, 210. , 212.5, 215. , 217.5, 220. , 222.5, 225. , 227.5, 230. , 232.5, 235. , 237.5, 240. , 242.5, 245. , 247.5, 250. , 252.5, 255. , 257.5, 260. , 262.5, 265. , 267.5, 270. , 272.5, 275. , 277.5, 280. , 282.5, 285. , 287.5, 290. , 292.5, 295. , 297.5, 300. , 302.5, 305. , 307.5, 310. , 312.5, 315. , 317.5, 320. , 322.5, 325. , 327.5, 330. ], dtype=float32)
- time(time)datetime64[ns]2013-01-01 ... 2014-12-31T18:00:00
- standard_name :
- time
- long_name :
- Time
array(['2013-01-01T00:00:00.000000000', '2013-01-01T06:00:00.000000000', '2013-01-01T12:00:00.000000000', ..., '2014-12-31T06:00:00.000000000', '2014-12-31T12:00:00.000000000', '2014-12-31T18:00:00.000000000'], dtype='datetime64[ns]')
- day(day)datetime64[ns]2013-01-01 ... 2014-12-31
array(['2013-01-01T00:00:00.000000000', '2013-01-02T00:00:00.000000000', '2013-01-03T00:00:00.000000000', ..., '2014-12-29T00:00:00.000000000', '2014-12-30T00:00:00.000000000', '2014-12-31T00:00:00.000000000'], dtype='datetime64[ns]')
- air(time, lat, lon)float32...
- long_name :
- 4xDaily Air temperature at sigma level 995
- units :
- degK
- precision :
- 2
- GRIB_id :
- 11
- GRIB_name :
- TMP
- var_desc :
- Air temperature
- dataset :
- NMC Reanalysis
- level_desc :
- Surface
- statistic :
- Individual Obs
- parent_stat :
- Other
- actual_range :
- [185.16 322.1 ]
[3869000 values with dtype=float32]
- ground(day, lat, lon)float32231.9 231.8 231.8 ... 286.2 285.7
array([[[231.89 , 231.79999, 231.79999, ..., 224.39 , 225.5 , 227.59999], [236.29999, 235.29999, 234.2 , ..., 220.89 , 221.5 , 224.5 ], [246.6 , 244.7 , 242.09999, ..., 220.7 , 221.79999, 226.09999], ..., [286.6 , 286.4 , 286. , ..., 286.5 , 285.79 , 285.29 ], [287. , 287.5 , 287.1 , ..., 286.79 , 286.6 , 286.29 ], [287.5 , 287.69998, 287.5 , ..., 287.79 , 288. , 287.9 ]], [[233.79999, 233.79999, 233.5 , ..., 230.89 , 232.7 , 234.59999], [237.59999, 237.7 , 237.29999, ..., 227.09999, 227.7 , 229.29999], [241.89 , 241.29999, 240.5 , ..., 229.39 , 230.5 , 233.09999], ... [286.59 , 285.88998, 285.29 , ..., 286.88998, 286.29 , 285.38998], [286.69 , 287.49 , 287.29 , ..., 286.69 , 286.29 , 285.59 ], [287.79 , 288.49 , 288.38998, ..., 287.38998, 286.88998, 286.09 ]], [[235.09 , 234.29 , 233.29 , ..., 231.68999, 231.48999, 231.79 ], [239.89 , 239.29 , 238.39 , ..., 229.59 , 230.29 , 231.68999], [252.98999, 252.19 , 251.38998, ..., 229.89 , 232.59 , 236.29 ], ..., [283.79 , 283.69 , 285.09 , ..., 285.29 , 285.09 , 284.69 ], [286.09 , 286.88998, 287.19 , ..., 285.69 , 285.69 , 285.19 ], [287.69 , 288.09 , 288.09 , ..., 286.49 , 286.19 , 285.69 ]]], dtype=float32)
- Conventions :
- COARDS
- title :
- 4x daily NMC reanalysis (1948)
- description :
- Data is from NMC initialized reanalysis (4x/day). These are the 0.9950 sigma level values.
- platform :
- Model
- references :
- http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html
from wax.accessors import register_wax_accessors
register_wax_accessors()
from wax.modules import EWMA
def my_custom_function(da):
return {
"air_10": EWMA(1.0 / 10.0)(da["air"]),
"air_100": EWMA(1.0 / 100.0)(da["air"]),
"ground_100": EWMA(1.0 / 100.0)(da["ground"]),
}
results, state = da.wax.stream(local_time="time", pbar=True).apply(
my_custom_function, format_dims=da.air.dims
)
_ = results.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7f8d889385b0>]
🌡 Binning temperatures 🌡¶
Let’s again considering the air temperatures dataset. It is sampled at a hourly resolution. We will make “trailing” air temperature bins during each day and “reset” the bin aggregation process at each day change.
import numpy as onp
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.modules import OHLC, HasChanged
register_wax_accessors()
da = xr.tutorial.open_dataset("air_temperature")
da["date"] = da.time.dt.date.astype(onp.datetime64)
def bin_temperature(da):
day_change = HasChanged()(da["date"])
return OHLC()(da["air"], reset_on=day_change)
output, state = da.wax.stream().apply(
bin_temperature, format_dims=onp.array(da.air.dims)
)
output = xr.Dataset(output._asdict())
df = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().loc["2013-01"]
_ = df.plot(figsize=(12, 8))

The UpdateOnEvent
module¶
The OHLC
module uses the primitive wax.modules.UpdateOnEvent
.
Its implementation required to complete Haiku with a central function
set_params_or_state_dict
which we have actually integrated in this WAX-ML module.
We have opened an issue on the Haiku github to integrate it in Haiku.
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
jax backend gpu
[GpuDevice(id=0, process_index=0)]
🎛 The 3-steps workflow 🎛¶
It is already very useful to be able to execute a JAX function on a dataframe in a single work step and with a single command line thanks to WAX-ML accessors.
The 1-step WAX-ML’s stream API works like that:
<data-container>.stream(...).apply(...)
But this is not optimal because, under the hood, there are mainly three costly steps:
(1) (synchronize | data tracing | encode): make the data “JAX ready”
(2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.
(3) (format): convert data back to pandas/xarray/numpy format.
With the wax.stream
primitives, it is quite easy to explicitly split the 1-step workflow
into a 3-step workflow.
This will allow the user to have full control over each step and iterate on each one.
It is actually very useful to iterate on step (2), the “calculation step” when
you are doing research.
You can then take full advantage of the JAX primitives, especially the jit
primitive.
Let’s illustrate how to reimplement WAX-ML EWMA yourself with the WAX-ML 3-step workflow.
Imports¶
import haiku as hk
import numpy as onp
import pandas as pd
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.compile import jit_init_apply
from wax.external.eagerpy import convert_to_tensors
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import dynamic_unroll
register_wax_accessors()
Performance on big dataframes¶
Generate data¶
T = 1.0e5
N = 1000
%%time
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)
CPU times: user 3.87 s, sys: 143 ms, total: 4.01 s
Wall time: 3.99 s
Pandas EWMA¶
%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 1.95 s, sys: 480 ms, total: 2.42 s
Wall time: 2.42 s
WAX-ML EWMA¶
%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 2.21 s, sys: 679 ms, total: 2.89 s
Wall time: 3.11 s
It’s a little faster, but not that much faster…
WAX-ML EWMA (without format step)¶
Let’s disable the final formatting step (the output is now in raw JAX format):
%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()
CPU times: user 1.77 s, sys: 563 ms, total: 2.33 s
Wall time: 2.22 s
type(df_ewma_wax_no_format)
jaxlib.xla_extension.DeviceArray
Let’s check the device on which the calculation was performed (if you have GPU available, this should be GpuDevice
otherwise it will be CpuDevice
):
df_ewma_wax_no_format.device()
GpuDevice(id=0, process_index=0)
That’s better! In fact (see below) there is a performance problem in the final formatting step. See WEP3 for a proposal to improve the formatting step.
Generate data (in dataset format)¶
WAX-ML Sream
object works on datasets.
So let’s transform the DataFrame
into a xarray Dataset
:
dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")
Step (1) (synchronize | data tracing | encode)¶
In this step, WAX-ML do:
“data tracing” : prepare the indices for fast access tin the JAX function
access_data
synchronize streams if there is multiple ones. This functionality have options :
freq
,ffills
encode and convert data from numpy to JAX: use encoders for
datetimes64
andstring_
dtypes. Be aware that by default JAX works in float32 (see JAX’s Common Gotchas to work in float64).
We have a function Stream.prepare
that implement this Step (1).
It prepares a function that wraps the input function with the actual data and indices
in a pair of pure functions (TransformedWithState
Haiku tuple).
%%time
stream = dataframe.wax.stream()
CPU times: user 20 µs, sys: 5 µs, total: 25 µs
Wall time: 29.8 µs
Define our custom function to be applied on a dict of arrays having the same structure than the original dataset:
def my_ewma_on_dataset(dataset):
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
transform_dataset, jxs = stream.prepare(dataset, my_ewma_on_dataset)
Let’s definite the init parameters and state of the transformation we will apply.
Init params and state¶
from wax.unroll import init_params_state
rng = jax.random.PRNGKey(42)
params, state = init_params_state(transform_dataset, rng, jxs)
params
FlatMapping({'ewma': FlatMapping({'alpha': DeviceArray(0.1, dtype=float32)})})
assert state["ewma"]["count"].shape == (N,)
assert state["ewma"]["mean"].shape == (N,)
Step (2) (compile | code tracing | execution)¶
In this step we:
prepare a pure function (with Haiku’s transform mechanism) Define a “transformation” function which:
access to the data
apply another transformation, here: EWMA
compile it with
jax.jit
perform code tracing and execution (the last line):
Unroll the transformation on “steps”
xs
(anp.arange
vector).
rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
outputs.device()
GpuDevice(id=0, process_index=0)
Once it has been compiled and “traced” by JAX, the function is much faster to execute:
%%timeit
outputs, _ = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
1 loop, best of 5: 1.18 s per loop
%%time
outputs, _ = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
CPU times: user 1.29 s, sys: 6.21 ms, total: 1.29 s
Wall time: 1.24 s
This is 3x faster than pandas implementation!
(The 3x factor is obtained by measuring the execution with %timeit. We don’t know why, but when executing a code cell once at a time, then the execution time can vary a lot and we can observe some executions with a speed-up of 100x).
Manually prepare the data and manage the device¶
In order to manage the device on which the computations take place,
we need to have even more control over the execution flow.
Instead of calling stream.prepare
to build the transform_dataset
function,
we can do it ourselves by :
using the
stream.trace_dataset
functionconverting the numpy data in jax ourself
puting the data on the device we want.
np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:2983: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "asarray")
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:2983: UserWarning: Explicitly requested dtype int64 requested in asarray is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "asarray")
We explicitly set data on CPUs (the is not needed if you only have CPUs):
from jax.tree_util import tree_leaves, tree_map
cpus = jax.devices("cpu")
jnp_data, jnp_index, jxs = tree_map(
lambda x: jax.device_put(x, cpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to CPU device.")
data copied to CPU device.
We have now “JAX-ready” data for later fast access.
Let’s define the transformation that wrap the actual data and indices in a pair of pure functions:
%%time
@jit_init_apply
@hk.transform_with_state
def transform_dataset(step):
dataset = tree_access_data(jnp_data, jnp_index, step)
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
CPU times: user 277 µs, sys: 0 ns, total: 277 µs
Wall time: 289 µs
And we can call it as before:
%%time
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 373 ms, sys: 163 ms, total: 537 ms
Wall time: 533 ms
outputs.device()
CpuDevice(id=0)
Step(3) (format)¶
Let’s come back to pandas/xarray:
%%time
y = format_dataframe(
dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)
CPU times: user 2.14 s, sys: 9.66 ms, total: 2.15 s
Wall time: 2.13 s
It’s quite slow (see WEP3 enhancement proposal).
GPU execution¶
Let’s look with execution on GPU
try:
gpus = jax.devices("gpu")
jnp_data, jnp_index, jxs = tree_map(
lambda x: jax.device_put(x, gpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to GPU device.")
GPU_AVAILABLE = True
except RuntimeError as err:
print(err)
GPU_AVAILABLE = False
data copied to GPU device.
Let’s check that our data is on the GPUs:
tree_leaves(jnp_data)[0].device()
GpuDevice(id=0, process_index=0)
tree_leaves(jnp_index)[0].device()
GpuDevice(id=0, process_index=0)
jxs.device()
GpuDevice(id=0, process_index=0)
%%time
if GPU_AVAILABLE:
rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 1.7 s, sys: 555 ms, total: 2.25 s
Wall time: 2.14 s
Let’s redefine our function transform_dataset
by explicitly specify to jax.jit
the device
option.
%%time
if GPU_AVAILABLE:
@hk.transform_with_state
def transform_dataset(step):
dataset = tree_access_data(jnp_data, jnp_index, step)
return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
transform_dataset = type(transform_dataset)(
transform_dataset.init, jax.jit(transform_dataset.apply, device=gpus[0])
)
rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 1.48 s, sys: 17.6 ms, total: 1.5 s
Wall time: 2.13 s
outputs.device()
GpuDevice(id=0, process_index=0)
%%timeit
if GPU_AVAILABLE:
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
1 loop, best of 5: 1.18 s per loop
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
🔭 Reconstructing the light curve of stars with LSTM 🔭¶
Let’s take a walk through the stars…
This notebook is based on the study done in this post by Christophe Pere and the notebook available on the authors’s github.
We will repeat this study on starlight using the LSTM architecture to predict the observed light flux through time.
Our LSTM implementation is based on this notebook from Haiku’s github repository.
We’ll see how to use WAX-ML to ease the preparation of time series data stored in dataframes and having Nans before calling a “standard” deep-learning workflow.
Disclaimer¶
Despite the fact that this code works with real data, the results presented here should not be considered as scientific knowledge insights, to the knowledge of the authors of WAX-ML, neither the results nor the data source have been reviewed by an astrophysics pair.
The purpose of this notebook is only to demonstrate how WAX-ML can be used when applying a “standard” machine learning workflow, here LSTM, to analyze time series.
Download the data¶
%matplotlib inline
import io
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
from sklearn.preprocessing import MinMaxScaler
from tqdm.auto import tqdm
from wax.accessors import register_wax_accessors
from wax.modules import RollingMean
register_wax_accessors()
# Parameters
STAR = "007609553"
SEQ_LEN = 64
BATCH_SIZE = 8
TRAIN_STEPS = 2 ** 16
TRAIN_SIZE = 2 ** 16
TOTAL_LEN = None
TRAIN_DATE = "2016"
CACHE_DIR = Path("./cached_data/")
%%time
filename = CACHE_DIR / "kep_lightcurves.parquet"
try:
raw_dataframe = pd.read_parquet(open(filename, "rb"))
print(f"data read from {filename}")
except FileNotFoundError:
# Downloading the csv file from Chrustioge Pere GitHub account
download = requests.get(
"https://raw.github.com/Christophe-pere/Time_series_RNN/master/kep_lightcurves.csv"
).content
raw_dataframe = pd.read_csv(io.StringIO(download.decode("utf-8")))
# set date index
raw_dataframe.index = pd.Index(
pd.date_range("2009-03-07", periods=len(raw_dataframe.index), freq="h"),
name="time",
)
# save dataframe locally in CACHE_DIR
CACHE_DIR.mkdir(exist_ok=True)
raw_dataframe.to_parquet(filename)
print(f"data saved in {filename}")
data saved in cached_data/kep_lightcurves.parquet
CPU times: user 1.2 s, sys: 197 ms, total: 1.39 s
Wall time: 3.73 s
# shortening of data to speed up the execution of the notebook in the CI
if TOTAL_LEN:
raw_dataframe = raw_dataframe.iloc[:TOTAL_LEN]
Let’s visualize the description of this dataset:
raw_dataframe.describe().T.to_xarray()
<xarray.Dataset> Dimensions: (index: 52) Coordinates: * index (index) object '001430305_orig' ... '011611275_res' Data variables: count (index) float64 6.48e+04 5.674e+04 ... 5.673e+04 5.673e+04 mean (index) float64 6.776e+04 -0.2265 0.01231 ... 0.001437 0.004351 std (index) float64 1.363e+03 15.42 15.27 12.45 ... 4.648 6.415 4.904 min (index) float64 6.529e+04 -123.3 -75.59 ... -20.32 -31.97 -20.89 25% (index) float64 6.619e+04 -9.488 -9.875 ... -3.269 -4.281 -3.279 50% (index) float64 6.806e+04 -0.3476 0.007812 ... 0.007812 -0.06529 75% (index) float64 6.882e+04 8.988 10.02 8.092 ... 2.872 4.277 3.213 max (index) float64 7.021e+04 128.7 72.31 69.34 ... 26.53 30.94 29.45
- index: 52
- index(index)object'001430305_orig' ... '011611275_...
array(['001430305_orig', '001430305_rscl', '001430305_diff', '001430305_res', '001724719_orig', '001724719_rscl', '001724719_diff', '001724719_res', '005209845_orig', '005209845_rscl', '005209845_diff', '005209845_res', '007596240_orig', '007596240_rscl', '007596240_diff', '007596240_res', '007609553_orig', '007609553_rscl', '007609553_diff', '007609553_res', '008241079_orig', '008241079_rscl', '008241079_diff', '008241079_res', '008247770_orig', '008247770_rscl', '008247770_diff', '008247770_res', '009345933_orig', '009345933_rscl', '009345933_diff', '009345933_res', '009347009_orig', '009347009_rscl', '009347009_diff', '009347009_res', '009349482_orig', '009349482_rscl', '009349482_diff', '009349482_res', '009349757_orig', '009349757_rscl', '009349757_diff', '009349757_res', '010024701_orig', '010024701_rscl', '010024701_diff', '010024701_res', '011611275_orig', '011611275_rscl', '011611275_diff', '011611275_res'], dtype=object)
- count(index)float646.48e+04 5.674e+04 ... 5.673e+04
array([64795., 56736., 50906., 50906., 64795., 59620., 55595., 55595., 54969., 49817., 45968., 45968., 64792., 60090., 56475., 56475., 64793., 50829., 45782., 45782., 64794., 54544., 50508., 50508., 64793., 58882., 54562., 54562., 64797., 60963., 58136., 58136., 64793., 60579., 57363., 57363., 64793., 61021., 58217., 58217., 64793., 60789., 57793., 57793., 51732., 47598., 44189., 44189., 64792., 60245., 56728., 56728.])
- mean(index)float646.776e+04 -0.2265 ... 0.004351
array([ 6.77637640e+04, -2.26476394e-01, 1.23063833e-02, 4.76521496e-02, 3.48858739e+04, -2.60434079e-01, -6.42200288e-04, 8.80227557e-03, 5.39107577e+03, 7.03697184e-03, -1.60903054e-02, -2.92293746e-02, 1.63334514e+04, -1.51960427e-01, -8.43860668e-03, -1.19774483e-02, 2.13926602e+04, -4.33682072e-01, 8.27861864e-03, 9.57967429e-03, 5.67090880e+04, -3.64663429e-01, -2.67114244e-03, 6.54392308e-03, 2.91100849e+04, -8.32231802e-02, -3.71574138e-03, -2.05330738e-02, 7.19740760e+03, -3.04019017e-02, -7.65547327e-04, 1.39132462e-03, 7.52192300e+03, -1.21187104e-01, 9.12380197e-03, 1.61813028e-02, 8.94735741e+03, -9.96904951e-02, 4.80096959e-03, 1.14104920e-02, 1.17224306e+04, -1.00470553e-01, 9.67554394e-03, 2.29960238e-02, 6.68149138e+04, -1.87700299e-01, -1.28047280e-02, -6.32266111e-03, 1.29621460e+04, -1.57874003e-01, 1.43721396e-03, 4.35098594e-03])
- std(index)float641.363e+03 15.42 ... 6.415 4.904
array([1362.5144182 , 15.42150384, 15.27052111, 12.44643235, 1528.19752763, 10.91976669, 11.1893542 , 8.62113287, 204.78780872, 20.19896883, 5.84919971, 4.6135537 , 535.20007192, 4.91307619, 6.8001509 , 5.21149265, 660.12491961, 153.47838423, 7.95784752, 6.83305069, 999.65408025, 15.75699475, 11.27259006, 8.82762192, 893.28598443, 11.31868477, 9.1138139 , 7.4161203 , 432.80513443, 21.70344223, 5.70210104, 4.54982055, 179.82292187, 19.17226962, 5.43930877, 4.37342724, 177.04591928, 4.67562213, 6.00262878, 4.61455226, 272.81579415, 11.3153307 , 6.40128184, 5.00633304, 2252.26388443, 89.94131559, 11.75971874, 10.21930995, 325.0546655 , 4.64757018, 6.4154822 , 4.9035335 ])
- min(index)float646.529e+04 -123.3 ... -31.97 -20.89
array([ 6.52850664e+04, -1.23261443e+02, -7.55937500e+01, -5.81310419e+01, 3.19670938e+04, -6.66089800e+01, -5.55937500e+01, -3.80387838e+01, 4.95582812e+03, -9.06898032e+01, -3.17832031e+01, -2.36347987e+01, 1.54332500e+04, -3.66425336e+01, -3.57548828e+01, -3.61389287e+01, 2.00126504e+04, -6.94903224e+02, -3.80410156e+01, -3.14954085e+01, 5.48705469e+04, -8.74110917e+01, -5.62773438e+01, -5.71788510e+01, 2.72978945e+04, -6.94690785e+01, -4.73476562e+01, -3.98079267e+01, 6.30219580e+03, -1.03889194e+02, -3.24965820e+01, -2.49301202e+01, 7.22198877e+03, -8.07118370e+01, -2.76318359e+01, -1.87084745e+01, 8.58008008e+03, -2.36194950e+01, -2.88779297e+01, -2.23368082e+01, 1.11858164e+04, -6.24034405e+01, -3.32412109e+01, -2.65777562e+01, 6.34323672e+04, -3.55021741e+02, -5.77968750e+01, -7.86537031e+01, 1.23135605e+04, -2.03195390e+01, -3.19687500e+01, -2.08865991e+01])
- 25%(index)float646.619e+04 -9.488 ... -4.281 -3.279
array([ 6.61889219e+04, -9.48833642e+00, -9.87500000e+00, -8.08656653e+00, 3.32582266e+04, -7.31995398e+00, -7.44335938e+00, -5.79592451e+00, 5.18955225e+03, -9.56397223e+00, -3.89270020e+00, -3.11812362e+00, 1.58437983e+04, -3.44825902e+00, -4.50830078e+00, -3.52359963e+00, 2.07200391e+04, -7.54642556e+01, -5.27343750e+00, -4.56884892e+00, 5.59803770e+04, -1.02040155e+01, -7.57812500e+00, -5.93617727e+00, 2.81871973e+04, -7.12718836e+00, -6.01953125e+00, -4.90452721e+00, 6.65427344e+03, -1.23264574e+01, -3.74560547e+00, -3.03522067e+00, 7.42203711e+03, -1.13228045e+01, -3.61962891e+00, -2.91987469e+00, 8.78827148e+03, -3.17652362e+00, -3.98632812e+00, -3.08491497e+00, 1.15134297e+04, -6.10318459e+00, -4.21875000e+00, -3.29996354e+00, 6.49962314e+04, -5.62104609e+01, -7.89062500e+00, -6.86189833e+00, 1.28410632e+04, -3.26949598e+00, -4.28125000e+00, -3.27908045e+00])
- 50%(index)float646.806e+04 -0.3476 ... -0.06529
array([ 6.80613828e+04, -3.47636218e-01, 7.81250000e-03, -1.30725999e-02, 3.53564023e+04, -9.16894810e-02, -3.51562500e-02, -2.89893356e-02, 5.41613721e+03, 1.46219949e-01, -5.78613281e-02, -9.43597722e-02, 1.65516709e+04, -2.64416689e-01, 2.73437500e-02, -9.33119760e-02, 2.14401523e+04, 1.14545608e+01, 2.14843750e-02, -5.92623226e-02, 5.67837734e+04, 8.92652675e-02, 1.95312500e-03, -7.25704027e-02, 2.95458281e+04, 4.87577712e-02, -4.49218750e-02, -1.00367284e-01, 7.25929248e+03, -2.94537715e-01, 3.34472656e-02, -2.26939366e-02, 7.45421924e+03, -3.65075272e-01, -1.66015625e-02, -4.78969335e-02, 8.95055371e+03, -1.40979344e-01, 7.81250000e-03, -1.58822255e-02, 1.16602783e+04, 2.74758582e-01, 2.73437500e-02, -4.15869687e-02, 6.58943555e+04, 3.32494254e+00, 4.68750000e-02, -5.63888776e-02, 1.29597114e+04, -2.24671520e-01, 7.81250000e-03, -6.52851501e-02])
- 75%(index)float646.882e+04 8.988 ... 4.277 3.213
array([6.88189727e+04, 8.98839136e+00, 1.00234375e+01, 8.09175077e+00, 3.61145820e+04, 7.01429305e+00, 7.48437500e+00, 5.69826563e+00, 5.53317725e+03, 1.00475793e+01, 3.87048340e+00, 2.98923215e+00, 1.68310957e+04, 3.02806213e+00, 4.49414062e+00, 3.38630625e+00, 2.20209531e+04, 9.08799796e+01, 5.26367188e+00, 4.52717869e+00, 5.74483965e+04, 1.01837346e+01, 7.51171875e+00, 5.81439607e+00, 2.99188320e+04, 7.15369426e+00, 6.00732422e+00, 4.81928542e+00, 7.52007275e+03, 1.29233935e+01, 3.76867676e+00, 2.95097962e+00, 7.65798096e+03, 1.08133591e+01, 3.61523438e+00, 2.87012749e+00, 9.08583008e+03, 2.94619654e+00, 3.98046875e+00, 3.03713233e+00, 1.19144434e+04, 6.53949167e+00, 4.25976562e+00, 3.31653085e+00, 6.81768262e+04, 5.91546280e+01, 7.77343750e+00, 6.74722831e+00, 1.31409404e+04, 2.87210558e+00, 4.27661133e+00, 3.21300974e+00])
- max(index)float647.021e+04 128.7 ... 30.94 29.45
array([7.02122031e+04, 1.28660432e+02, 7.23125000e+01, 6.93417611e+01, 3.75573242e+04, 4.79886762e+01, 8.00234375e+01, 6.16484304e+01, 6.00489502e+03, 9.13358716e+01, 2.45039062e+01, 2.61494406e+01, 1.73261172e+04, 2.35996818e+01, 3.72900391e+01, 2.62028292e+01, 2.24849180e+04, 3.82495214e+02, 3.95625000e+01, 3.61686737e+01, 5.90490898e+04, 7.62124556e+01, 5.41562500e+01, 5.42089099e+01, 3.03895098e+04, 6.19293590e+01, 4.28691406e+01, 4.45521063e+01, 7.83280371e+03, 7.73476225e+01, 2.73188477e+01, 2.46408888e+01, 8.00235156e+03, 6.99670019e+01, 3.19765625e+01, 2.73471557e+01, 9.59054297e+03, 2.72015177e+01, 2.51455078e+01, 2.51618206e+01, 1.22767207e+04, 4.38635401e+01, 3.68515625e+01, 3.04471056e+01, 7.25150000e+04, 3.21070047e+02, 6.51328125e+01, 5.87010158e+01, 1.39579814e+04, 2.65332384e+01, 3.09394531e+01, 2.94483002e+01])
stars = raw_dataframe.columns
stars = sorted(list(set([i.split("_")[0] for i in stars])))
print(f"The number of stars available is: {len(stars)}")
print(f"star identifiers: {stars}")
The number of stars available is: 13
star identifiers: ['001430305', '001724719', '005209845', '007596240', '007609553', '008241079', '008247770', '009345933', '009347009', '009349482', '009349757', '010024701', '011611275']
dataframe = raw_dataframe[[i + "_rscl" for i in stars]].rename(
columns=lambda c: c.replace("_rscl", "")
)
dataframe.columns.names = ["star"]
dataframe.shape
(71427, 13)
Rolling mean¶
We will smooth the data by applying a rolling mean with a window of 100 periods.
Count nan values¶
But before since the dataset has some nan values, we will extract few statistics about the density of nan values in windows of size 100.
It will be the occasion to show a usage of the wax.modules.Buffer
module with the format_outputs=False
option for the dataframe accessor .wax.stream
.
import jax.numpy as jnp
import numpy as onp
from wax.modules import Buffer
Let’s apply the Buffer
module to the data:
buffer, _ = dataframe.wax.stream(format_outputs=False).apply(lambda x: Buffer(100)(x))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
assert isinstance(buffer, jnp.ndarray)
Let’s describe the statistic of nans with pandas:
count_nan = jnp.isnan(buffer).sum(axis=1)
pd.DataFrame(onp.array(count_nan)).stack().describe().astype(int)
count 928551
mean 20
std 27
min 0
25% 5
50% 8
75% 19
max 100
dtype: int64
Computing the rolling mean¶
We will choose a min_periods
of 5 in order to keep at leas 75% of the points.
%%time
dataframe_mean, _ = dataframe.wax.stream().apply(
lambda x: RollingMean(100, min_periods=5)(x)
)
CPU times: user 446 ms, sys: 11.5 ms, total: 458 ms
Wall time: 456 ms
dataframe.loc[:, "008241079"].plot()
dataframe_mean.loc[:, "008241079"].plot()
<AxesSubplot:xlabel='time'>

With Dataset API¶
Let’s illustrate how to do the same rolling mean operation but using wax accessors on xarray Dataset
.
from functools import partial
from jax.tree_util import tree_map
dataset = dataframe.to_xarray()
dataset
<xarray.Dataset> Dimensions: (time: 71427) Coordinates: * time (time) datetime64[ns] 2009-03-07 ... 2017-04-30T02:00:00 Data variables: (12/13) 001430305 (time) float64 -4.943 2.338 nan -0.9275 ... 13.92 7.728 nan 1.33 001724719 (time) float64 -9.95 -19.69 -6.298 nan ... 7.535 nan 0.8825 005209845 (time) float64 nan nan nan nan ... -5.528 -6.864 -25.47 -25.75 007596240 (time) float64 -1.353 -1.534 -9.497 -3.48 ... 4.234 -3.83 -0.6448 007609553 (time) float64 38.36 19.7 19.08 27.18 ... 172.2 178.2 163.8 169.9 008241079 (time) float64 -22.68 -12.27 nan -10.34 ... 3.236 8.97 18.74 ... ... 009345933 (time) float64 -9.756 -9.812 -8.808 ... 5.486 -10.21 -1.196 009347009 (time) float64 2.219 -3.694 -3.056 -3.843 ... -3.58 11.99 -1.917 009349482 (time) float64 -4.975 -11.5 -7.711 -9.017 ... 2.338 1.825 0.9793 009349757 (time) float64 nan -16.64 -20.52 -15.6 ... nan 19.15 17.15 18.18 010024701 (time) float64 -45.78 -54.53 -35.46 ... -292.7 -301.2 -283.9 011611275 (time) float64 -1.308 -4.728 -5.136 2.284 ... 12.73 -6.223 -8.024
- time: 71427
- time(time)datetime64[ns]2009-03-07 ... 2017-04-30T02:00:00
array(['2009-03-07T00:00:00.000000000', '2009-03-07T01:00:00.000000000', '2009-03-07T02:00:00.000000000', ..., '2017-04-30T00:00:00.000000000', '2017-04-30T01:00:00.000000000', '2017-04-30T02:00:00.000000000'], dtype='datetime64[ns]')
- 001430305(time)float64-4.943 2.338 nan ... 7.728 nan 1.33
array([-4.94312846, 2.33812154, nan, ..., 7.7280167 , nan, 1.3295792 ])
- 001724719(time)float64-9.95 -19.69 -6.298 ... nan 0.8825
array([ -9.94989465, -19.6881759 , -6.2975509 , ..., 7.53484594, nan, 0.88250219])
- 005209845(time)float64nan nan nan ... -25.47 -25.75
array([ nan, nan, nan, ..., -6.86435205, -25.47275049, -25.74862939])
- 007596240(time)float64-1.353 -1.534 ... -3.83 -0.6448
array([-1.35257635, -1.53421697, -9.4971076 , ..., 4.23408294, -3.83037018, -0.64482331])
- 007609553(time)float6438.36 19.7 19.08 ... 163.8 169.9
array([ 38.36386935, 19.69785373, 19.07675998, ..., 178.20191723, 163.83472973, 169.92262035])
- 008241079(time)float64-22.68 -12.27 nan ... 8.97 18.74
array([-22.67964901, -12.26949276, nan, ..., 3.23589308, 8.97026808, 18.73979933])
- 008247770(time)float644.405 1.474 -4.892 ... nan -4.144
array([ 4.40528546, 1.47364484, -4.89158954, ..., -10.50180101, nan, -4.14437914])
- 009345933(time)float64-9.756 -9.812 ... -10.21 -1.196
array([ -9.75565222, -9.81229284, -8.80838659, ..., 5.48588468, -10.21333407, -1.19624422])
- 009347009(time)float642.219 -3.694 ... 11.99 -1.917
array([ 2.21872053, -3.6943654 , -3.05569353, ..., -3.58040441, 11.99088466, -1.91683019])
- 009349482(time)float64-4.975 -11.5 ... 1.825 0.9793
array([ -4.97500144, -11.50234519, -7.71132957, ..., 2.33773775, 1.82504244, 0.97933931])
- 009349757(time)float64nan -16.64 -20.52 ... 17.15 18.18
array([ nan, -16.6422077 , -20.52306708, ..., 19.15124702, 17.15027046, 18.18347358])
- 010024701(time)float64-45.78 -54.53 ... -301.2 -283.9
array([ -45.78391029, -54.53000404, -35.45969154, ..., -292.7287725 , -301.15455375, -283.88111625])
- 011611275(time)float64-1.308 -4.728 ... -6.223 -8.024
array([-1.30840132, -4.7283232 , -5.13554976, ..., 12.73319537, -6.22285932, -8.02364057])
%%time
dataset_mean, _ = dataset.wax.stream().apply(
partial(tree_map, lambda x: RollingMean(100, min_periods=5)(x)),
format_dims=["time"],
)
CPU times: user 5.53 s, sys: 132 ms, total: 5.67 s
Wall time: 5.67 s
(Its much longer than with dataframe)
TODO: This is an issue that we should solve.
dataset_mean
<xarray.Dataset> Dimensions: (time: 71427) Coordinates: * time (time) datetime64[ns] 2009-03-07 ... 2017-04-30T02:00:00 Data variables: (12/13) 001430305 (time) float32 nan nan nan nan ... -0.1169 0.1599 0.2577 0.3024 001724719 (time) float32 nan nan nan nan ... -6.384 -6.214 -6.223 -6.147 005209845 (time) float32 nan nan nan nan ... -9.909 -9.939 -10.13 -10.29 007596240 (time) float32 nan nan nan nan nan ... 1.165 1.19 1.14 1.134 007609553 (time) float32 nan nan nan nan 25.99 ... 145.5 146.0 146.4 146.9 008241079 (time) float32 nan nan nan nan nan ... 13.57 13.57 13.4 13.47 ... ... 009345933 (time) float32 nan nan nan nan ... -14.8 -14.53 -14.48 -14.31 009347009 (time) float32 nan nan nan nan ... -3.367 -3.462 -3.263 -3.25 009349482 (time) float32 nan nan nan nan -8.398 ... 1.861 1.858 1.817 1.825 009349757 (time) float32 nan nan nan nan nan ... 10.3 10.61 10.91 11.08 010024701 (time) float32 nan nan nan nan ... -322.8 -323.0 -322.8 -322.9 011611275 (time) float32 nan nan nan nan ... -4.214 -4.037 -4.106 -4.192
- time: 71427
- time(time)datetime64[ns]2009-03-07 ... 2017-04-30T02:00:00
array(['2009-03-07T00:00:00.000000000', '2009-03-07T01:00:00.000000000', '2009-03-07T02:00:00.000000000', ..., '2017-04-30T00:00:00.000000000', '2017-04-30T01:00:00.000000000', '2017-04-30T02:00:00.000000000'], dtype='datetime64[ns]')
- 001430305(time)float32nan nan nan ... 0.2577 0.3024
array([ nan, nan, nan, ..., 0.15992431, 0.25768742, 0.3024016 ], dtype=float32)
- 001724719(time)float32nan nan nan ... -6.223 -6.147
array([ nan, nan, nan, ..., -6.2136836, -6.2231784, -6.1467733], dtype=float32)
- 005209845(time)float32nan nan nan ... -10.13 -10.29
array([ nan, nan, nan, ..., -9.938673, -10.134243, -10.286571], dtype=float32)
- 007596240(time)float32nan nan nan nan ... 1.19 1.14 1.134
array([ nan, nan, nan, ..., 1.1900659, 1.1402303, 1.1340625], dtype=float32)
- 007609553(time)float32nan nan nan ... 146.0 146.4 146.9
array([ nan, nan, nan, ..., 146.03654, 146.40211, 146.8768 ], dtype=float32)
- 008241079(time)float32nan nan nan ... 13.57 13.4 13.47
array([ nan, nan, nan, ..., 13.570164, 13.399795, 13.467991], dtype=float32)
- 008247770(time)float32nan nan nan ... -3.861 -3.921
array([ nan, nan, nan, ..., -3.9081523, -3.861196 , -3.9209824], dtype=float32)
- 009345933(time)float32nan nan nan ... -14.48 -14.31
array([ nan, nan, nan, ..., -14.527492, -14.483915, -14.306822], dtype=float32)
- 009347009(time)float32nan nan nan ... -3.462 -3.263 -3.25
array([ nan, nan, nan, ..., -3.462142 , -3.2630634, -3.2503006], dtype=float32)
- 009349482(time)float32nan nan nan ... 1.858 1.817 1.825
array([ nan, nan, nan, ..., 1.8581948, 1.8167623, 1.8253683], dtype=float32)
- 009349757(time)float32nan nan nan ... 10.61 10.91 11.08
array([ nan, nan, nan, ..., 10.611248, 10.90544 , 11.077212], dtype=float32)
- 010024701(time)float32nan nan nan ... -322.8 -322.9
array([ nan, nan, nan, ..., -323.0488 , -322.8231 , -322.85065], dtype=float32)
- 011611275(time)float32nan nan nan ... -4.106 -4.192
array([ nan, nan, nan, ..., -4.0370173, -4.1056514, -4.1921782], dtype=float32)
dataset["008241079"].plot()
dataset_mean["008241079"].plot()
[<matplotlib.lines.Line2D at 0x7fad38d28b70>]

With dataarray¶
dataarray = dataframe.to_xarray().to_array("star").transpose("time", "star")
dataarray
<xarray.DataArray (time: 71427, star: 13)> array([[ -4.94312846, -9.94989465, nan, ..., nan, -45.78391029, -1.30840132], [ 2.33812154, -19.6881759 , nan, ..., -16.6422077 , -54.53000404, -4.7283232 ], [ nan, -6.2975509 , nan, ..., -20.52306708, -35.45969154, -5.13554976], ..., [ 7.7280167 , 7.53484594, -6.86435205, ..., 19.15124702, -292.7287725 , 12.73319537], [ nan, nan, -25.47275049, ..., 17.15027046, -301.15455375, -6.22285932], [ 1.3295792 , 0.88250219, -25.74862939, ..., 18.18347358, -283.88111625, -8.02364057]]) Coordinates: * time (time) datetime64[ns] 2009-03-07 ... 2017-04-30T02:00:00 * star (star) <U9 '001430305' '001724719' ... '010024701' '011611275'
- time: 71427
- star: 13
- -4.943 -9.95 nan -1.353 38.36 ... -1.917 0.9793 18.18 -283.9 -8.024
array([[ -4.94312846, -9.94989465, nan, ..., nan, -45.78391029, -1.30840132], [ 2.33812154, -19.6881759 , nan, ..., -16.6422077 , -54.53000404, -4.7283232 ], [ nan, -6.2975509 , nan, ..., -20.52306708, -35.45969154, -5.13554976], ..., [ 7.7280167 , 7.53484594, -6.86435205, ..., 19.15124702, -292.7287725 , 12.73319537], [ nan, nan, -25.47275049, ..., 17.15027046, -301.15455375, -6.22285932], [ 1.3295792 , 0.88250219, -25.74862939, ..., 18.18347358, -283.88111625, -8.02364057]])
- time(time)datetime64[ns]2009-03-07 ... 2017-04-30T02:00:00
array(['2009-03-07T00:00:00.000000000', '2009-03-07T01:00:00.000000000', '2009-03-07T02:00:00.000000000', ..., '2017-04-30T00:00:00.000000000', '2017-04-30T01:00:00.000000000', '2017-04-30T02:00:00.000000000'], dtype='datetime64[ns]')
- star(star)<U9'001430305' ... '011611275'
array(['001430305', '001724719', '005209845', '007596240', '007609553', '008241079', '008247770', '009345933', '009347009', '009349482', '009349757', '010024701', '011611275'], dtype='<U9')
%%time
dataarray_mean, _ = dataarray.wax.stream().apply(
lambda x: RollingMean(100, min_periods=5)(x)
)
CPU times: user 426 ms, sys: 7.23 ms, total: 433 ms
Wall time: 431 ms
(Its much longer than with dataframe)
dataarray_mean
<xarray.DataArray (time: 71427, star: 13)> array([[ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], ..., [ 1.5992440e-01, -6.2136836e+00, -9.9386721e+00, ..., 1.0611248e+01, -3.2304874e+02, -4.0370173e+00], [ 2.5768742e-01, -6.2231779e+00, -1.0134243e+01, ..., 1.0905440e+01, -3.2282303e+02, -4.1056514e+00], [ 3.0240160e-01, -6.1467724e+00, -1.0286570e+01, ..., 1.1077212e+01, -3.2285059e+02, -4.1921792e+00]], dtype=float32) Coordinates: * time (time) datetime64[ns] 2009-03-07 ... 2017-04-30T02:00:00 * star (star) <U9 '001430305' '001724719' ... '010024701' '011611275'
- time: 71427
- star: 13
- nan nan nan nan nan nan nan ... -14.31 -3.25 1.825 11.08 -322.9 -4.192
array([[ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], ..., [ 1.5992440e-01, -6.2136836e+00, -9.9386721e+00, ..., 1.0611248e+01, -3.2304874e+02, -4.0370173e+00], [ 2.5768742e-01, -6.2231779e+00, -1.0134243e+01, ..., 1.0905440e+01, -3.2282303e+02, -4.1056514e+00], [ 3.0240160e-01, -6.1467724e+00, -1.0286570e+01, ..., 1.1077212e+01, -3.2285059e+02, -4.1921792e+00]], dtype=float32)
- time(time)datetime64[ns]2009-03-07 ... 2017-04-30T02:00:00
array(['2009-03-07T00:00:00.000000000', '2009-03-07T01:00:00.000000000', '2009-03-07T02:00:00.000000000', ..., '2017-04-30T00:00:00.000000000', '2017-04-30T01:00:00.000000000', '2017-04-30T02:00:00.000000000'], dtype='datetime64[ns]')
- star(star)<U9'001430305' ... '011611275'
array(['001430305', '001724719', '005209845', '007596240', '007609553', '008241079', '008247770', '009345933', '009347009', '009349482', '009349757', '010024701', '011611275'], dtype='<U9')
dataarray.sel(star="008241079").plot()
dataarray_mean.sel(star="008241079").plot()
[<matplotlib.lines.Line2D at 0x7fadbac4e518>]

Forecasting with Machine Learning¶
We need two forecast in this data, if you look with attention you’ll see micro holes and big holes.
import warnings
from typing import NamedTuple, Tuple, TypeVar
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
import plotnine as gg
T = TypeVar("T")
Pair = Tuple[T, T]
class Pair(NamedTuple):
x: T
y: T
class TrainSplit(NamedTuple):
train: T
validation: T
gg.theme_set(gg.theme_bw())
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = 18, 8
fig, (ax, lax) = plt.subplots(ncols=2, gridspec_kw={"width_ratios": [4, 1]})
dataframe.plot(ax=ax, title="raw data")
ax.legend(bbox_to_anchor=(0, 0, 1, 1), bbox_transform=lax.transAxes)
lax.axis("off")
(0.0, 1.0, 0.0, 1.0)

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = 18, 8
fig, (ax, lax) = plt.subplots(ncols=2, gridspec_kw={"width_ratios": [4, 1]})
dataframe_mean.plot(ax=ax, title="Smoothed data")
ax.legend(bbox_to_anchor=(0, 0, 1, 1), bbox_transform=lax.transAxes)
lax.axis("off")
(0.0, 1.0, 0.0, 1.0)

Normalize data¶
dataframe_mean.stack().hist(bins=100)
<AxesSubplot:>

from wax.encode import Encoder
def min_max_scaler(values: pd.DataFrame, output_format: str = "dataframe") -> Encoder:
scaler = MinMaxScaler(feature_range=(0, 1))
scaler.fit(values)
index = values.index
columns = values.columns
def encode(dataframe: pd.DataFrame):
nonlocal index
nonlocal columns
index = dataframe.index
columns = dataframe.columns
array_normed = scaler.transform(dataframe)
if output_format == "dataframe":
return pd.DataFrame(array_normed, index, columns)
elif output_format == "jax":
return jnp.array(array_normed)
else:
return array_normed
def decode(array_scaled):
value = scaler.inverse_transform(array_scaled)
if output_format == "dataframe":
return pd.DataFrame(value, index, columns)
else:
return value
return Encoder(encode, decode)
scaler = min_max_scaler(dataframe_mean)
dataframe_normed = scaler.encode(dataframe_mean)
assert (scaler.decode(dataframe_normed) - dataframe_mean).stack().abs().max() < 1.0e-4
dataframe_normed.stack().hist(bins=100)
<AxesSubplot:>

Prepare train / validation datasets¶
from wax.modules import FillNanInf, Lag
def split_feature_target(dataframe, look_back=SEQ_LEN) -> Pair:
x, _ = dataframe.wax.stream(format_outputs=False).apply(
lambda x: FillNanInf()(Lag(1)(Buffer(look_back)(x)))
)
B, T, F = x.shape
x = x.transpose(1, 0, 2)
y, _ = dataframe.wax.stream(format_outputs=False).apply(
lambda x: FillNanInf()(Buffer(look_back)(x))
)
y = y.transpose(1, 0, 2)
return Pair(x, y)
def split_feature_target(
dataframe,
look_back=SEQ_LEN,
stack=True,
shuffle=False,
min_periods_ratio: float = 0.8,
) -> Pair:
x, _ = dataframe.wax.stream(format_outputs=False).apply(
lambda x: Lag(1)(Buffer(look_back)(x))
)
x = x.transpose(1, 0, 2)
y, _ = dataframe.wax.stream(format_outputs=False).apply(
lambda x: Buffer(look_back)(x)
)
y = y.transpose(1, 0, 2)
T, B, F = x.shape
if stack:
x = x.reshape(T, B * F, 1)
y = y.reshape(T, B * F, 1)
if shuffle:
rng = jax.random.PRNGKey(42)
idx = jnp.arange(x.shape[1])
idx = jax.random.shuffle(rng, idx)
x = x[:, idx]
y = y[:, idx]
if min_periods_ratio:
count_nan = jnp.isnan(x).sum(axis=0)
mask = count_nan < min_periods_ratio * T
idx = jnp.where(mask)
# print("count_nan = ", count_nan)
# print("B = ", B)
x = x[:, idx[0], :]
y = y[:, idx[0], :]
T, B, F = x.shape
# print("B = ", B)
# round Batch size to a power of to
B_round = int(2 ** jnp.floor(jnp.log2(B)))
x = x[:, :B_round, :]
y = y[:, :B_round, :]
# fillnan by zeros
fill_nan_inf = hk.transform(lambda x: FillNanInf()(x))
params = fill_nan_inf.init(None, jnp.full(x.shape, jnp.nan, x.dtype))
x = fill_nan_inf.apply(params, None, x)
y = fill_nan_inf.apply(params, None, y)
return Pair(x, y)
def split_train_validation(dataframe, stars, train_size, look_back) -> TrainSplit:
# prepare scaler
dataframe_train = dataframe[stars].iloc[:train_size]
scaler = min_max_scaler(dataframe_train)
# prepare train data
dataframe_train_normed = scaler.encode(dataframe_train)
train = split_feature_target(dataframe_train_normed, look_back)
# prepare validation data
valid_size = len(dataframe[stars]) - train_size
valid_size = int(2 ** jnp.floor(jnp.log2(valid_size)))
valid_end = int(train_size + valid_size)
dataframe_valid = dataframe[stars].iloc[train_size:valid_end]
dataframe_valid_normed = scaler.encode(dataframe_valid)
valid = split_feature_target(dataframe_valid_normed, look_back)
return TrainSplit(train, valid)
print(f"Look at star: {STAR}")
train, valid = split_train_validation(dataframe_normed, [STAR], TRAIN_SIZE, SEQ_LEN)
Look at star: 007609553
train[0].shape, train[1].shape, valid[0].shape, valid[1].shape
((64, 32768, 1), (64, 32768, 1), (64, 2048, 1), (64, 2048, 1))
TRAIN_SIZE, VALID_SIZE = len(train.x), len(valid.x)
seq = hk.PRNGSequence(42)
# Plot an observation/target pair.
batch_plot = jax.random.choice(next(seq), len(train[0]))
df = pd.DataFrame(
{"x": train[0][:, batch_plot, 0], "y": train[1][:, batch_plot, 0]}
).reset_index()
df = pd.melt(df, id_vars=["index"], value_vars=["x", "y"])
plot = (
gg.ggplot(df)
+ gg.aes(x="index", y="value", color="variable")
+ gg.geom_line()
+ gg.scales.scale_y_log10()
)
_ = plot.draw()

Dataset iterator¶
class Dataset:
"""An iterator over a numpy array, revealing batch_size elements at a time."""
def __init__(self, xy: Pair, batch_size: int):
self._x, self._y = xy
self._batch_size = batch_size
self._length = self._x.shape[1]
self._idx = 0
if self._length % batch_size != 0:
msg = "dataset size {} must be divisible by batch_size {}."
raise ValueError(msg.format(self._length, batch_size))
def __next__(self) -> Pair:
start = self._idx
end = start + self._batch_size
x, y = self._x[:, start:end], self._y[:, start:end]
if end >= self._length:
end = end % self._length
assert end == 0 # Guaranteed by ctor assertion.
self._idx = end
return x, y
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)
del train, valid # Don't leak temporaries.
Training an LSTM¶
To train the LSTM, we define a Haiku function which unrolls the LSTM over the input sequence, generating predictions for all output values. The LSTM always starts with its initial state at the start of the sequence.
The Haiku function is then transformed into a pure function through hk.transform
, and is trained with Adam on an L2 prediction loss.
from wax.compile import jit_init_apply
x, y = next(train_ds)
x.shape, y.shape
((64, 8, 1), (64, 8, 1))
from collections import defaultdict
def unroll_net(seqs: jnp.ndarray):
"""Unrolls an LSTM over seqs, mapping each output to a scalar."""
# seqs is [T, B, F].
core = hk.LSTM(32)
batch_size = seqs.shape[1]
outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
# We could include this Linear as part of the recurrent core!
# However, it's more efficient on modern accelerators to run the linear once
# over the entire sequence than once per sequence element.
return hk.BatchApply(hk.Linear(1))(outs), state
model = jit_init_apply(hk.transform(unroll_net))
def train_model(
train_ds: Dataset, valid_ds: Dataset, max_iterations: int = -1
) -> hk.Params:
"""Initializes and trains a model on train_ds, returning the final params."""
rng = jax.random.PRNGKey(428)
opt = optax.adam(1e-3)
@jax.jit
def loss(params, x, y):
pred, _ = model.apply(params, None, x)
return jnp.mean(jnp.square(pred - y))
@jax.jit
def update(step, params, opt_state, x, y):
l, grads = jax.value_and_grad(loss)(params, x, y)
grads, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, grads)
return l, params, opt_state
# Initialize state.
sample_x, _ = next(train_ds)
params = model.init(rng, sample_x)
opt_state = opt.init(params)
step = 0
records = defaultdict(list)
def _format_results(records):
records = {key: jnp.stack(l) for key, l in records.items()}
return records
with tqdm() as pbar:
while True:
if step % 100 == 0:
x, y = next(valid_ds)
valid_loss = loss(params, x, y)
# print("Step {}: valid loss {}".format(step, valid_loss))
records["step"].append(step)
records["valid_loss"].append(valid_loss)
try:
x, y = next(train_ds)
except StopIteration:
return params, _format_results(records)
train_loss, params, opt_state = update(step, params, opt_state, x, y)
if step % 100 == 0:
# print("Step {}: train loss {}".format(step, train_loss))
records["train_loss"].append(train_loss)
step += 1
pbar.update()
if max_iterations > 0 and step >= max_iterations:
return params, _format_results(records)
%%time
trained_params, records = train_model(train_ds, valid_ds, TRAIN_STEPS)
CPU times: user 2min 36s, sys: 6.9 s, total: 2min 42s
Wall time: 1min 23s
# Plot losses
losses = pd.DataFrame(records)
df = pd.melt(losses, id_vars=["step"], value_vars=["train_loss", "valid_loss"])
plot = (
gg.ggplot(df)
+ gg.aes(x="step", y="value", color="variable")
+ gg.geom_line()
+ gg.scales.scale_y_log10()
)
_ = plot.draw()

Sampling¶
The point of training models is so that they can make predictions! How can we generate predictions with the trained model?
If we’re allowed to feed in the ground truth, we can just run the original model’s apply
function.
def plot_samples(truth: np.ndarray, prediction: np.ndarray) -> gg.ggplot:
assert truth.shape == prediction.shape
df = pd.DataFrame(
{"truth": truth.squeeze(), "predicted": prediction.squeeze()}
).reset_index()
df = pd.melt(df, id_vars=["index"], value_vars=["truth", "predicted"])
plot = (
gg.ggplot(df) + gg.aes(x="index", y="value", color="variable") + gg.geom_line()
)
return plot
# Grab a sample from the validation set.
sample_x, _ = next(valid_ds)
sample_x = sample_x[:, :1] # Shrink to batch-size 1.
# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(trained_params, None, sample_x)
plot = plot_samples(sample_x[1:], predicted[:-1])
plot.draw()
del sample_x, predicted

Run autoregressively¶
If we can’t feed in the ground truth (because we don’t have it), we can also run the model autoregressively.
def autoregressive_predict(
trained_params: hk.Params,
context: jnp.ndarray,
seq_len: int,
):
"""Given a context, autoregressively generate the rest of a sine wave."""
ar_outs = []
context = jax.device_put(context)
times = range(seq_len - context.shape[0])
for _ in times:
full_context = jnp.concatenate([context] + ar_outs)
outs, _ = jax.jit(model.apply)(trained_params, None, full_context)
# Append the newest prediction to ar_outs.
ar_outs.append(outs[-1:])
# Return the final full prediction.
return outs
sample_x, _ = next(valid_ds)
context_length = SEQ_LEN // 8
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:context_length, :1]
%%time
# We can reuse params we got from training for inference - as long as the
# declaration order is the same.
predicted = autoregressive_predict(trained_params, context, SEQ_LEN)
plot = plot_samples(sample_x[1:, :1], predicted)
plot += gg.geom_vline(xintercept=len(context), linetype="dashed")
plot.draw()
del predicted
CPU times: user 9.71 s, sys: 194 ms, total: 9.91 s
Wall time: 9.82 s

Sharing parameters with a different function.¶
Unfortunately, this is a bit slow - we’re doing O(N^2) computation for a sequence of length N.
It’d be better if we could do the autoregressive sampling all at once - but we need to write a new Haiku function for that.
We’re in luck - if the Haiku module names match, the same parameters can be used for multiple Haiku functions.
This can be achieved through a combination of two techniques:
If we manually give a unique name to a module, we can ensure that the parameters are directed to the right places.
If modules are instantiated in the same order, they’ll have the same names in different functions.
Here, we rely on method #2 to create a fast autoregressive prediction.
def fast_autoregressive_predict_fn(context, seq_len):
"""Given a context, autoregressively generate the rest of a sine wave."""
core = hk.LSTM(32)
dense = hk.Linear(1)
state = core.initial_state(context.shape[1])
# Unroll over the context using `hk.dynamic_unroll`.
# As before, we `hk.BatchApply` the Linear for efficiency.
context_outs, state = hk.dynamic_unroll(core, context, state)
context_outs = hk.BatchApply(dense)(context_outs)
# Now, unroll one step at a time using the running recurrent state.
ar_outs = []
x = context_outs[-1]
times = range(seq_len - context.shape[0])
for _ in times:
x, state = core(x, state)
x = dense(x)
ar_outs.append(x)
return jnp.concatenate([context_outs, jnp.stack(ar_outs)])
fast_ar_predict = hk.transform(fast_autoregressive_predict_fn)
fast_ar_predict = jax.jit(fast_ar_predict.apply, static_argnums=3)
%%time
# Reuse the same context from the previous cell.
predicted = fast_ar_predict(trained_params, None, context, SEQ_LEN)
# The plots should be equivalent!
plot = plot_samples(sample_x[1:, :1], predicted[:-1])
plot += gg.geom_vline(xintercept=len(context), linetype="dashed")
_ = plot.draw()
CPU times: user 6.67 s, sys: 144 ms, total: 6.82 s
Wall time: 6.75 s

%timeit autoregressive_predict(trained_params, context, SEQ_LEN)
%timeit fast_ar_predict(trained_params, None, context, SEQ_LEN)
86.3 ms ± 1.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
34.2 µs ± 549 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Train all stars¶
Training¶
def split_train_validation_date(dataframe, stars, date, look_back) -> TrainSplit:
train_size = len(dataframe.loc[:date])
return split_train_validation(dataframe, stars, train_size, look_back)
%%time
train, valid = split_train_validation_date(dataframe_normed, stars, TRAIN_DATE, SEQ_LEN)
TRAIN_SIZE = train[0].shape[1]
print(f"TRAIN_SIZE = {TRAIN_SIZE}")
TRAIN_SIZE = 524288
CPU times: user 5.45 s, sys: 1.75 s, total: 7.2 s
Wall time: 4.42 s
train[0].shape, train[1].shape, valid[0].shape, valid[1].shape
((64, 524288, 1), (64, 524288, 1), (64, 16384, 1), (64, 16384, 1))
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)
del train, valid # Don't leak temporaries.
%%time
trained_params, records = train_model(train_ds, valid_ds, TRAIN_STEPS)
CPU times: user 2min 36s, sys: 7.03 s, total: 2min 43s
Wall time: 1min 24s
# Plot losses
losses = pd.DataFrame(records)
df = pd.melt(losses, id_vars=["step"], value_vars=["train_loss", "valid_loss"])
plot = (
gg.ggplot(df)
+ gg.aes(x="step", y="value", color="variable")
+ gg.geom_line()
+ gg.scales.scale_y_log10()
)
_ = plot.draw()

Sampling¶
# Grab a sample from the validation set.
sample_x, _ = next(valid_ds)
sample_x = sample_x[:, :1] # Shrink to batch-size 1.
# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(trained_params, None, sample_x)
plot = plot_samples(sample_x[1:], predicted[:-1])
plot.draw()
del sample_x, predicted

Run autoregressively¶
%%time
sample_x, _ = next(valid_ds)
context_length = SEQ_LEN // 8
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:context_length, :1]
# Reuse the same context from the previous cell.
predicted = fast_ar_predict(trained_params, None, context, SEQ_LEN)
# The plots should be equivalent!
plot = plot_samples(sample_x[1:, :1], predicted[:-1])
plot += gg.geom_vline(xintercept=len(context), linetype="dashed")
_ = plot.draw()
CPU times: user 195 ms, sys: 18.2 ms, total: 213 ms
Wall time: 144 ms

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7f4da67afcb0>]
🦎 Online linear regression with a non-stationary environment 🦎¶
We implement an online learning non-stationary linear regression problem.
We go there progressively by showing how a linear regression problem can be cast
into an online learning problem thanks to the OnlineSupervisedLearner
module.
Then, to tackle a non-stationary linear regression problem (i.e. with a weight that can vary in time)
we reformulate the problem into a reinforcement learning problem that we implement with the GymFeedBack
module of WAX-ML.
We then need to define an “agent” and an “environment” using simple functions implemented with modules:
The agent is responsible for learning the weights of its internal linear model.
The environment is responsible for generating labels and evaluating the agent’s reward metric.
We experiment with a non-stationary environment that returns the sign of the linear regression parameters at a given time step, known only to the environment.
This example shows that it is quite simple to implement this online-learning task with WAX-ML tools. In particular, the functional workflow adopted here allows reusing the functions implemented for a task for each new task of increasing complexity,
In this journey, we will use:
Haiku basic linear module
hk.Linear
.Optax stochastic gradient descent optimizer:
sgd
.WAX-ML modules:
OnlineSupervisedLearner
,Lag
,GymFeedBack
WAX-ML helper functions:
dynamic_unroll
,jit_init_apply
%pylab inline
Populating the interactive namespace from numpy and matplotlib
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from matplotlib import pyplot as plt
from wax.compile import jit_init_apply
from wax.modules import OnlineSupervisedLearner
Static Linear Regression¶
First, let’s implement a simple linear regression
Generate data¶
Let’s generate a batch of data:
seq = hk.PRNGSequence(42)
X = jax.random.normal(next(seq), (100, 3))
w_true = jnp.ones(3)
Define the model¶
We use the basic module hk.Linear
which is a linear layer.
By default, it initializes the weights with random values from the truncated normal,
with a standard deviation of \(1 / \sqrt{N}\) (See https://arxiv.org/abs/1502.03167v3)
where \(N\) is the size of the inputs.
@jit_init_apply
@hk.transform_with_state
def linear_model(x):
return hk.Linear(output_size=1, with_bias=False)(x)
Run the model¶
Let’s run the model using WAX-ML dynamic_unroll
on the batch of data.
from wax.unroll import dynamic_unroll
params, state = linear_model.init(next(seq), X[0])
linear_model.apply(params, state, None, X[0])
(DeviceArray([-0.2070887], dtype=float32), FlatMapping({}))
Y_pred, state = dynamic_unroll(linear_model, None, None, next(seq), False, X)
Y_pred.shape
(100, 1)
Check cost¶
Let’s look at the mean squared error for this non-trained model.
noise = jax.random.normal(next(seq), (100,))
Y = X.dot(w_true) + noise
L = ((Y - Y_pred) ** 2).sum(axis=1)
mean_loss = L.mean()
assert mean_loss > 0
Let’s look at the regret (cumulative sum of the loss) for the non-trained model.
plt.plot(L.cumsum())
plt.title("Regret")
Text(0.5, 1.0, 'Regret')

As expected, we have a linear regret when we did not train the model!
Online Linear Regression¶
We will now start training the model online. For a review on online-learning methods see [1]
[1] Elad Hazan, Introduction to Online Convex Optimization
Define an optimizer¶
opt = optax.sgd(1e-3)
Define a loss¶
Since we are doing online learning, we need to define a local loss function: $\( \ell_t(y, w, x) = \lVert y_t - w \cdot x_t \rVert^2 \)$
@jax.jit
def loss(y_pred, y):
return jnp.mean(jnp.square(y_pred - y))
Define a learning strategy¶
@jit_init_apply
@hk.transform_with_state
def learner(x, y):
return OnlineSupervisedLearner(linear_model, opt, loss)(x, y)
Generate data¶
def generate_many_observations(T=300, sigma=1.0e-2, rng=None):
rng = jax.random.PRNGKey(42) if rng is None else rng
X = jax.random.normal(rng, (T, 3))
noise = sigma * jax.random.normal(rng, (T,))
w_true = jnp.ones(3)
noise = sigma * jax.random.normal(rng, (T,))
Y = X.dot(w_true) + noise
return (X, Y)
T = 3000
X, Y = generate_many_observations(T)
Unroll the learner¶
output, online_state = dynamic_unroll(learner, None, None, next(seq), False, X, Y)
Plot the regret¶
Let’s look at the loss and regret over time.
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
axs[0].plot(output["loss"].cumsum())
axs[0].set_title("Regret")
axs[1].plot(output["params"]["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
Text(0.5, 1.0, 'Weight[0,0]')

We have sub-linear regret!
Online learning with Gym¶
Now we will recast the online linear regression learning task as a reinforcement learning task
implemented with the GymFeedback
module of WAX-ML.
For that, we define:
obserbations (
obs
) : pairs(x, y)
of features and labelsraw observations (
raw_obs
): pairs(x, noise)
of features and noise.
Linear regression agent¶
In WAX-ML, an agent is a simple function with the following API:

Let’s define a simple linear regression agent with the elements we have defined so far.
def linear_regression_agent(obs):
x, y = obs
@jit_init_apply
@hk.transform_with_state
def model(x):
return hk.Linear(output_size=1, with_bias=False)(x)
opt = optax.sgd(1e-3)
@jax.jit
def loss(y_pred, y):
return jnp.mean(jnp.square(y_pred - y))
def learner(x, y):
return OnlineSupervisedLearner(model, opt, loss)(x, y)
return learner(x, y)
Linear regression environment¶
In WAX-ML, an environment is a simple function with the following API:

Let’s now define a linear regression environment that, for the moment, have static weights.
It is responsible for generating the real labels and evaluating the agent’s reward.
For the evaluation of the reward, we need the Lag
module to evaluate the action of
the agent with the labels generated in the previous time step.
from wax.modules import Lag
def stationary_linear_regression_env(action, raw_obs):
# Only the environment now the true value of the parameters
w_true = -jnp.ones(3)
# The environment has its proper loss definition
@jax.jit
def loss(y_pred, y):
return jnp.mean(jnp.square(y_pred - y))
# raw observation contains features and generative noise
x, noise = raw_obs
# generate targets
y = x @ w_true + noise
obs = (x, y)
y_previous = Lag(1)(y)
# evaluate the prediction made by the agent
y_pred = action["y_pred"]
reward = loss(y_pred, y_previous)
return reward, obs
Generate raw observation¶
Let’s define a function that generate the raw observation:
def generate_many_raw_observations(T=300, sigma=1.0e-2, rng=None):
rng = jax.random.PRNGKey(42) if rng is None else rng
X = jax.random.normal(rng, (T, 3))
noise = sigma * jax.random.normal(rng, (T,))
return (X, noise)
Implement Feedback¶
We are now ready to set things up with the GymFeedback
module implemented in WAX-ML.
It implements the following feedback loop:

Equivalently, it can be described with the pair of init
and apply
functions:

from wax.modules import GymFeedback
@hk.transform_with_state
def gym_fun(raw_obs):
return GymFeedback(
linear_regression_agent, stationary_linear_regression_env, return_action=True
)(raw_obs)
And now we can unroll it on a sequence of raw observations!
seq = hk.PRNGSequence(42)
T = 3000
raw_observations = generate_many_raw_observations(T)
rng = next(seq)
output_sequence, final_state = dynamic_unroll(
gym_fun,
None,
None,
rng,
True,
raw_observations,
)
Let’s visualize the outputs.
We now use pd.Series
to represent the reward sequence since its first value is Nan due to the use of the lag operator.
import pandas as pd
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
pd.Series(output_sequence.reward).cumsum().plot(ax=axs[0], title="Regret")
axs[1].plot(output["params"]["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
Text(0.5, 1.0, 'Weight[0,0]')

Non-stationary environment¶
Now, let’s implement a non-stationary environment.
We implement it so that the sign of the weight is reversed after 2000$ steps.
class NonStationaryEnvironment(hk.Module):
def __call__(self, action, raw_obs):
step = hk.get_state("step", [], init=lambda *_: 0)
# Only the environment now the true value of the parameters
# at step 2000 we flip the sign of the true parameters !
w_true = hk.cond(
step < 2000,
step,
lambda step: -jnp.ones(3),
step,
lambda step: jnp.ones(3),
)
# The environment has its proper loss definition
@jax.jit
def loss(y_pred, y):
return jnp.mean(jnp.square(y_pred - y))
# raw observation contains features and generative noise
x, noise = raw_obs
# generate targets
y = x @ w_true + noise
obs = (x, y)
# evaluate the prediction made by the agent
y_previous = Lag(1)(y)
y_pred = action["y_pred"]
reward = loss(y_pred, y_previous)
step += 1
hk.set_state("step", step)
return reward, obs
Now let’s run a gym simulation to see how the agent adapt to the change of environment.
@hk.transform_with_state
def gym_fun(raw_obs):
return GymFeedback(
linear_regression_agent, NonStationaryEnvironment(), return_action=True
)(raw_obs)
T = 6000
raw_observations = generate_many_raw_observations(T)
rng = jax.random.PRNGKey(42)
output_sequence, final_state = dynamic_unroll(
gym_fun,
None,
None,
rng,
True,
raw_observations,
)
import pandas as pd
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
pd.Series(output_sequence.reward).cumsum().plot(ax=axs[0], title="Regret")
axs[1].plot(output_sequence.action["params"]["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
# plt.savefig("../_static/online_linear_regression_regret.png")
Text(0.5, 1.0, 'Weight[0,0]')

It adapts!
The regret first converges, then jumps on step 2000 and finally readjusts to the new regime.
We see that the weights converge to the correct values in both regimes.
Installing wax
¶
First, obtain the WAX-ML source code:
git clone https://github.com/eserie/wax-ml
cd wax
You can install wax
by running:
pip install -e .[complete] # install wax
To upgrade to the latest version from GitHub, just run git pull
from the WAX-ML
repository root. You shouldn’t have to reinstall wax
because pip install -e
sets up symbolic links from site-packages into the repository.
You can install wax
development tools by running:
pip install -e .[dev] # install wax-development-tools
Running the tests¶
To run all the WAX-ML tests, we recommend using pytest-xdist
, which can run tests in
parallel. First, install pytest-xdist
and pytest-benchmark
by running
ip install -r build/test-requirements.txt
.
Then, from the repository root directory run:
pytest -n auto .
You can run a more specific set of tests using pytest’s built-in selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:
pytest -v wax/accessors_test.py
The Colab notebooks are tested for errors as part of the documentation build and Github actions.
Type checking¶
We use mypy
to check the type hints. To check types locally the same way
as Github actions checks, you can run:
mypy wax
or
make mypy
Flake8¶
We use flake8
to check that the code follow the pep8 standard.
To check the code, you can run
make flake8
Formatting code¶
We use isort
and black
to format the code.
When you are in the root directory of the project, to format code in the package, you can run:
make format-package
To format notebooks in the documentation, you can use:
make format-notebooks
To format all files you can run:
make format
Note that the CI running with actions will verify that formatting all source code does not affect the files. You can check this locally by running :
make check-format
Check actions¶
You can check that everything is ok by running:
make act
This will check flake8, mypy, isort and black formatting, licenses headers and run tests and coverage.
Update documentation¶
To rebuild the documentation, install several packages:
pip install -r docs/requirements.txt
And then run:
sphinx-build -b html docs docs/build/html
or run
make docs
This can take a long time because it executes many of the notebooks in the documentation source; if you’d prefer to build the docs without executing the notebooks, you can run:
sphinx-build -b html -D jupyter_execute_notebooks=off docs docs/build/html
or run
make docs-fast
You can then see the generated documentation in docs/_build/html/index.html
.
Update notebooks¶
We use jupytext to maintain three synced copies of the notebooks
in docs/notebooks
: one in ipynb
format, one in py
and one in md
format.
The advantage of the former is that it can be opened and executed directly in Colab;
the advantage of the second is that it makes easier to refactor and format python code;
the advantage of the latter is that it makes it much easier to track diffs within version control.
Editing ipynb¶
For making large changes that substantially modify code and outputs, it is easiest to
edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface,
open http://colab.research.google.com and Upload
from your local repo.
Update it as needed, Run all cells
then Download ipynb
.
You may want to test that it executes properly, using sphinx-build
as explained above.
You could format the python code in your notebooks by running make format
in the docs/notebooks
directory
or
make format-notebooks
in the root directory.
Editing md¶
For making smaller changes to the text content of the notebooks, it is easiest to edit the
.md
versions using a text editor.
Syncing notebooks¶
After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:
jupytext --sync docs/notebooks/*
or:
cd docs/notebooks/
make sync
Alternatively, you can run this command via the pre-commit framework by executing the following in the main WAX-ML directory:
pre-commit run --all
See the pre-commit framework documentation for information on how to set your local git environment to execute this automatically.
Creating new notebooks¶
If you are adding a new notebook to the documentation and would like to use the jupytext --sync
command discussed here, you can set up your notebook for jupytext by using the following command:
jupytext --set-formats ipynb,py,md:myst path/to/the/notebook.ipynb
This works by adding a "jupytext"
metadata field to the notebook file which specifies the
desired formats, and which the jupytext --sync
command recognizes when invoked.
Notebooks within the sphinx build¶
Some of the notebooks are built automatically as part of the Travis pre-submit checks and
as part of the Read the docs build.
The build will fail if cells raise errors. If the errors are intentional,
you can either catch them, or tag the cell with raises-exceptions
metadata (example PR).
You have to add this metadata by hand in the .ipynb
file. It will be preserved when somebody else
re-saves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations.
See exclude_patterns
in conf.py.
Documentation building on readthedocs.io¶
WAX-ML’s auto-generated documentations is at https://wax-ml.readthedocs.io/.
The documentation building is controlled for the entire project by the
readthedocs WAX-ML settings. The current settings
trigger a documentation build as soon as code is pushed to the GitHub main
branch.
For each code version, the building process is driven by the
.readthedocs.yml
and the docs/conf.py
configuration files.
For each automated documentation build you can see the documentation build logs.
If you want to test the documentation generation on Readthedocs, you can push code to the test-docs
branch. That branch is also built automatically, and you can
see the generated documentation here. If the documentation build
fails you may want to wipe the build environment for test-docs.
For a local test, you can do it in a fresh directory by replaying the commands executed by Readthedocs and written in their logs:
mkvirtualenv wax-docs # A new virtualenv
mkdir wax-docs # A new directory
cd wax-docs
git clone --no-single-branch --depth 50 https://github.com/eserie/wax-ml
cd wax
git checkout --force origin/test-docs
git clean -d -f -f
workon wax-docs
python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html
Public API: wax package¶
Subpackages¶
wax.modules package¶
Gym modules¶
In WAX-ML, an agent and environments are simple functions:
A Gym feedback loops can be represented with the diagram:
Equivalently, it can be described with the pair of init and apply functions:
Gym feedback between an agent and a Gym environment. |
Other Haiku modules¶
Buffer module. |
|
Diff module. |
|
Exponentioal moving average module. |
|
Exponentially weighted variance module. |
|
Exponentially weighted variance module. |
|
Detect if something has changed. |
|
Delay operator. |
|
Open-High-Low-Close binning. |
|
Relative change between the current and a prior element. |
|
Rolling mean. |
|
Apply a module when an event occur otherwise return last computed output. |
wax.gym package¶
|
Define accessors for xarray and pandas data containers. |
|
Compilation helper for Haiku Transformed and TransformedWithState pairs of pure funcions. |
|
Encoding schemes to encode/decode numpy data types non supported by JAX, e.g. |
|
Format nested data structures to numpy/xarray/pandas containers. |
|
Define Stream object used to synchronize in-memory data streams and unroll data transformations on it. |
|
Transformation functions to work on batches of data. |
|
Unroll modules on data along first axis. |
|
Some utils functions used in WAX-ML. |