Welcome to WAX-ML

WAX-ML is a library for machine-learning on streaming data.

For an introduction to WAX-ML, start at the WAX-ML GitHub page.

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]

〰 Compute exponential moving averages with xarray and pandas accessors 〰

Open in Colab

WAX-ML implements pandas and xarray accessors to ease the usage of machine-learning algorithms with high-level data APIs :

  • pandas’s DataFrame and Series

  • xarray’s Dataset and DataArray.

These accessors allow to easily execute any function using Haiku modules on these data containers.

For instance, WAX-ML propose an implementation of the exponential moving average realized with this mechanism.

Let’s show how it works.

Load accessors

First you need to load accessors:

from wax.accessors import register_wax_accessors

register_wax_accessors()

EWMA on dataframes

Let’s look at a simple example: The exponential moving average (EWMA).

Let’s apply the EWMA algorithm to the NCEP/NCAR ‘s Air temperature data.

🌡 Load temperature dataset 🌡

import xarray as xr

dataset = xr.tutorial.open_dataset("air_temperature")

Let’s see what this dataset looks like:

dataset
<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53, time: 2920)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 ...
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...

To compute a EWMA on some variables of a dataset, we usually need to convert data in pandas series or dataframe.

So, let’s convert the dataset into a dataframe to illustrate accessors on a dataframe:

dataframe = dataset.air.to_series().unstack(["lon", "lat"])

EWMA with pandas

air_temp_ewma = dataframe.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
_images/01_demo_EWMA_17_0.png

EWMA with WAX-ML

air_temp_ewma = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
_images/01_demo_EWMA_19_0.png

On small data, WAX-ML’s EWMA is slower than Pandas’ because of the expensive data conversion steps. WAX-ML’s accessors are interesting to use on large data loads (See our three-steps_workflow)

Apply a custom function to a Dataset

Now let’s illustrate how WAX-ML accessors work on xarray datasets.

from wax.modules import EWMA


def my_custom_function(dataset):
    return {
        "air_10": EWMA(1.0 / 10.0)(dataset["air"]),
        "air_100": EWMA(1.0 / 100.0)(dataset["air"]),
    }


dataset = xr.tutorial.open_dataset("air_temperature")
output, state = dataset.wax.stream().apply(
    my_custom_function, format_dims=dataset.air.dims
)

_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))
_images/01_demo_EWMA_22_0.png
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]

⏱ Synchronize data streams ⏱

Open in Colab

Physicists, and not the least 😅, have brought a solution to the synchronization problem. See Poincaré-Einstein synchronization Wikipedia page for more details.

In WAX-ML we strive to follow their recommendations and implement a synchronization mechanism between different data streams. Using the terminology of Henri Poincaré (see link above), we introduce the notion of “local time” to unravel the stream in which the user wants to apply transformations. We call the other streams “secondary streams”. They can work at different frequencies, lower or higher. The data from these secondary streams will be represented in the “local time” either with the use of a forward filling mechanism for lower frequencies or a buffering mechanism for higher frequencies.

We implement a “data tracing” mechanism to optimize access to out-of-sync streams. This mechanism works on in-memory data. We perform the first pass on the data, without actually accessing it, and determine the indices necessary to later access the data. Doing so we are vigilant to not let any “future” information pass through and thus guaranty a data processing that respects causality.

The buffering mechanism used in the case of higher frequencies works with a fixed buffer size (see the WAX-ML module wax.modules.Buffer) which allows us to use JAX / XLA optimizations and have efficient processing.

Let’s illustrate with a small example how wax.stream.Stream synchronizes data streams.

Let’s use the dataset “air temperature” with :

  • An air temperature is defined with hourly resolution.

  • A “fake” ground temperature is defined with a daily resolution as the air temperature minus 10 degrees.

import xarray as xr

dataset = xr.tutorial.open_dataset("air_temperature")
dataset["ground"] = dataset.air.resample(time="d").last().rename({"time": "day"}) - 10

Let’s see what this dataset looks like:

dataset
<xarray.Dataset>
Dimensions:  (day: 730, lat: 25, lon: 53, time: 2920)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
  * day      (day) datetime64[ns] 2013-01-01 2013-01-02 ... 2014-12-31
Data variables:
    air      (time, lat, lon) float32 ...
    ground   (day, lat, lon) float32 231.9 231.8 231.8 ... 286.5 286.2 285.7
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
from wax.accessors import register_wax_accessors

register_wax_accessors()
from wax.modules import EWMA


def my_custom_function(dataset):
    return {
        "air_10": EWMA(1.0 / 10.0)(dataset["air"]),
        "air_100": EWMA(1.0 / 100.0)(dataset["air"]),
        "ground_100": EWMA(1.0 / 100.0)(dataset["ground"]),
    }
results, state = dataset.wax.stream(
    local_time="time", ffills={"day": 1}, pbar=True
).apply(my_custom_function, format_dims=dataset.air.dims)
_ = results.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))
_images/02_Synchronize_data_streams_11_0.png
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]

🌡 Binning temperatures 🌡

Open in Colab

Let’s again considering the air temperatures dataset. It is sampled at an hourly resolution. We will make “trailing” air temperature bins during each day and “reset” the bin aggregation process at each day change.

import numpy as onp
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.modules import OHLC, HasChanged

register_wax_accessors()
dataset = xr.tutorial.open_dataset("air_temperature")
dataset["date"] = dataset.time.dt.date.astype(onp.datetime64)
dataset
<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53, time: 2920)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 ...
    date     (time) datetime64[ns] 2013-01-01 2013-01-01 ... 2014-12-31
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
def bin_temperature(da):
    day_change = HasChanged()(da["date"])
    return OHLC()(da["air"], reset_on=day_change)


output, state = dataset.wax.stream().apply(
    bin_temperature, format_dims=onp.array(dataset.air.dims)
)
output = xr.Dataset(output._asdict())
df = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().loc["2013-01"]
_ = df.plot(figsize=(12, 8), title="Trailing Open-High-Low-Close temperatures")
_images/03_ohlc_temperature_10_0.png

The UpdateOnEvent module

The OHLC module uses the primitive wax.modules.UpdateOnEvent.

Its implementation required to complete Haiku with a central function set_params_or_state_dict which we have actually integrated in this WAX-ML module.

We have opened an issue on the Haiku github to integrate it in Haiku.

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x17678ceb0>]

🎛 The 3-steps workflow 🎛

Open in Colab

It is already very useful to be able to execute a JAX function on a dataframe in a single work step and with a single command line thanks to WAX-ML accessors.

The 1-step WAX-ML’s stream API works like that:

<data-container>.stream(...).apply(...)

But this is not optimal because, under the hood, there are mainly three costly steps:

  • (1) (synchronize | data tracing | encode): make the data “JAX ready”

  • (2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.

  • (3) (format): convert data back to pandas/xarray/numpy format.

With the wax.stream primitives, it is quite easy to explicitly split the 1-step workflow into a 3-step workflow.

This will allow the user to have full control over each step and iterate on each one.

It is actually very useful to iterate on step (2), the “calculation step” when you are doing research. You can then take full advantage of the JAX primitives, especially the jit primitive.

Let’s illustrate how to reimplement WAX-ML EWMA yourself with the WAX-ML 3-step workflow.

Imports

import numpy as onp
import pandas as pd
import xarray as xr

from wax.accessors import register_wax_accessors
from wax.external.eagerpy import convert_to_tensors
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import unroll

register_wax_accessors()

Performance on big dataframes

Generate data

T = 1.0e5
N = 1000
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
    onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)

pandas EWMA

%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 2.03 s, sys: 167 ms, total: 2.19 s
Wall time: 2.19 s

WAX-ML EWMA

%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 1.8 s, sys: 876 ms, total: 2.68 s
Wall time: 2.67 s

It’s a little faster, but not that much faster…

WAX-ML EWMA (without format step)

Let’s disable the final formatting step (the output is now in raw JAX format):

%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()
df_ewma_wax_no_format.block_until_ready()
CPU times: user 1.8 s, sys: 1 s, total: 2.8 s
Wall time: 3.55 s
DeviceArray([[-2.9661492e-01, -3.2103235e-01, -9.6144844e-03, ...,
               4.4276214e-01,  1.1568004e+00, -1.0724162e+00],
             [-9.5860586e-02, -2.0366390e-01, -3.7967765e-01, ...,
              -6.1594141e-01,  7.7942121e-01,  1.8111229e-02],
             [ 2.8211397e-01,  1.7749734e-01, -2.0034584e-01, ...,
              -7.6095390e-01,  5.3778893e-01,  3.1442198e-01],
             ...,
             [-1.1917123e-01,  1.3764068e-01,  2.4761766e-01, ...,
               2.0842913e-01,  2.5283977e-01, -1.1205430e-01],
             [-1.0947308e-01,  3.6484647e-01,  2.4164049e-01, ...,
               2.7038181e-01,  2.4539444e-01,  6.2920153e-05],
             [ 4.8219025e-02,  1.5648599e-01,  1.2161890e-01, ...,
               2.0765728e-01,  8.9837506e-02,  1.0943251e-01]],            dtype=float32)
type(df_ewma_wax_no_format)
jaxlib.xla_extension.DeviceArray

Let’s check the device on which the calculation was performed (if you have GPU available, this should be GpuDevice otherwise it will be CpuDevice):

df_ewma_wax_no_format.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>

Now we will see how to break down WAX-ML one-liners <dataset>.ewm(...).mean() or <dataset>.stream(...).apply(...) into 3 steps:

  • a preparation step where we prepare JAX-ready data and functions.

  • a processing step where we execute the JAX program

  • a post-processing step where we format the results in pandas or xarray format.

Generate data (in dataset format)

WAX-ML Sream object works on datasets. So let’s transform the DataFrame into a xarray Dataset:

dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")
del dataframe

Step (1) (synchronize | data tracing | encode)

In this step, WAX-ML do:

  • “data tracing” : prepare the indices for fast access tin the JAX function access_data

  • synchronize streams if there is multiple ones. This functionality have options : freq, ffills

  • encode and convert data from numpy to JAX: use encoders for datetimes64 and string_ dtypes. Be aware that by default JAX works in float32 (see JAX’s Common Gotchas to work in float64).

We have a function Stream.prepare that implement this Step (1). It prepares a function that wraps the input function with the actual data and indices in a pair of pure functions (TransformedWithState Haiku tuple).

%%time
stream = dataset.wax.stream()
CPU times: user 244 µs, sys: 36 µs, total: 280 µs
Wall time: 287 µs

Define our custom function to be applied on a dict of arrays having the same structure than the original dataset:

def my_ewma_on_dataset(dataset):
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
transform_dataset, jxs = stream.prepare(dataset, my_ewma_on_dataset)

Let’s definite the init parameters and state of the transformation we will apply.

Step (2) (compile | code tracing | execution)

In this step we:

  • prepare a pure function (with Haiku’s transform mechanism) Define a “transformation” function which:

    • access to the data

    • apply another transformation, here: EWMA

  • compile it with jax.jit

  • perform code tracing and execution (the last line):

    • Unroll the transformation on “steps” xs (a np.arange vector).

outputs = unroll(transform_dataset)(jxs)
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>

Once it has been compiled and “traced” by JAX, the function is much faster to execute:

%%timeit
outputs = unroll(transform_dataset)(jxs)
_ = outputs.block_until_ready()
619 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This is 3x faster than pandas implementation!

Manually prepare the data and manage the device

In order to manage the device on which the computations take place, we need to have even more control over the execution flow. Instead of calling stream.prepare to build the transform_dataset function, we can do it ourselves by :

  • using the stream.trace_dataset function

  • converting the numpy data in jax ourself

  • puting the data on the device we want.

np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")

We explicitly set data on CPUs (the is not needed if you only have CPUs):

from jax.tree_util import tree_leaves, tree_map

cpus = jax.devices("cpu")
jnp_data, jnp_index, jxs = tree_map(
    lambda x: jax.device_put(x, cpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to CPU device.")
data copied to CPU device.

We have now “JAX-ready” data for later fast access.

Let’s define the transformation that wrap the actual data and indices in a pair of pure functions:

@jax.jit
@unroll
def transform_dataset(step):
    dataset = tree_access_data(jnp_data, jnp_index, step)
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])

And we can call it as before:

%%time
outputs = transform_dataset(jxs)
_ = outputs.block_until_ready()
CPU times: user 1.72 s, sys: 1.06 s, total: 2.78 s
Wall time: 3.31 s
%%time
outputs = transform_dataset(jxs)
_ = outputs.block_until_ready()
CPU times: user 546 ms, sys: 52.8 ms, total: 599 ms
Wall time: 567 ms
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>

Step(3) (format)

Let’s come back to pandas/xarray:

%%time
y = format_dataframe(
    dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)
CPU times: user 27.9 ms, sys: 70.4 ms, total: 98.2 ms
Wall time: 144 ms

It’s quite slow (see WEP3 enhancement proposal).

GPU execution

Let’s look with execution on GPU

try:
    gpus = jax.devices("gpu")
    jnp_data, jnp_index, jxs = tree_map(
        lambda x: jax.device_put(x, gpus[0]), (jnp_data, jnp_index, jxs)
    )
    print("data copied to GPU device.")
    GPU_AVAILABLE = True
except RuntimeError as err:
    print(err)
    GPU_AVAILABLE = False
Requested backend gpu, but it failed to initialize: Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host

Let’s check that our data is on the GPUs:

tree_leaves(jnp_data)[0].device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
tree_leaves(jnp_index)[0].device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
jxs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
%%time
if GPU_AVAILABLE:
    outputs = unroll(transform_dataset)(jxs)
CPU times: user 3 µs, sys: 2 µs, total: 5 µs
Wall time: 11.2 µs

Let’s redefine our function transform_dataset by explicitly specify to jax.jit the device option.

%%time
from functools import partial

if GPU_AVAILABLE:

    @partial(jax.jit, device=gpus[0])
    @unroll
    def transform_dataset(step):
        dataset = tree_access_data(jnp_data, jnp_index, step)
        return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])

    outputs = transform_dataset(jxs)
CPU times: user 8 µs, sys: 2 µs, total: 10 µs
Wall time: 15 µs
outputs.device()
<jaxlib.xla_extension.Device at 0x17678ceb0>
%%timeit
if GPU_AVAILABLE:
    outputs = unroll(transform_dataset)(jxs)
    _ = outputs.block_until_ready()
12 ns ± 0.0839 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
%matplotlib inline
import io
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, NamedTuple, Optional, TypeVar
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpy as onp
import optax
import pandas as pd
import plotnine as gg
import requests
from sklearn.preprocessing import MinMaxScaler
from tqdm.auto import tqdm
from wax.accessors import register_wax_accessors
from wax.compile import jit_init_apply
from wax.encode import Encoder
from wax.modules import Buffer, FillNanInf, Lag, RollingMean
from wax.unroll import unroll
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x14ade9f70>]

🔭 Reconstructing the light curve of stars with LSTM 🔭

Open in Colab

Let’s take a walk through the stars…

This notebook is based on the study done in this post by Christophe Pere and the notebook available on the authors’s github.

We will repeat this study on starlight using the LSTM architecture to predict the observed light flux through time.

Our LSTM implementation is based on this notebook from Haiku’s github repository.

We’ll see how to use WAX-ML to ease the preparation of time series data stored in dataframes and having Nans before calling a “standard” deep-learning workflow.

Disclaimer

Despite the fact that this code works with real data, the results presented here should not be considered as scientific knowledge insights, to the knowledge of the authors of WAX-ML, neither the results nor the data source have been reviewed by an astrophysics pair.

The purpose of this notebook is only to demonstrate how WAX-ML can be used when applying a “standard” machine learning workflow, here LSTM, to analyze time series.

Download the data

register_wax_accessors()
# Parameters
STAR = "007609553"
SEQ_LEN = 64
BATCH_SIZE = 8
TRAIN_SIZE = 2 ** 16
NUM_EPOCHS = 10
NUM_STARS = None
RECORD_FREQ = 100
TOTAL_LEN = None
TRAIN_DATE = "2016"
CACHE_DIR = Path("./cached_data/")
%%time
filename = CACHE_DIR / "kep_lightcurves.parquet"
try:
    raw_dataframe = pd.read_parquet(open(filename, "rb"))
    print(f"data read from {filename}")
except FileNotFoundError:
    # Downloading the csv file from Chrustioge Pere GitHub account
    download = requests.get(
        "https://raw.github.com/Christophe-pere/Time_series_RNN/master/kep_lightcurves.csv"
    ).content
    raw_dataframe = pd.read_csv(io.StringIO(download.decode("utf-8")))
    # set date index
    raw_dataframe.index = pd.Index(
        pd.date_range("2009-03-07", periods=len(raw_dataframe.index), freq="h"),
        name="time",
    )
    # save dataframe locally in CACHE_DIR
    CACHE_DIR.mkdir(exist_ok=True)
    raw_dataframe.to_parquet(filename)
    print(f"data saved in {filename}")
data read from cached_data/kep_lightcurves.parquet
CPU times: user 55.3 ms, sys: 32.2 ms, total: 87.5 ms
Wall time: 40.8 ms
# shortening of data to speed up the execution of the notebook in the CI
if TOTAL_LEN:
    raw_dataframe = raw_dataframe.iloc[:TOTAL_LEN]

Let’s visualize the description of this dataset:

raw_dataframe.describe().T.to_xarray()
<xarray.Dataset>
Dimensions:  (index: 52)
Coordinates:
  * index    (index) object '001430305_orig' ... '011611275_res'
Data variables:
    count    (index) float64 6.48e+04 5.674e+04 ... 5.673e+04 5.673e+04
    mean     (index) float64 6.776e+04 -0.2265 0.01231 ... 0.001437 0.004351
    std      (index) float64 1.363e+03 15.42 15.27 12.45 ... 4.648 6.415 4.904
    min      (index) float64 6.529e+04 -123.3 -75.59 ... -20.32 -31.97 -20.89
    25%      (index) float64 6.619e+04 -9.488 -9.875 ... -3.269 -4.281 -3.279
    50%      (index) float64 6.806e+04 -0.3476 0.007812 ... 0.007812 -0.06529
    75%      (index) float64 6.882e+04 8.988 10.02 8.092 ... 2.872 4.277 3.213
    max      (index) float64 7.021e+04 128.7 72.31 69.34 ... 26.53 30.94 29.45
stars = raw_dataframe.columns
stars = sorted(list(set([i.split("_")[0] for i in stars])))
print(f"The number of stars available is: {len(stars)}")
print(f"star identifiers: {stars}")
The number of stars available is: 13
star identifiers: ['001430305', '001724719', '005209845', '007596240', '007609553', '008241079', '008247770', '009345933', '009347009', '009349482', '009349757', '010024701', '011611275']
dataframe = raw_dataframe[[i + "_rscl" for i in stars]].rename(
    columns=lambda c: c.replace("_rscl", "")
)
dataframe.columns.names = ["star"]
dataframe.shape
(71427, 13)
if NUM_STARS:
    columns = dataframe.columns.tolist()
    columns.remove(STAR)
    dataframe = dataframe[[STAR] + columns[: NUM_STARS - 1]]

Rolling mean

We will smooth the data by applying a rolling mean with a window of 100 periods.

Count nan values

But before since the dataset has some nan values, we will extract few statistics about the density of nan values in windows of size 100.

It will be the occasion to show a usage of the wax.modules.Buffer module with the format_outputs=False option for the dataframe accessor .wax.stream.

Let’s apply the Buffer module to the data:

buffer, _ = dataframe.wax.stream(format_outputs=False).apply(lambda x: Buffer(100)(x))
assert isinstance(buffer, jnp.ndarray)

Equivalently, we can use wax unroll function.

buffer = unroll(lambda x: Buffer(100)(x))(jax.device_put(dataframe.values))

Let’s describe the statistic of nans with pandas:

count_nan = jnp.isnan(buffer).sum(axis=1)
pd.DataFrame(onp.array(count_nan)).stack().describe().astype(int)
count    928551
mean         20
std          27
min           0
25%           5
50%           8
75%          19
max         100
dtype: int64

Computing the rolling mean

We will choose a min_periods of 5 in order to keep at leas 75% of the points.

%%time
dataframe_mean, _ = dataframe.wax.stream().apply(
    lambda x: RollingMean(100, min_periods=5)(x)
)
CPU times: user 262 ms, sys: 9.48 ms, total: 272 ms
Wall time: 270 ms
dataframe.iloc[:, :2].plot()
<AxesSubplot:xlabel='time'>
_images/05_reconstructing_the_light_curve_of_stars_31_1.png

Forecasting with Machine Learning

We need two forecast in this data, if you look with attention you’ll see micro holes and big holes.

T = TypeVar("T")
class Pair(NamedTuple):
    x: T
    y: T
class TrainSplit(NamedTuple):
    train: T
    validation: T
gg.theme_set(gg.theme_bw())
warnings.filterwarnings("ignore")
plt.rcParams["figure.figsize"] = 18, 8
fig, (ax, lax) = plt.subplots(ncols=2, gridspec_kw={"width_ratios": [4, 1]})
dataframe.plot(ax=ax, title="raw data")
ax.legend(bbox_to_anchor=(0, 0, 1, 1), bbox_transform=lax.transAxes)
lax.axis("off")
(0.0, 1.0, 0.0, 1.0)
_images/05_reconstructing_the_light_curve_of_stars_37_1.png
plt.rcParams["figure.figsize"] = 18, 8
fig, (ax, lax) = plt.subplots(ncols=2, gridspec_kw={"width_ratios": [4, 1]})
dataframe_mean.plot(ax=ax, title="Smoothed data")
ax.legend(bbox_to_anchor=(0, 0, 1, 1), bbox_transform=lax.transAxes)
lax.axis("off")
# -
(0.0, 1.0, 0.0, 1.0)
_images/05_reconstructing_the_light_curve_of_stars_38_1.png

Normalize data

dataframe_mean.stack().hist(bins=100, log=True)
<AxesSubplot:>
_images/05_reconstructing_the_light_curve_of_stars_40_1.png
def min_max_scaler(values: pd.DataFrame, output_format: str = "dataframe") -> Encoder:
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaler.fit(values)
    index = values.index
    columns = values.columns

    def encode(dataframe: pd.DataFrame):
        nonlocal index
        nonlocal columns

        index = dataframe.index
        columns = dataframe.columns
        array_normed = scaler.transform(dataframe)

        if output_format == "dataframe":
            return pd.DataFrame(array_normed, index, columns)
        elif output_format == "jax":
            return jnp.array(array_normed)
        else:
            return array_normed

    def decode(array_scaled):

        value = scaler.inverse_transform(array_scaled)

        if output_format == "dataframe":
            return pd.DataFrame(value, index, columns)
        else:
            return value

    return Encoder(encode, decode)
scaler = min_max_scaler(dataframe_mean)
dataframe_normed = scaler.encode(dataframe_mean)
assert (scaler.decode(dataframe_normed) - dataframe_mean).stack().abs().max() < 1.0e-4
dataframe_normed.stack().hist(bins=100)
<AxesSubplot:>
_images/05_reconstructing_the_light_curve_of_stars_44_1.png

Prepare train / validation datasets

def split_feature_target(
    dataframe,
    look_back=SEQ_LEN,
    shuffle=True,
    stack=True,
    min_periods_ratio: float = 0.8,
    rng=None,
) -> Pair:
    def prepare_xy(data):
        buffer = Buffer(look_back + 1)(data)
        x = buffer[:-1]
        y = buffer[-1]
        return x, y

    def prepare_xy(data):
        y = Buffer(look_back)(data)
        x = Lag(1)(y)
        return x, y

    x, y = unroll(prepare_xy)(jax.device_put(dataframe.values))

    if shuffle:
        if rng is None:
            rng = jax.random.PRNGKey(42)

        B = x.shape[0]
        idx = jnp.arange(B)
        idx = jax.random.shuffle(rng, idx)

        x = x[idx]
        y = y[idx]

    if stack:
        B, T, F = x.shape
        x = x.transpose(1, 0, 2).reshape(T, B * F, 1).transpose(1, 0, 2)
        y = y.transpose(1, 0, 2).reshape(T, B * F, 1).transpose(1, 0, 2)

    if min_periods_ratio:
        T = x.shape[1]
        count_nan = jnp.isnan(x).sum(axis=1)
        mask = count_nan < min_periods_ratio * T
        idx = jnp.where(mask)
        x = x[idx[0]]
        y = y[idx[0]]

    # round Batch size to a power of to
    B = x.shape[0]
    B_round = int(2 ** jnp.floor(jnp.log2(B)))
    print(f"{B} batches rounded to {B_round} batches.")
    x = x[:B_round]
    y = y[:B_round]

    # fillnan by zeros
    x, y = hk.testing.transform_and_run(lambda x: FillNanInf()(x))((x, y))

    return Pair(x, y)
# split_feature_target(dataframe)
def split_train_validation(
    dataframe, train_size, look_back, scaler: Optional[Callable] = None
) -> TrainSplit:

    # prepare scaler
    train_df = dataframe.iloc[:train_size]

    if scaler:
        scaler = scaler(train_df)

    # prepare train data
    if scaler:
        train_df = scaler.encode(train_df)

    train_xy = split_feature_target(train_df, look_back)

    # prepare validation data
    valid_size = len(dataframe) - train_size
    valid_size = int(2 ** jnp.floor(jnp.log2(valid_size)))

    valid_end = int(train_size + valid_size)
    valid_df = dataframe.iloc[train_size:valid_end]

    if scaler:
        valid_df = scaler.encode(valid_df)

    valid_xy = split_feature_target(valid_df, look_back)

    return TrainSplit(train_xy, valid_xy)
TRAIN_SIZE
65536
print(f"Look at star: {STAR}")
train, valid = split_train_validation(dataframe_normed[[STAR]], TRAIN_SIZE, SEQ_LEN)
Look at star: 007609553
63871 batches rounded to 32768 batches.
3597 batches rounded to 2048 batches.
train[0].shape, train[1].shape, valid[0].shape, valid[1].shape
((32768, 64, 1), (32768, 64, 1), (2048, 64, 1), (2048, 64, 1))
# TRAIN_SIZE, VALID_SIZE = len(train.x), len(valid.x)
print(
    f"effective train_size = {len(train.x)}, " f"effective valid size= {len(valid.x)}"
)
effective train_size = 32768, effective valid size= 2048
# Plot an observation/target pair.
rng = jax.random.PRNGKey(42)
batch_plot = jax.random.choice(rng, len(train[0]))
df = pd.DataFrame(
    {"x": train.x[batch_plot, :, 0], "y": train.y[batch_plot, :, 0]}
).reset_index()
df = pd.melt(df, id_vars=["index"], value_vars=["x", "y"])
plot = (
    gg.ggplot(df)
    + gg.aes(x="index", y="value", color="variable")
    + gg.geom_line()
    + gg.scales.scale_y_log10()
)
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_53_0.png

Dataset iterator

class Dataset:
    """An iterator over a numpy array, revealing batch_size elements at a time."""

    def __init__(self, xy: Pair, batch_size: int):
        self._x, self._y = xy
        self._batch_size = batch_size
        self._length = self._x.shape[0]
        self._idx = 0
        if self._length % batch_size != 0:
            msg = "dataset size {} must be divisible by batch_size {}."
            raise ValueError(msg.format(self._length, batch_size))

    def __next__(self) -> Pair:
        start = self._idx
        end = start + self._batch_size
        x, y = self._x[start:end], self._y[start:end]
        if end >= self._length:
            print(f"End of the data set (size={end}). Return to the beginning.")
            end = end % self._length
            assert end == 0  # Guaranteed by ctor assertion.
        self._idx = end
        return Pair(x, y)

Training an LSTM

To train the LSTM, we define a Haiku function which unrolls the LSTM over the input sequence, generating predictions for all output values. The LSTM always starts with its initial state at the start of the sequence.

The Haiku function is then transformed into a pure function through hk.transform, and is trained with Adam on an L2 prediction loss.

def unroll_net(seqs: jnp.ndarray):
    """Unrolls an LSTM over seqs, mapping each output to a scalar."""
    # seqs is [T, B, F].
    core = hk.LSTM(32)
    batch_size = seqs.shape[0]
    outs, state = hk.dynamic_unroll(
        core, seqs, core.initial_state(batch_size), time_major=False
    )
    # We could include this Linear as part of the recurrent core!
    # However, it's more efficient on modern accelerators to run the linear once
    # over the entire sequence than once per sequence element.
    return hk.BatchApply(hk.Linear(1))(outs), state
model = jit_init_apply(hk.transform(unroll_net))
@jax.jit
def loss(pred, y):
    return jnp.mean(jnp.square(pred - y))


def model_with_loss(x, y):
    pred, _ = unroll_net(x)
    return loss(pred, y)
class TrainState(NamedTuple):
    step: int
    params: Any
    opt_state: Any
    rng: jnp.ndarray
    loss: float


def train_model(
    model_with_loss: Callable,
    train_ds: Dataset,
    valid_ds: Dataset,
    max_iterations: int = -1,
    rng=None,
    record_freq=100,
) -> hk.Params:
    """Initializes and trains a model on train_ds, returning the final params."""
    opt = optax.adam(1e-3)
    model_with_loss = jit_init_apply(hk.transform(model_with_loss))

    @jax.jit
    def update(train_state, x, y):
        step, params, opt_state, rng, _ = train_state
        if rng is not None:
            (rng,) = jax.random.split(rng, 1)
        l, grads = jax.value_and_grad(model_with_loss.apply)(params, rng, x, y)
        grads, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, grads)
        return TrainState(step + 1, params, opt_state, rng, l)

    # Initialize state.
    def init():
        x, y = next(train_ds)
        params = model_with_loss.init(rng, x, y)
        opt_state = opt.init(params)
        return TrainState(0, params, opt_state, rng, jnp.inf)

    def _format_results(records):
        records = {key: jnp.stack(l) for key, l in records.items()}
        return records

    records = defaultdict(list)
    train_state = init()
    with tqdm(total=max_iterations if max_iterations > 0 else None) as pbar:
        while True:
            try:
                x, y = next(train_ds)
            except StopIteration:
                return train_state, _format_results(records)

            train_state = update(train_state, x, y)
            if train_state.step % record_freq == 0:
                x, y = next(valid_ds)
                if rng is not None:
                    (rng,) = jax.random.split(rng, 1)
                valid_loss = model_with_loss.apply(train_state.params, rng, x, y)
                records["step"].append(train_state.step)
                records["valid_loss"].append(valid_loss)
                records["train_loss"].append(train_state.loss)

            pbar.update()
            if max_iterations > 0 and train_state.step >= max_iterations:
                return train_state, _format_results(records)
%%time
train, valid = split_train_validation(dataframe_normed[[STAR]], TRAIN_SIZE, SEQ_LEN)
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)


train_state, records = train_model(
    model_with_loss,
    train_ds,
    valid_ds,
    len(train.x) // BATCH_SIZE * NUM_EPOCHS,
    rng=jax.random.PRNGKey(42),
    record_freq=RECORD_FREQ,
)
63871 batches rounded to 32768 batches.
3597 batches rounded to 2048 batches.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=2048). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
End of the data set (size=32768). Return to the beginning.
CPU times: user 58.6 s, sys: 371 ms, total: 59 s
Wall time: 58.8 s
# train_state.params
# Plot losses
losses = pd.DataFrame(records)
df = pd.melt(losses, id_vars=["step"], value_vars=["train_loss", "valid_loss"])
plot = (
    gg.ggplot(df)
    + gg.aes(x="step", y="value", color="variable")
    + gg.geom_line()
    + gg.scales.scale_y_log10()
)
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_63_0.png

Sampling

The point of training models is so that they can make predictions! How can we generate predictions with the trained model?

If we’re allowed to feed in the ground truth, we can just run the original model’s apply function.

def plot_samples(truth: np.ndarray, prediction: np.ndarray) -> gg.ggplot:
    assert truth.shape == prediction.shape
    df = pd.DataFrame(
        {"truth": truth.squeeze(), "predicted": prediction.squeeze()}
    ).reset_index()
    df = pd.melt(df, id_vars=["index"], value_vars=["truth", "predicted"])
    plot = (
        gg.ggplot(df) + gg.aes(x="index", y="value", color="variable") + gg.geom_line()
    )
    return plot
# Grab a sample from the validation set.
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1]  # Shrink to batch-size 1.
sample_y = sample_y[:1]

# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(train_state.params, None, sample_x)

plot = plot_samples(sample_y, predicted)
plot.draw()
del sample_x, predicted
_images/05_reconstructing_the_light_curve_of_stars_66_0.png

Run autoregressively

If we can’t feed in the ground truth (because we don’t have it), we can also run the model autoregressively.

def autoregressive_predict(
    trained_params: hk.Params,
    context: jnp.ndarray,
    seq_len: int,
    pbar=False,
):
    """Given a context, autoregressively generate the rest of a sine wave."""

    ar_outs = []
    context = jax.device_put(context)
    times = onp.arange(seq_len - context.shape[1] + 1)
    if pbar:
        times = tqdm(times)
    for _ in times:
        full_context = jnp.concatenate([context] + ar_outs, axis=1)

        outs, _ = model.apply(trained_params, None, full_context)
        # Append the newest prediction to ar_outs.
        ar_outs.append(outs[:, -1:, :])
    # Return the final full prediction.
    return outs
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1]  # Shrink to batch-size 1.
sample_y = sample_y[:1]  # Shrink to batch-size 1.


context_length = SEQ_LEN // 8
print(f"context_length = {context_length}")
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:, :context_length]
context_length = 8
%%time
# We can reuse params we got from training for inference - as long as the
# declaration order is the same.
predicted = autoregressive_predict(train_state.params, context, SEQ_LEN, pbar=True)
CPU times: user 7.5 s, sys: 123 ms, total: 7.63 s
Wall time: 7.56 s
sample_y.shape, predicted.shape
((1, 64, 1), (1, 64, 1))
plot = plot_samples(sample_y, predicted)
plot += gg.geom_vline(xintercept=context.shape[1], linetype="dashed")
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_73_0.png
Sharing parameters with a different function.

Unfortunately, this is a bit slow - we’re doing O(N^2) computation for a sequence of length N.

It’d be better if we could do the autoregressive sampling all at once - but we need to write a new Haiku function for that.

We’re in luck - if the Haiku module names match, the same parameters can be used for multiple Haiku functions.

This can be achieved through a combination of two techniques:

  1. If we manually give a unique name to a module, we can ensure that the parameters are directed to the right places.

  2. If modules are instantiated in the same order, they’ll have the same names in different functions.

Here, we rely on method #2 to create a fast autoregressive prediction.

@hk.transform
def fast_autoregressive_predict_fn(context, seq_len):
    """Given a context, autoregressively generate the rest of a sine wave."""
    core = hk.LSTM(32)
    dense = hk.Linear(1)
    state = core.initial_state(context.shape[0])
    # Unroll over the context using `hk.dynamic_unroll`.
    # As before, we `hk.BatchApply` the Linear for efficiency.
    context_outs, state = hk.dynamic_unroll(
        core,
        context,
        state,
        time_major=False,
    )
    context_outs = hk.BatchApply(dense)(context_outs)

    # Now, unroll one step at a time using the running recurrent state.
    ar_outs = []
    x = context_outs[:, -1, :]
    times = range(seq_len - context.shape[1])
    for _ in times:
        x, state = core(x, state)
        x = dense(x)
        ar_outs.append(x)
    ar_outs = jnp.stack(ar_outs)
    ar_outs = ar_outs.transpose(1, 0, 2)
    return jnp.concatenate([context_outs, ar_outs], axis=1)


fast_autoregressive_predict = jax.jit(
    fast_autoregressive_predict_fn.apply, static_argnums=(3,)
)
%%time
# Reuse the same context from the previous cell.
predicted = fast_autoregressive_predict(train_state.params, None, context, SEQ_LEN)
CPU times: user 24.6 s, sys: 54.8 ms, total: 24.6 s
Wall time: 24.6 s
# The plots should be equivalent!
plot = plot_samples(sample_y, predicted)
plot += gg.geom_vline(xintercept=context.shape[1], linetype="dashed")
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_77_0.png

Sample trajectories

sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1]  # Shrink to batch-size 1.
sample_y = sample_y[:1]  # Shrink to batch-size 1.


context_length = SEQ_LEN // 8
print(f"context_length = {context_length}")
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:, :context_length]

# Reuse the same context from the previous cell.
predicted = fast_autoregressive_predict(train_state.params, None, context, SEQ_LEN)

# The plots should be equivalent!
plot = plot_samples(sample_y, predicted)
plot += gg.geom_vline(xintercept=context.shape[1], linetype="dashed")
_ = plot.draw()
context_length = 8
_images/05_reconstructing_the_light_curve_of_stars_79_1.png

timeit

%timeit autoregressive_predict(train_state.params, context, SEQ_LEN)
%timeit fast_autoregressive_predict(train_state.params, None, context, SEQ_LEN)
32.4 ms ± 331 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
25 µs ± 90.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Train all stars

Training

def split_train_validation_date(dataframe, date, look_back) -> TrainSplit:
    train_size = len(dataframe.loc[:date])
    return split_train_validation(dataframe, train_size, look_back)
%%time
train, valid = split_train_validation_date(dataframe_normed, TRAIN_DATE, SEQ_LEN)
print(f"effective train size = {train[0].shape[1]}")
838194 batches rounded to 524288 batches.
26455 batches rounded to 16384 batches.
effective train size = 64
CPU times: user 2.65 s, sys: 743 ms, total: 3.39 s
Wall time: 2.82 s
train[0].shape, train[1].shape, valid[0].shape, valid[1].shape
((524288, 64, 1), (524288, 64, 1), (16384, 64, 1), (16384, 64, 1))
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)
# del train, valid  # Don't leak temporaries.
%%time
train_state, records = train_model(
    model_with_loss,
    train_ds,
    valid_ds,
    len(train.x) // BATCH_SIZE * 1,
    jax.random.PRNGKey(42),
    record_freq=RECORD_FREQ,
)
End of the data set (size=524288). Return to the beginning.
CPU times: user 1min 30s, sys: 408 ms, total: 1min 30s
Wall time: 1min 30s
# Plot losses
losses = pd.DataFrame(records)
df = pd.melt(losses, id_vars=["step"], value_vars=["train_loss", "valid_loss"])
plot = (
    gg.ggplot(df)
    + gg.aes(x="step", y="value", color="variable")
    + gg.geom_line()
    + gg.scales.scale_y_log10()
)
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_89_0.png

Sampling

# Grab a sample from the validation set.
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1]  # Shrink to batch-size 1.
sample_y = sample_y[:1]  # Shrink to batch-size 1.


# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(train_state.params, None, sample_x)

plot = plot_samples(sample_y, predicted)
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_91_0.png

Run autoregressively

%%time
sample_x, sample_y = next(valid_ds)
sample_x = sample_x[:1]  # Shrink to batch-size 1.
sample_y = sample_y[:1]  # Shrink to batch-size 1.


context_length = SEQ_LEN // 8
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:, :context_length]

# Reuse the same context from the previous cell.
predicted = fast_autoregressive_predict(train_state.params, None, context, SEQ_LEN)

# The plots should be equivalent!
plot = plot_samples(sample_y, predicted)
plot += gg.geom_vline(xintercept=len(context), linetype="dashed")
_ = plot.draw()
CPU times: user 64.6 ms, sys: 2.56 ms, total: 67.2 ms
Wall time: 65.8 ms
_images/05_reconstructing_the_light_curve_of_stars_93_1.png
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.70+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[CpuDevice(id=0)]

🦎 Online linear regression with a non-stationary environment 🦎

Open in Colab

We implement an online learning non-stationary linear regression problem.

We go there progressively by showing how a linear regression problem can be cast into an online learning problem thanks to the OnlineSupervisedLearner module.

Then, to tackle a non-stationary linear regression problem (i.e. with a weight that can vary in time) we reformulate the problem into a reinforcement learning problem that we implement with the GymFeedBack module of WAX-ML.

We then need to define an “agent” and an “environment” using simple functions implemented with modules:

  • The agent is responsible for learning the weights of its internal linear model.

  • The environment is responsible for generating labels and evaluating the agent’s reward metric.

We experiment with a non-stationary environment that returns the sign of the linear regression parameters at a given time step, known only to the environment.

This example shows that it is quite simple to implement this online-learning task with WAX-ML tools. In particular, the functional workflow adopted here allows reusing the functions implemented for a task for each new task of increasing complexity,

In this journey, we will use:

  • Haiku basic linear module hk.Linear.

  • Optax stochastic gradient descent optimizer: sgd.

  • WAX-ML modules: OnlineSupervisedLearner, Lag, GymFeedBack

  • WAX-ML helper functions: unroll, jit_init_apply

%pylab inline
Populating the interactive namespace from numpy and matplotlib
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from matplotlib import pyplot as plt
from wax.compile import jit_init_apply
from wax.modules import OnlineSupervisedLearner

Static Linear Regression

First, let’s implement a simple linear regression

Generate data

Let’s generate a batch of data:

seq = hk.PRNGSequence(42)
X = jax.random.normal(next(seq), (100, 3))
w_true = jnp.ones(3)

Define the model

We use the basic module hk.Linear which is a linear layer. By default, it initializes the weights with random values from the truncated normal, with a standard deviation of \(1 / \sqrt{N}\) (See https://arxiv.org/abs/1502.03167v3) where \(N\) is the size of the inputs.

@jit_init_apply
@hk.transform_with_state
def linear_model(x):
    return hk.Linear(output_size=1, with_bias=False)(x)

Run the model

Let’s run the model using WAX-ML unroll on the batch of data.

from wax.unroll import unroll
params, state = linear_model.init(next(seq), X[0])
linear_model.apply(params, state, None, X[0])
(DeviceArray([-0.2070887], dtype=float32), FlatMapping({}))
Y_pred = unroll(linear_model, rng=next(seq))(X)
Y_pred.shape
(100, 1)

Check cost

Let’s look at the mean squared error for this non-trained model.

noise = jax.random.normal(next(seq), (100,))
Y = X.dot(w_true) + noise
L = ((Y - Y_pred) ** 2).sum(axis=1)
mean_loss = L.mean()
assert mean_loss > 0

Let’s look at the regret (cumulative sum of the loss) for the non-trained model.

plt.plot(L.cumsum())
plt.title("Regret")
Text(0.5, 1.0, 'Regret')
_images/06_Online_Linear_Regression_26_1.png

As expected, we have a linear regret when we did not train the model!

Online Linear Regression

We will now start training the model online. For a review on online-learning methods see [1]

[1] Elad Hazan, Introduction to Online Convex Optimization

Define an optimizer

opt = optax.sgd(1e-3)

Define a loss

Since we are doing online learning, we need to define a local loss function: $\( \ell_t(y, w, x) = \lVert y_t - w \cdot x_t \rVert^2 \)$

@jax.jit
def loss(y_pred, y):
    return jnp.mean(jnp.square(y_pred - y))

Define a learning strategy

@jit_init_apply
@hk.transform_with_state
def learner(x, y):
    return OnlineSupervisedLearner(linear_model, opt, loss)(x, y)

Generate data

def generate_many_observations(T=300, sigma=1.0e-2, rng=None):
    rng = jax.random.PRNGKey(42) if rng is None else rng
    X = jax.random.normal(rng, (T, 3))
    noise = sigma * jax.random.normal(rng, (T,))
    w_true = jnp.ones(3)
    noise = sigma * jax.random.normal(rng, (T,))
    Y = X.dot(w_true) + noise
    return (X, Y)
T = 3000
X, Y = generate_many_observations(T)

Unroll the learner

(output, info) = unroll(learner, rng=next(seq))(X, Y)

Plot the regret

Let’s look at the loss and regret over time.

fig, axs = plt.subplots(1, 2, figsize=(9, 3))
axs[0].plot(info.loss.cumsum())
axs[0].set_title("Regret")

axs[1].plot(info.params["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
Text(0.5, 1.0, 'Weight[0,0]')
_images/06_Online_Linear_Regression_44_1.png

We have sub-linear regret!

Online learning with Gym

Now we will recast the online linear regression learning task as a reinforcement learning task implemented with the GymFeedback module of WAX-ML.

For that, we define:

  • obserbations (obs) : pairs (x, y) of features and labels

  • raw observations (raw_obs): pairs (x, noise) of features and noise.

Linear regression agent

In WAX-ML, an agent is a simple function with the following API:

logo

Let’s define a simple linear regression agent with the elements we have defined so far.

def linear_regression_agent(obs):
    x, y = obs

    @jit_init_apply
    @hk.transform_with_state
    def model(x):
        return hk.Linear(output_size=1, with_bias=False)(x)

    opt = optax.sgd(1e-3)

    @jax.jit
    def loss(y_pred, y):
        return jnp.mean(jnp.square(y_pred - y))

    return OnlineSupervisedLearner(model, opt, loss)(x, y)

Linear regression environment

In WAX-ML, an environment is a simple function with the following API:

logo

Let’s now define a linear regression environment that, for the moment, have static weights.

It is responsible for generating the real labels and evaluating the agent’s reward.

For the evaluation of the reward, we need the Lag module to evaluate the action of the agent with the labels generated in the previous time step.

from wax.modules import Lag
def stationary_linear_regression_env(action, raw_obs):

    # Only the environment now the true value of the parameters
    w_true = -jnp.ones(3)

    # The environment has its proper loss definition
    @jax.jit
    def loss(y_pred, y):
        return jnp.mean(jnp.square(y_pred - y))

    # raw observation contains features and generative noise
    x, noise = raw_obs

    # generate targets
    y = x @ w_true + noise
    obs = (x, y)

    y_previous = Lag(1)(y)
    # evaluate the prediction made by the agent
    y_pred = action
    reward = loss(y_pred, y_previous)

    return reward, obs, {}

Generate raw observation

Let’s define a function that generate the raw observation:

def generate_many_raw_observations(T=300, sigma=1.0e-2, rng=None):
    rng = jax.random.PRNGKey(42) if rng is None else rng
    X = jax.random.normal(rng, (T, 3))
    noise = sigma * jax.random.normal(rng, (T,))
    return (X, noise)

Implement Feedback

We are now ready to set things up with the GymFeedback module implemented in WAX-ML.

It implements the following feedback loop:

logo

Equivalently, it can be described with the pair of init and apply functions:

logo
from wax.modules import GymFeedback
@hk.transform_with_state
def gym_fun(raw_obs):
    return GymFeedback(
        linear_regression_agent, stationary_linear_regression_env, return_action=True
    )(raw_obs)

And now we can unroll it on a sequence of raw observations!

seq = hk.PRNGSequence(42)
T = 3000
raw_observations = generate_many_raw_observations(T)
rng = next(seq)
(gym_output, gym_info) = unroll(gym_fun, rng=rng, skip_first=True)(raw_observations)

Let’s visualize the outputs.

We now use pd.Series to represent the reward sequence since its first value is Nan due to the use of the lag operator.

import pandas as pd
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
pd.Series(gym_output.reward).cumsum().plot(ax=axs[0], title="Regret")
axs[1].plot(info.params["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
Text(0.5, 1.0, 'Weight[0,0]')
_images/06_Online_Linear_Regression_67_1.png

Non-stationary environment

Now, let’s implement a non-stationary environment.

We implement it so that the sign of the weight is reversed after 2000$ steps.

class NonStationaryEnvironment(hk.Module):
    def __call__(self, action, raw_obs):

        step = hk.get_state("step", [], init=lambda *_: 0)

        # Only the environment now the true value of the parameters
        # at step 2000 we flip the sign of the true parameters !
        w_true = hk.cond(
            step < 2000,
            step,
            lambda step: -jnp.ones(3),
            step,
            lambda step: jnp.ones(3),
        )

        # The environment has its proper loss definition
        @jax.jit
        def loss(y_pred, y):
            return jnp.mean(jnp.square(y_pred - y))

        # raw observation contains features and generative noise
        x, noise = raw_obs

        # generate targets
        y = x @ w_true + noise
        obs = (x, y)

        # evaluate the prediction made by the agent
        y_previous = Lag(1)(y)
        y_pred = action
        reward = loss(y_pred, y_previous)

        step += 1
        hk.set_state("step", step)

        return reward, obs, {}

Now let’s run a gym simulation to see how the agent adapt to the change of environment.

@hk.transform_with_state
def gym_fun(raw_obs):
    return GymFeedback(
        linear_regression_agent, NonStationaryEnvironment(), return_action=True
    )(raw_obs)
T = 6000
raw_observations = generate_many_raw_observations(T)
rng = jax.random.PRNGKey(42)
(gym_output, gym_info) = unroll(gym_fun, rng=rng, skip_first=True)(
    raw_observations,
)
import pandas as pd
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
pd.Series(gym_output.reward).cumsum().plot(ax=axs[0], title="Regret")
axs[1].plot(gym_info.agent.params["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
# plt.savefig("../_static/online_linear_regression_regret.png")
Text(0.5, 1.0, 'Weight[0,0]')
_images/06_Online_Linear_Regression_75_1.png

It adapts!

The regret first converges, then jumps on step 2000 and finally readjusts to the new regime.

We see that the weights converge to the correct values in both regimes.

🔄 Online learning for time series prediction 🔄

In [1], the authors develop an online learning method to predict time-series generated by and ARMA (autoregressive moving average) model.

They develop an effective online learning algorithm based on an improper learning approach which consists to use an AR model for prediction with sufficiently long horizon, together with an online update of the prediction model parameters using either an “online Newton” algorithm [2] or a stochastic gradient descent algorithm.

This effective approach adresses the prediction problem, without assuming that the noise terms are Gaussian, identically distributed or even independent. Furthermore, they show that their algorithm’s performances asymptotically approaches the performance of the best ARMA model in hindsight.

We use WAX-ML to reproduce their empirical results.

We first focus in the reproduction of the “setting 1” for sanity checks and show how to setup a training environment with WAX-ML to study improper learning of ARMA time-series models.

We then study the behavior the method with different optimizers in the non-stationary environements proposed in [1] (settings 2, 3, and 4).

We use the following modules from WAX-ML:

  • ARMA : to generate a modeled time-series

  • SNARIMAX : to adaptively learn to predict the generated time-series.

  • GymFeedback: To setup a training loop.

  • VMap: to add batch dimensions to the training loop

  • optim.newton: a newton algorithm as used in [1] and developped in [2]. It extends optax optimizers.

  • OnlineOptimizer: A wrapper for a model with loss and and optimizer for online learning.

References

[1] Anava, O., Hazan, E., Mannor, S. and Shamir, O., 2013, June. Online learning for time series prediction. In Conference on learning theory (pp. 172-184)

[2] Hazan, E., Agarwal, A. and Kale, S., 2007. Logarithmic regret algorithms for online convex optimization. Machine Learning, 69(2-3), pp.169-192

%pylab inline
%load_ext autoreload
%autoreload 2
Populating the interactive namespace from numpy and matplotlib
from typing import Any, NamedTuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
import optax
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

from wax.modules import (
    ARMA,
    SNARIMAX,
    GymFeedback,
    Lag,
    OnlineOptimizer,
    UpdateParams,
    VMap,
)
from wax.optim import newton
from wax.unroll import unroll_transform_with_state
T = 10000
N_BATCH = 20
N_STEP_SIZE = 10
N_EPS = 5

ARMA

Let’s generate a sample of the “setting 1” of [1]:

alpha = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
beta = jnp.array([0.3, -0.2])

rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
sim = unroll_transform_with_state(lambda eps: ARMA(alpha, beta)(eps))
params, state = sim.init(rng, eps)
y, state = sim.apply(params, state, rng, eps)
pd.Series(y).plot()
_images/07_Online_Time_Series_Prediction_6_0.png

SNARIMAX

Let’s setup an online model to try to learn the dynamic of the time-series.

First let’s run the filter with it’s initial random weights.

def predict(y, X=None):
    return SNARIMAX(10, 0, 0)(y, X)


sim = unroll_transform_with_state(predict)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (10,))
params, state = sim.init(rng, y)
(y_pred, _), state = sim.apply(params, state, rng, y)

pd.Series((y - y_pred)).plot()
_images/07_Online_Time_Series_Prediction_10_0.png
def evaluate(y_pred, y):
    return jnp.linalg.norm(y_pred - y) ** 2, {}


def lag(shift=1):
    def __call__(y, X=None):
        yp = Lag(shift)(y)
        Xp = Lag(shift)(X) if X is not None else None
        return yp, Xp

    return __call__


def predict_and_evaluate(y, X=None):
    # predict with lagged data
    y_pred, pred_info = predict(*lag(1)(y, X))

    # evaluate loss with actual data
    loss, loss_info = evaluate(y_pred, y)

    return loss, dict(pred_info=pred_info, loss_info=loss_info)
sim = unroll_transform_with_state(predict_and_evaluate)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, y)
(loss, _), state = sim.apply(params, state, rng, y)

pd.Series(loss).expanding().mean().plot()
_images/07_Online_Time_Series_Prediction_12_0.png

Since the model is not trained and the coefficient of the SNARIMAX filter a choosen randomly, the loss may diverge

params
FlatMapping({
  'snarimax/~/linear': FlatMapping({
                         'w': DeviceArray([[-0.47023755],
                                           [ 0.07070494],
                                           [-0.07388116],
                                           [ 0.13453043],
                                           [ 0.07728617],
                                           [ 0.0851969 ],
                                           [-0.03324771],
                                           [ 0.17115903],
                                           [ 0.1023274 ],
                                           [-0.4804019 ]], dtype=float32),
                         'b': DeviceArray([0.], dtype=float32),
                       }),
})

Learn

To learn the model parameters we will use the OnlineOptimizer of WAX-ML.

Setup projection

We can setup a projection for the parameters:

def project_params(params, opt_state=None):
    w = params["snarimax/~/linear"]["w"]
    w = jnp.clip(w, -1, 1)
    params["snarimax/~/linear"]["w"] = w
    return params
project_params(hk.data_structures.to_mutable_dict(params))
{'snarimax/~/linear': {'w': DeviceArray([[-0.47023755],
               [ 0.07070494],
               [-0.07388116],
               [ 0.13453043],
               [ 0.07728617],
               [ 0.0851969 ],
               [-0.03324771],
               [ 0.17115903],
               [ 0.1023274 ],
               [-0.4804019 ]], dtype=float32),
  'b': DeviceArray([0.], dtype=float32)}}
def learn(y, X=None):
    optim_res = OnlineOptimizer(
        predict_and_evaluate, optax.sgd(1.0e-2), project_params=project_params
    )(y, X)
    return optim_res

Let’s train:

sim = unroll_transform_with_state(learn)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, eps)
optim_res, state = sim.apply(params, state, rng, eps)

pd.Series(optim_res.loss).expanding().mean().plot()
_images/07_Online_Time_Series_Prediction_23_0.png

Let’s look at the latest weights:

jax.tree_map(lambda x: x[-1], optim_res.updated_params)
FlatMapping({
  'snarimax/~/linear': FlatMapping({
                         'b': DeviceArray([-0.14335798], dtype=float32),
                         'w': DeviceArray([[-0.06957585],
                                           [ 0.11199051],
                                           [-0.14902674],
                                           [-0.02122167],
                                           [-0.03742805],
                                           [ 0.00436568],
                                           [-0.10548005],
                                           [-0.05702855],
                                           [-0.01185708],
                                           [-0.01168777]], dtype=float32),
                       }),
})

Learn and Forecast

class ForecastInfo(NamedTuple):
    optim: Any
    forecast: Any
def learn_and_forecast(y, X=None):
    optim_res = OnlineOptimizer(
        predict_and_evaluate, optax.sgd(1.0e-3), project_params=project_params
    )(*lag(1)(y, X))

    predict_params = optim_res.updated_params

    forecast, forecast_info = UpdateParams(predict)(predict_params, y, X)
    return forecast, ForecastInfo(optim_res, forecast_info)
sim = unroll_transform_with_state(learn_and_forecast)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (1000,))
params, state = sim.init(rng, y)
(forecast, info), state = sim.apply(params, state, rng, y)

pd.Series(info.optim.loss).expanding().mean().plot()
<AxesSubplot:>
_images/07_Online_Time_Series_Prediction_29_1.png

Gym simulation

Now let’s wrapup the training loop in a gym feedback loop.

Environment

Let’s build an environment corresponding to “setting 1” in [1]

def build_env():
    def env(action, obs):
        y_pred, eps = action, obs
        ar_coefs = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
        ma_coefs = jnp.array([0.3, -0.2])

        y = ARMA(ar_coefs, ma_coefs)(eps)
        # prediction used on a fresh y observation.
        rw = -((y - y_pred) ** 2)

        env_info = {"y": y, "y_pred": y_pred}
        obs = y
        return rw, obs, env_info

    return env
env = build_env()

Agent

Let’s build an agent:

from optax._src.base import OptState
def build_agent(time_series_model=None, opt=None):
    if time_series_model is None:
        time_series_model = lambda y, X: SNARIMAX(10)(y, X)

    if opt is None:
        opt = optax.sgd(1.0e-3)

    class AgentInfo(NamedTuple):
        optim: Any
        forecast: Any

    class ModelWithLossInfo(NamedTuple):
        pred: Any
        loss: Any

    def agent(obs):
        if isinstance(obs, tuple):
            y, X = obs
        else:
            y = obs
            X = None

        def evaluate(y_pred, y):
            return jnp.linalg.norm(y_pred - y) ** 2, {}

        def model_with_loss(y, X=None):
            # predict with lagged data
            y_pred, pred_info = time_series_model(*lag(1)(y, X))

            # evaluate loss with actual data
            loss, loss_info = evaluate(y_pred, y)

            return loss, ModelWithLossInfo(pred_info, loss_info)

        def project_params(params: Any, opt_state: OptState = None):
            del opt_state
            return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)

        def split_params(params):
            def filter_params(m, n, p):
                # print(m, n, p)
                return m.endswith("snarimax/~/linear") and n == "w"

            return hk.data_structures.partition(filter_params, params)

        def learn_and_forecast(y, X=None):
            optim_res = OnlineOptimizer(
                model_with_loss,
                opt,
                project_params=project_params,
                split_params=split_params,
            )(*lag(1)(y, X))

            predict_params = optim_res.updated_params

            y_pred, forecast_info = UpdateParams(time_series_model)(
                predict_params, y, X
            )
            return y_pred, AgentInfo(optim_res, forecast_info)

        return learn_and_forecast(y, X)

    return agent
agent = build_agent()

Gym loop

def gym_loop(eps):
    return GymFeedback(agent, env)(eps)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T,))
sim = unroll_transform_with_state(gym_loop)
params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
pd.Series(info.agent.optim.loss).expanding().mean().plot(label="agent loss")
pd.Series(-gym.reward).expanding().mean().plot(xlim=(0, 100), label="env loss")
plt.legend()
<matplotlib.legend.Legend at 0x167c42490>
_images/07_Online_Time_Series_Prediction_43_1.png
pd.Series(-gym.reward).expanding().mean().plot()  # ylim=(0.09, 0.15))
<AxesSubplot:>
_images/07_Online_Time_Series_Prediction_44_1.png

We see that the agent suffers the same loss as the environment but with a time lag.

Batch simulations

Average over 20 experiments

Slow version

First, let’s do it “naively” by doing a simple python “for loop”.

%%time
rng = jax.random.PRNGKey(42)
sim = unroll_transform_with_state(gym_loop)

res = {}
for i in tqdm(onp.arange(N_BATCH)):
    rng, _ = jax.random.split(rng)
    eps = jax.random.normal(rng, (T,)) * 0.3
    params, state = sim.init(rng, eps)
    (gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
    res[i] = gym_info
CPU times: user 13.7 s, sys: 182 ms, total: 13.9 s
Wall time: 13.8 s
pd.DataFrame({k: pd.Series(v.agent.optim.loss) for k, v in res.items()}).mean(
    1
).expanding().mean().plot(ylim=(0.09, 0.15))
_images/07_Online_Time_Series_Prediction_50_0.png

Fast version with vmap

Instead of using a “for loop” we can use jax’s vmap transformation function!

%%time
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (N_BATCH, T)) * 0.3

rng = jax.random.PRNGKey(42)
rng = jax.random.split(rng, num=N_BATCH)
sim = unroll_transform_with_state(gym_loop)
params, state = jax.vmap(sim.init)(rng, eps)
(gym_output, gym_info), final_state = jax.vmap(sim.apply)(params, state, rng, eps)
CPU times: user 2.09 s, sys: 61 ms, total: 2.15 s
Wall time: 2.13 s

This is much faster!

pd.DataFrame(gym_info.agent.optim.loss).mean().expanding().mean().plot(
    ylim=(0.09, 0.15)
)
_images/07_Online_Time_Series_Prediction_54_0.png
i_batch = 0
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][i_batch, :, :, 0]
w = pd.DataFrame(w)

ax = w.plot(title=f"weights on batch {i_batch}")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight value")


plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
_images/07_Online_Time_Series_Prediction_55_0.png _images/07_Online_Time_Series_Prediction_55_1.png
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"].mean(axis=0)[:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")

plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
_images/07_Online_Time_Series_Prediction_56_0.png _images/07_Online_Time_Series_Prediction_56_1.png

With VMap module

We can use the wrapper module VMap of WAX-ML. It permits to have an ever simpler syntax.

Note: we have to swap the position of time and batch dimensions in the generation of the noise variable eps.

%%time
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3


def batched_gym_loop(eps):
    return VMap(gym_loop)(eps)


sim = unroll_transform_with_state(batched_gym_loop)

rng = jax.random.PRNGKey(43)
params, state = sim.init(rng, eps)
(gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
CPU times: user 1.53 s, sys: 15.5 ms, total: 1.54 s
Wall time: 1.53 s
pd.DataFrame(gym_info.agent.optim.loss).shape
(10000, 20)
pd.DataFrame(gym_info.agent.optim.loss).mean(1).expanding().mean().plot(
    ylim=(0.09, 0.15)
)
_images/07_Online_Time_Series_Prediction_61_0.png
i_batch = 0
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][:, i_batch, :, 0]
w = pd.DataFrame(w)

ax = w.plot(title=f"weights on batch {i_batch}")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight value")


plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
_images/07_Online_Time_Series_Prediction_62_0.png _images/07_Online_Time_Series_Prediction_62_1.png
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"].mean(axis=1)[:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")

plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
_images/07_Online_Time_Series_Prediction_63_0.png _images/07_Online_Time_Series_Prediction_63_1.png

Taking mean inside simulation

%%time


def add_batch(fun, take_mean=True):
    def fun_batch(*args, **kwargs):
        res = VMap(fun)(*args, **kwargs)
        if take_mean:
            res = jax.tree_map(lambda x: x.mean(axis=0), res)
        return res

    return fun_batch


gym_loop_batch = add_batch(gym_loop)
sim = unroll_transform_with_state(gym_loop_batch)

rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3

params, state = sim.init(rng, eps)
(gym_output, gym_info), final_state = sim.apply(params, state, rng, eps)
CPU times: user 1.53 s, sys: 18.8 ms, total: 1.54 s
Wall time: 1.54 s
pd.Series(-gym_output.reward).expanding().mean().plot(ylim=(0.09, 0.15))
_images/07_Online_Time_Series_Prediction_66_0.png
w = gym_info.agent.optim.updated_params["snarimax/~/linear"]["w"][:, :, 0]
w = pd.DataFrame(w)
ax = w.plot(title="averaged weights (over batches)")
ax.legend(bbox_to_anchor=(1.0, 1.0))
ax.set_xlabel("time")
ax.set_ylabel("weight")

plt.figure()
w.iloc[-1][::-1].plot(kind="bar")
_images/07_Online_Time_Series_Prediction_67_0.png _images/07_Online_Time_Series_Prediction_67_1.png

Hyper parameter tuning

First order optimizers

We will consider different first order optimizers, namely:

  • SGD

  • ADAM

  • ADAGRAD

For each of them, we will scan the “step_size” parameter \(\eta\).

We will average results over batches of size 40.

We will consider trajectories of size 10.000.

Finally, we will pickup the best parameter based of the minimum averaged loss for the last 5000 time steps.

%%time

STEP_SIZE_idx = pd.Index(onp.logspace(-4, 1, 30), name="step_size")
STEP_SIZE = jax.device_put(STEP_SIZE_idx.values)
OPTIMIZERS = [optax.sgd, optax.adagrad, optax.rmsprop, optax.adam]

res = {}
for optimizer in tqdm(OPTIMIZERS):

    def gym_loop_scan_hparams(eps):
        def scan_params(step_size):
            return GymFeedback(build_agent(opt=optimizer(step_size)), env)(eps)

        res = VMap(scan_params)(STEP_SIZE)
        return res

    sim = unroll_transform_with_state(add_batch(gym_loop_scan_hparams))
    rng = jax.random.PRNGKey(42)
    eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3

    params, state = sim.init(rng, eps)
    _res, state = sim.apply(params, state, rng, eps)
    res[optimizer.__name__] = _res
CPU times: user 11.5 s, sys: 111 ms, total: 11.6 s
Wall time: 11.6 s
ax = None
BEST_STEP_SIZE = {}
BEST_GYM = {}
for name, (gym, info) in res.items():

    loss = pd.DataFrame(-gym.reward, columns=STEP_SIZE).iloc[-5000:].mean()

    BEST_STEP_SIZE[name] = loss.idxmin()
    best_idx = loss.reset_index(drop=True).idxmin()
    BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)

    ax = loss[loss < 0.15].plot(logx=True, logy=False, ax=ax, label=name)
plt.legend()
_images/07_Online_Time_Series_Prediction_71_0.png
for name, gym in BEST_GYM.items():
    ax = (
        pd.Series(-gym.reward)
        .expanding()
        .mean()
        .plot(
            label=f"{name}    -    $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.15)
        )
    )
ax.legend(bbox_to_anchor=(1.0, 1.0))
_images/07_Online_Time_Series_Prediction_72_0.png
pd.Series(BEST_STEP_SIZE).plot(kind="bar", logy=True)
_images/07_Online_Time_Series_Prediction_73_0.png

Newton algorithm

Now let’s consider the newton algorithm.

First let’s test it with one set of parameter with average over N_BATCH batches.

%%time


@add_batch
def gym_loop_newton(eps):
    return GymFeedback(build_agent(opt=newton(0.05, eps=20.0)), env)(eps)


sim = unroll_transform_with_state(gym_loop_newton)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3

params, state = sim.init(rng, eps)
(gym, info), state = sim.apply(params, state, rng, eps)
CPU times: user 1.88 s, sys: 43.5 ms, total: 1.92 s
Wall time: 1.73 s
pd.Series(-gym.reward).expanding().mean().plot()
_images/07_Online_Time_Series_Prediction_78_0.png
%%time

STEP_SIZE = pd.Index(onp.logspace(-2, 3, 10), name="step_size")
EPS = pd.Index(onp.logspace(-4, 3, 5), name="eps")

HPARAMS_idx = pd.MultiIndex.from_product([STEP_SIZE, EPS])
HPARAMS = jnp.stack(list(map(onp.array, HPARAMS_idx)))


@add_batch
def gym_loop_scan_hparams(eps):
    def scan_params(hparams):
        step_size, newton_eps = hparams
        agent = build_agent(opt=newton(step_size, eps=newton_eps))
        return GymFeedback(agent, env)(eps)

    return VMap(scan_params)(HPARAMS)


sim = unroll_transform_with_state(gym_loop_scan_hparams)
rng = jax.random.PRNGKey(42)
eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3

params, state = sim.init(rng, eps)
res_newton, state = sim.apply(params, state, rng, eps)
CPU times: user 16.1 s, sys: 118 ms, total: 16.2 s
Wall time: 16 s
gym_newton, info_newton = res_newton
loss_newton = pd.DataFrame(-gym_newton.reward, columns=HPARAMS_idx).mean().unstack()
loss_newton = (
    pd.DataFrame(-gym_newton.reward, columns=HPARAMS_idx).iloc[-5000:].mean().unstack()
)
sns.heatmap(loss_newton[loss_newton < 0.4], annot=True, cmap="YlGnBu")

STEP_SIZE, NEWTON_EPS = loss_newton.stack().idxmin()

x = -gym_newton.reward[-5000:].mean(axis=0)
x = jax.ops.index_update(x, jnp.isnan(x), jnp.inf)
I_BEST_PARAM = jnp.argmin(x)


BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)
Best newton parameters:  0.464158883361278 0.31622776601683794
_images/07_Online_Time_Series_Prediction_83_1.png
for name, gym in BEST_GYM.items():
    pd.Series(-gym.reward).rolling(5000, min_periods=5000).mean().plot(
        label=f"{name}    -    $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.1)
    )

gym = BEST_NEWTON_GYM
ax = (
    pd.Series(-gym.reward)
    .rolling(5000, min_periods=5000)
    .mean()
    .plot(
        label=f"Newton    -    $\eta$={STEP_SIZE:.2e},    $\epsilon$={NEWTON_EPS:.2e}"
    )
)
ax.legend(bbox_to_anchor=(1.0, 1.0))
plt.title("Rolling mean of loss (5000) time-steps")
_images/07_Online_Time_Series_Prediction_84_0.png
for name, gym in BEST_GYM.items():
    pd.Series(-gym.reward).expanding().mean().plot(
        label=f"{name}    -    $\eta$={BEST_STEP_SIZE[name]:.2e}", ylim=(0.09, 0.15)
    )
gym = BEST_NEWTON_GYM
ax = (
    pd.Series(-gym.reward)
    .expanding()
    .mean()
    .plot(
        label=f"Newton    -    $\eta$={STEP_SIZE:.2e},    $\epsilon$={NEWTON_EPS:.2e}"
    )
)
ax.legend(bbox_to_anchor=(1.0, 1.0))
_images/07_Online_Time_Series_Prediction_85_0.png

In agreement with results in [1], we see that Newton’s algorithm performs much better than SGD.

In addition, we note that:

  • ADAGRAD performormance is between newton and sgd.

  • RMSPROPR and ADAM does not perform well in this online setting.

🔄 Online learning in non-stationary environments 🔄

We reproduce the empirical results of [1].

References

[1] Anava, O., Hazan, E., Mannor, S. and Shamir, O., 2013, June. Online learning for time series prediction. In Conference on learning theory (pp. 172-184)

%pylab inline
%load_ext autoreload
%autoreload 2
Populating the interactive namespace from numpy and matplotlib
from typing import Any, NamedTuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
import optax
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

from wax.modules import ARMA, SNARIMAX, GymFeedback, OnlineOptimizer, UpdateParams, VMap
from wax.modules.lag import tree_lag
from wax.modules.vmap import add_batch
from wax.optim import newton
from wax.unroll import unroll_transform_with_state
T = 10000
N_BATCH = 20
N_STEP_SIZE = 30
N_STEP_SIZE_NEWTON = 10
N_EPS = 5

Agent

OPTIMIZERS = [optax.sgd, optax.adagrad, optax.rmsprop, optax.adam]
from optax._src.base import OptState


def build_agent(time_series_model=None, opt=None):
    if time_series_model is None:
        time_series_model = lambda y, X: SNARIMAX(10)(y, X)

    if opt is None:
        opt = optax.sgd(1.0e-3)

    class AgentInfo(NamedTuple):
        optim: Any
        forecast: Any

    class ModelWithLossInfo(NamedTuple):
        pred: Any
        loss: Any

    def agent(obs):
        if isinstance(obs, tuple):
            y, X = obs
        else:
            y = obs
            X = None

        def evaluate(y_pred, y):
            return jnp.linalg.norm(y_pred - y) ** 2, {}

        def model_with_loss(y, X=None):
            # predict with lagged data
            y_pred, pred_info = time_series_model(*tree_lag(1)(y, X))

            # evaluate loss with actual data
            loss, loss_info = evaluate(y_pred, y)

            return loss, ModelWithLossInfo(pred_info, loss_info)

        def project_params(params: Any, opt_state: OptState = None):
            del opt_state
            return jax.tree_map(lambda w: jnp.clip(w, -1, 1), params)

        def split_params(params):
            def filter_params(m, n, p):
                # print(m, n, p)
                return m.endswith("snarimax/~/linear") and n == "w"

            return hk.data_structures.partition(filter_params, params)

        def learn_and_forecast(y, X=None):
            # use lagged data for the optimizer
            optim_res = OnlineOptimizer(
                model_with_loss,
                opt,
                project_params=project_params,
                split_params=split_params,
            )(*tree_lag(1)(y, X))

            # use updated params to forecast with actual data
            predict_params = optim_res.updated_params

            y_pred, forecast_info = UpdateParams(time_series_model)(
                predict_params, y, X
            )
            return y_pred, AgentInfo(optim_res, forecast_info)

        return learn_and_forecast(y, X)

    return agent

Non-stationary environments

We will now wrapup the study of an environment + agent in few analysis functions.

We will then use them to perform the same analysis in the non-stationary setting proposed in [1], namely:

  • setting 1 : sanity check (stationary ARMA environment).

  • setting 2 : slowly varying parameters.

  • setting 3 : brutal variation of parameters.

  • setting 4 : non-stationary (random walk) noise.

Analysis functions

For each solver, we will select the best hyper parameters (step size \(\eta\), \(\epsilon\)) by measuring the average loss between the 5000 and 10000 steps.

First order solvers

def scan_hparams_first_order():

    STEP_SIZE_idx = pd.Index(onp.logspace(-4, 1, N_STEP_SIZE), name="step_size")
    STEP_SIZE = jax.device_put(STEP_SIZE_idx.values)

    rng = jax.random.PRNGKey(42)
    eps = sample_noise(rng)

    res = {}
    for optimizer in tqdm(OPTIMIZERS):

        def gym_loop_scan_hparams(eps):
            def scan_params(step_size):
                return GymFeedback(build_agent(opt=optimizer(step_size)), env)(eps)

            res = VMap(scan_params)(STEP_SIZE)
            return res

        sim = unroll_transform_with_state(add_batch(gym_loop_scan_hparams))

        params, state = sim.init(rng, eps)
        _res, state = sim.apply(params, state, rng, eps)
        res[optimizer.__name__] = _res

    ax = None
    BEST_STEP_SIZE = {}
    BEST_GYM = {}

    for name, (gym, info) in res.items():

        loss = (
            pd.DataFrame(-gym.reward, columns=STEP_SIZE).iloc[LEARN_TIME_SLICE].mean()
        )

        BEST_STEP_SIZE[name] = loss.idxmin()

        best_idx = jnp.argmax(gym.reward[LEARN_TIME_SLICE].mean(axis=0))
        BEST_GYM[name] = jax.tree_map(lambda x: x[:, best_idx], gym)

        ax = loss.plot(
            logx=True, logy=False, ax=ax, label=name, ylim=(MIN_ERR, MAX_ERR)
        )
    plt.legend()

    return BEST_STEP_SIZE, BEST_GYM

We will “cross-validate” the result by running the agent on new samples.

CROSS_VAL_RNG = jax.random.PRNGKey(44)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
COLORS = sns.color_palette("hls")
def cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM):
    plt.figure()
    eps = sample_noise(CROSS_VAL_RNG)
    CROSS_VAL_GYM = {}
    ax = None

    # def measure(reward):
    #     return pd.Series(-reward).rolling(T/2, min_periods=T/2).mean()

    def measure(reward):
        return pd.Series(-reward).expanding().mean()

    for i, (name, gym) in enumerate(BEST_GYM.items()):
        ax = measure(gym.reward).plot(
            ax=ax,
            color=COLORS[i],
            label=(f"(TRAIN) -  {name}    " f"-    $\eta$={BEST_STEP_SIZE[name]:.2e}"),
            style="--",
        )
    for i, optimizer in enumerate(tqdm(OPTIMIZERS)):

        name = optimizer.__name__

        def gym_loop(eps):
            return GymFeedback(build_agent(opt=optimizer(BEST_STEP_SIZE[name])), env)(
                eps
            )

        sim = unroll_transform_with_state(add_batch(gym_loop))

        rng = jax.random.PRNGKey(42)
        params, state = sim.init(rng, eps)
        (gym, info), state = sim.apply(params, state, rng, eps)
        CROSS_VAL_GYM[name] = gym

        ax = measure(gym.reward).plot(
            ax=ax,
            color=COLORS[i],
            ylim=(MIN_ERR, MAX_ERR),
            label=(
                f"(VALIDATE) -  {name}    " f"-    $\eta$={BEST_STEP_SIZE[name]:.2e}"
            ),
        )
    plt.legend()

    return CROSS_VAL_GYM

Newton solver

def scan_hparams_newton():
    STEP_SIZE = pd.Index(onp.logspace(-2, 3, N_STEP_SIZE_NEWTON), name="step_size")
    EPS = pd.Index(onp.logspace(-4, 3, N_EPS), name="eps")

    HPARAMS_idx = pd.MultiIndex.from_product([STEP_SIZE, EPS])
    HPARAMS = jnp.stack(list(map(onp.array, HPARAMS_idx)))

    @add_batch
    def gym_loop_scan_hparams(eps):
        def scan_params(hparams):
            step_size, newton_eps = hparams
            agent = build_agent(opt=newton(step_size, eps=newton_eps))
            return GymFeedback(agent, env)(eps)

        return VMap(scan_params)(HPARAMS)

    sim = unroll_transform_with_state(gym_loop_scan_hparams)

    rng = jax.random.PRNGKey(42)
    eps = sample_noise(rng)

    params, state = sim.init(rng, eps)
    res_newton, state = sim.apply(params, state, rng, eps)

    gym_newton, info_newton = res_newton

    loss_newton = (
        pd.DataFrame(-gym_newton.reward, columns=HPARAMS_idx)
        .iloc[LEARN_TIME_SLICE]
        .mean()
        .unstack()
    )

    sns.heatmap(loss_newton[loss_newton < 0.4], annot=True, cmap="YlGnBu")

    STEP_SIZE, NEWTON_EPS = loss_newton.stack().idxmin()

    x = -gym_newton.reward[LEARN_TIME_SLICE].mean(axis=0)
    x = jax.ops.index_update(x, jnp.isnan(x), jnp.inf)
    I_BEST_PARAM = jnp.argmin(x)

    BEST_NEWTON_GYM = jax.tree_map(lambda x: x[:, I_BEST_PARAM], gym_newton)
    print("Best newton parameters: ", STEP_SIZE, NEWTON_EPS)
    return (STEP_SIZE, NEWTON_EPS), BEST_NEWTON_GYM
def cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM):
    (STEP_SIZE, NEWTON_EPS) = BEST_HPARAMS
    plt.figure()

    # def measure(reward):
    #     return pd.Series(-reward).rolling(T/2, min_periods=T/2).mean()

    def measure(reward):
        return pd.Series(-reward).expanding().mean()

    @add_batch
    def gym_loop(eps):
        agent = build_agent(opt=newton(STEP_SIZE, eps=NEWTON_EPS))
        return GymFeedback(agent, env)(eps)

    sim = unroll_transform_with_state(gym_loop)

    rng = jax.random.PRNGKey(44)
    eps = sample_noise(rng)
    params, state = sim.init(rng, eps)
    (gym, info), state = sim.apply(params, state, rng, eps)

    ax = None
    i = 4
    ax = measure(BEST_NEWTON_GYM.reward).plot(
        ax=ax,
        color=COLORS[i],
        label=f"(TRAIN) -  Newton    -    $\eta$={STEP_SIZE:.2e},    $\epsilon$={NEWTON_EPS:.2e}",
        ylim=(MIN_ERR, MAX_ERR),
        style="--",
    )

    ax = measure(gym.reward).plot(
        ax=ax,
        color=COLORS[i],
        ylim=(MIN_ERR, MAX_ERR),
        label=f"(VALIDATE) - Newton    -    $\eta$={STEP_SIZE:.2e},    $\epsilon$={NEWTON_EPS:.2e}",
    )

    plt.legend()

    return gym

Plot everithing

def plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM):
    MESURES = []

    def measure(reward):
        return pd.Series(-reward).rolling(int(T / 2), min_periods=int(T / 2)).mean()

    MESURES.append(("Rolling mean of loss (5000) time-steps", measure))

    def measure(reward):
        return pd.Series(-reward).expanding().mean()

    MESURES.append(("Expanding means", measure))

    for MEASURE_NAME, MEASUR_FUNC in MESURES:

        plt.figure()

        for i, (name, gym) in enumerate(BEST_GYM.items()):
            MEASUR_FUNC(gym.reward).plot(
                label=f"{name}    -    $\eta$={BEST_STEP_SIZE[name]:.2e}",
                ylim=(MIN_ERR, MAX_ERR),
                color=COLORS[i],
            )

        i = 4
        (STEP_SIZE, NEWTON_EPS) = BEST_HPARAMS
        gym = BEST_NEWTON_GYM
        ax = MEASUR_FUNC(gym.reward).plot(
            label=f"Newton    -    $\eta$={STEP_SIZE:.2e},    $\epsilon$={NEWTON_EPS:.2e}",
            ylim=(MIN_ERR, MAX_ERR),
            color=COLORS[i],
        )
        ax.legend(bbox_to_anchor=(1.0, 1.0))
        plt.title(MEASURE_NAME)

Setting 1

Environment

let’s wrapup the results for the “setting 1” in [1]

from wax.modules import Counter


def build_env():
    def env(action, obs):
        y_pred, eps = action, obs

        ar_coefs = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
        ma_coefs = jnp.array([0.3, -0.2])

        y = ARMA(ar_coefs, ma_coefs)(eps)

        rw = -((y - y_pred) ** 2)

        env_info = {"y": y, "y_pred": y_pred}
        obs = y
        return rw, obs, env_info

    return env


def sample_noise(rng):
    eps = jax.random.normal(rng, (T, 20)) * 0.3
    return eps


MIN_ERR = 0.09
MAX_ERR = 0.15
LEARN_TIME_SLICE = slice(int(T / 2), T)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()
_images/08_Online_learning_in_non_stationary_environments_26_1.png
CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)
_images/08_Online_learning_in_non_stationary_environments_27_1.png
BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters:  0.464158883361278 0.31622776601683794
_images/08_Online_learning_in_non_stationary_environments_28_1.png
CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_29_0.png
plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_30_0.png _images/08_Online_learning_in_non_stationary_environments_30_1.png

Conclusions

  • The NEWTON and ADAGRAD optimizers are the faster to converge.

  • The SGD and ADAM optimizers have the worst performance.

Fixed setting

@add_batch
def gym_loop_newton(eps):
    return GymFeedback(build_agent(opt=newton(0.1, eps=0.3)), env)(eps)


def run_fixed_setting():
    rng = jax.random.PRNGKey(42)
    eps = sample_noise(rng)
    sim = unroll_transform_with_state(gym_loop_newton)
    params, state = sim.init(rng, eps)
    (gym, info), state = sim.apply(params, state, rng, eps)

    pd.Series(-gym.reward).expanding().mean().plot()  # ylim=(MIN_ERR, MAX_ERR))
%%time
run_fixed_setting()
CPU times: user 1.69 s, sys: 19.4 ms, total: 1.71 s
Wall time: 1.7 s
_images/08_Online_learning_in_non_stationary_environments_34_1.png

Setting 2

Environment

let’s build an environment corresponding to “setting 2” in [1]

from wax.modules import Counter


def build_env():
    def env(action, obs):
        y_pred, eps = action, obs
        t = Counter()()
        ar_coefs_1 = jnp.array([-0.4, -0.5, 0.4, 0.4, 0.1])
        ar_coefs_2 = jnp.array([0.6, -0.4, 0.4, -0.5, 0.5])
        ar_coefs = ar_coefs_1 * t / T + ar_coefs_2 * (1 - t / T)

        ma_coefs = jnp.array([0.32, -0.2])

        y = ARMA(ar_coefs, ma_coefs)(eps)
        # prediction used on a fresh y observation.
        rw = -((y - y_pred) ** 2)

        env_info = {"y": y, "y_pred": y_pred}
        obs = y
        return rw, obs, env_info

    return env


def sample_noise(rng):
    eps = jax.random.uniform(rng, (T, 20), minval=-0.5, maxval=0.5)
    return eps


MIN_ERR = 0.0833
MAX_ERR = 0.15
LEARN_TIME_SLICE = slice(int(T / 2), T)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()
_images/08_Online_learning_in_non_stationary_environments_39_1.png
CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)
_images/08_Online_learning_in_non_stationary_environments_40_1.png
BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters:  1.6681005372000592 0.31622776601683794
_images/08_Online_learning_in_non_stationary_environments_41_1.png
CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_42_0.png
plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_43_0.png _images/08_Online_learning_in_non_stationary_environments_43_1.png

Conclusions

  • The NEWTON and ADAGRAD optimizers are more efficient to adapt to slowly changing environments.

  • The SGD and ADAM optimizers seem to have the worst performance.

Fixed setting

%%time
run_fixed_setting()
CPU times: user 1.77 s, sys: 22.5 ms, total: 1.79 s
Wall time: 1.76 s
_images/08_Online_learning_in_non_stationary_environments_46_1.png

Setting 3

Environment

Let us build an environment corresponding to the “setting 3” of [1]. We modify it slightly by adding 10000 steps. We intentionally use use the 5000 to 10000 steps to optimize the hyper parameters. This allows us to evaluate how the models “over-optimize”.

from wax.modules import Counter


def build_env():
    def env(action, obs):
        y_pred, eps = action, obs
        t = Counter()()
        ar_coefs_1 = jnp.array([0.6, -0.5, 0.4, -0.4, 0.3])
        ar_coefs_2 = jnp.array([-0.4, -0.5, 0.4, 0.4, 0.1])

        ar_coefs = jnp.where(t < int(Tlong / 2), ar_coefs_1, ar_coefs_2)
        ma_coefs_1 = jnp.array([0.3, -0.2])
        ma_coefs_2 = jnp.array([-0.3, 0.2])
        ma_coefs = jnp.where(t < int(Tlong / 2), ma_coefs_1, ma_coefs_2)

        y = ARMA(ar_coefs, ma_coefs)(eps)

        # prediction used on a fresh y observation.
        rw = -((y - y_pred) ** 2)

        env_info = {"y": y, "y_pred": y_pred}
        obs = y
        return rw, obs, env_info

    return env


def sample_noise(rng):
    eps = jax.random.uniform(rng, (Tlong, N_BATCH), minval=-0.5, maxval=0.5)
    return eps


Tlong = 2 * T
MIN_ERR = 0.0833
MAX_ERR = 0.12
LEARN_TIME_SLICE = slice(int(T / 2), T)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()
_images/08_Online_learning_in_non_stationary_environments_51_1.png
CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)
_images/08_Online_learning_in_non_stationary_environments_52_1.png
BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters:  0.464158883361278 17.78279410038923
_images/08_Online_learning_in_non_stationary_environments_53_1.png
CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_54_0.png
plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_55_0.png _images/08_Online_learning_in_non_stationary_environments_55_1.png

Choosing hyper parameters on the whole period

It seems that Newton solver is more prone to overfitting (recall that we chose its hyper parameters to optimize the average loss between steps 5000 and 1000, thus only in the first regime).

However, as stated in [1], Newton algorithm can have better performances if we choose its hyper parameters in order to obtain the best performances for both regimes.

Let us check this:

LEARN_TIME_SLICE = slice(int(Tlong / 2), None)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()
_images/08_Online_learning_in_non_stationary_environments_59_1.png
CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)
_images/08_Online_learning_in_non_stationary_environments_60_1.png
BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters:  5.994842503189409 17.78279410038923
_images/08_Online_learning_in_non_stationary_environments_61_1.png
CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_62_0.png
plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_63_0.png _images/08_Online_learning_in_non_stationary_environments_63_1.png

Conclusion

  • The ADAGRAD optimizers seems to be best suited for abrupt regime switching.

  • The SGD and NEWTON optimizers seem to behave similarly if their parameters are correctly chosen.

  • The ADAM optimizer seems to have the worst performance.

Fixed setting

%%time
run_fixed_setting()
CPU times: user 2.2 s, sys: 32.6 ms, total: 2.23 s
Wall time: 2.23 s
_images/08_Online_learning_in_non_stationary_environments_66_1.png

Setting 4

Environment

let’s build an environment corresponding to “setting 4” in [1]

from wax.modules import Counter


def build_env():
    def env(action, obs):
        y_pred, eps = action, obs
        t = Counter()()
        ar_coefs = jnp.array([0.11, -0.5])

        ma_coefs = jnp.array([0.41, -0.39, -0.685, 0.1])

        # rng = hk.next_rng_key()

        prev_eps = hk.get_state("prev_eps", (1,), init=lambda *_: jnp.zeros_like(eps))
        eps = prev_eps + eps  # jax.random.normal(rng, (1, N_BATCH))

        hk.set_state("prev_eps", eps)

        y = ARMA(ar_coefs, ma_coefs)(eps)
        # prediction used on a fresh y observation.
        rw = -((y - y_pred) ** 2)

        env_info = {"y": y, "y_pred": y_pred}
        obs = y
        return rw, obs, env_info

    return env


def sample_noise(rng):
    eps = jax.random.normal(rng, (T, N_BATCH)) * 0.3
    return eps


MIN_ERR = 0.09
MAX_ERR = 0.3
LEARN_TIME_SLICE = slice(int(T / 2), T)
env = build_env()
BEST_STEP_SIZE, BEST_GYM = scan_hparams_first_order()
_images/08_Online_learning_in_non_stationary_environments_71_1.png
BEST_STEP_SIZE
{'sgd': 9.999999747378752e-05,
 'adagrad': 0.05736152455210686,
 'rmsprop': 0.0016102619701996446,
 'adam': 0.0016102619701996446}
CROSS_VAL_GYM = cross_validate_first_order(BEST_STEP_SIZE, BEST_GYM)
_images/08_Online_learning_in_non_stationary_environments_73_1.png
BEST_HPARAMS, BEST_NEWTON_GYM = scan_hparams_newton()
Best newton parameters:  0.464158883361278 0.31622776601683794
_images/08_Online_learning_in_non_stationary_environments_74_1.png
CROSS_VAL_GYM = cross_validate_newton(BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_75_0.png
plot_everything(BEST_STEP_SIZE, BEST_GYM, BEST_HPARAMS, BEST_NEWTON_GYM)
_images/08_Online_learning_in_non_stationary_environments_76_0.png _images/08_Online_learning_in_non_stationary_environments_76_1.png

As noted in [1], the newton algorithm seems to be the only one to achieve an average error rate that converges to the variance of the noise (0.09).

Conclusion

In this environment with noise auto-correlations:

  • The NEWTON optimizer achieve to realize the minimum theoretical average loss

  • The other optimizers struggle to converge to the minimum theoretical loss and thus seems to suffer a linear regret.

  • The SGD optimizer is the worst in this setting.

Fixed setting

%%time
run_fixed_setting()
CPU times: user 1.84 s, sys: 16.5 ms, total: 1.86 s
Wall time: 1.85 s
_images/08_Online_learning_in_non_stationary_environments_80_1.png

Change log

Best viewed here.

wax 0.2.0 (October 20 2021)

  • Documentation:

    • New notebook : 07_Online_Time_Series_Prediction

    • New notebook : 08_Online_learning_in_non_stationary_environments

  • API modifications:

    • refactor accessors and stream

    • GymFeedback now assumes that agent and env return info object

    • OnlineSupervisedLearner action is y_pred, loss and params are returned as info

  • Improvements:

    • introduce general unroll transformation.

    • dynamic_unroll can handle Callable objects

    • UpdateOnEvent can handle any signature for functions

    • EWMCov can handle the x and y arguments explicitly

    • add initial action option to GymFeedback

  • New Features:

    • New module UpdateParams

    • New module SNARIMAX, ARMA

    • New module OnlineOptimizer

    • New module VMap

    • add grads_fill_nan_inf option to OnlineSupervisedLearner

    • Introduce unroll_transform_with_state following Haiku API.

    • New function auto_format_with_shape and tree_auto_format_with_shape

    • New module Ffill

    • New module Counter

  • Deprecate:

    • deprecate dynamic_unroll and static_unroll, refactor their usages.

  • Fixes:

    • Simplify Buffer to work only on ndarrays (implementation on pytrees were too complex)

    • EWMA behave corectly with gradient

    • MaskStd behave correctly with gradient

    • correct encode_int64 when working on int32

    • update notebook 06_Online_Linear_Regression and add it to run-notebooks rule

    • correct pct_change to behave correctly when input data has nan values.

    • correct eagerpy test for update of tensorflow, pytorch and jax

    • remove duplicate license comments

    • use numpy.allclose instsead of jax.numpy.allclose for comparaison of non Jax objects

    • update comment in notebooks : jaxlib==0.1.67+cuda111 to jaxlib==0.1.70+cuda111

    • fix jupytext dependency

    • add seaborn as optional dependency

wax 0.1.0 (June 14 2021)

  • First realease.

Installing wax

First, obtain the WAX-ML source code:

git clone https://github.com/eserie/wax-ml
cd wax

You can install wax by running:

pip install -e .[complete]  # install wax

To upgrade to the latest version from GitHub, just run git pull from the WAX-ML repository root. You shouldn’t have to reinstall wax because pip install -e sets up symbolic links from site-packages into the repository.

You can install wax development tools by running:

pip install -e .[dev]  # install wax-development-tools

Running the tests

To run all the WAX-ML tests, we recommend using pytest-xdist, which can run tests in parallel. First, install pytest-xdist and pytest-benchmark by running ip install -r build/test-requirements.txt. Then, from the repository root directory run:

pytest -n auto .

You can run a more specific set of tests using pytest’s built-in selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:

pytest -v wax/accessors_test.py

The Colab notebooks are tested for errors as part of the documentation build and Github actions.

Type checking

We use mypy to check the type hints. To check types locally the same way as Github actions checks, you can run:

mypy wax

or

make mypy

Flake8

We use flake8 to check that the code follow the pep8 standard. To check the code, you can run

make flake8

Formatting code

We use isort and black to format the code.

When you are in the root directory of the project, to format code in the package, you can run:

make format-package

To format notebooks in the documentation, you can use:

make format-notebooks

To format all files you can run:

make format

Note that the CI running with actions will verify that formatting all source code does not affect the files. You can check this locally by running :

make check-format

Check actions

You can check that everything is ok by running:

make act

This will check flake8, mypy, isort and black formatting, licenses headers and run tests and coverage.

Update documentation

To rebuild the documentation, install several packages:

pip install -r docs/requirements.txt

And then run:

sphinx-build -b html docs docs/build/html

or run

make docs

This can take a long time because it executes many of the notebooks in the documentation source; if you’d prefer to build the docs without executing the notebooks, you can run:

sphinx-build -b html -D jupyter_execute_notebooks=off docs docs/build/html

or run

make docs-fast

You can then see the generated documentation in docs/_build/html/index.html.

Update notebooks

We use jupytext to maintain three synced copies of the notebooks in docs/notebooks: one in ipynb format, one in py and one in md format. The advantage of the former is that it can be opened and executed directly in Colab; the advantage of the second is that it makes easier to refactor and format python code; the advantage of the latter is that it makes it much easier to track diffs within version control.

Editing ipynb

For making large changes that substantially modify code and outputs, it is easiest to edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface, open http://colab.research.google.com and Upload from your local repo. Update it as needed, Run all cells then Download ipynb. You may want to test that it executes properly, using sphinx-build as explained above.

You could format the python code in your notebooks by running make format in the docs/notebooks directory or make format-notebooks in the root directory.

Editing md

For making smaller changes to the text content of the notebooks, it is easiest to edit the .md versions using a text editor.

Syncing notebooks

After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:

jupytext --sync docs/notebooks/*

or:

cd  docs/notebooks/
make sync

Alternatively, you can run this command via the pre-commit framework by executing the following in the main WAX-ML directory:

pre-commit run --all

See the pre-commit framework documentation for information on how to set your local git environment to execute this automatically.

Creating new notebooks

If you are adding a new notebook to the documentation and would like to use the jupytext --sync command discussed here, you can set up your notebook for jupytext by using the following command:

jupytext --set-formats ipynb,py,md:myst path/to/the/notebook.ipynb

This works by adding a "jupytext" metadata field to the notebook file which specifies the desired formats, and which the jupytext --sync command recognizes when invoked.

Notebooks within the sphinx build

Some of the notebooks are built automatically as part of the Travis pre-submit checks and as part of the Read the docs build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with raises-exceptions metadata (example PR). You have to add this metadata by hand in the .ipynb file. It will be preserved when somebody else re-saves the notebook.

We exclude some notebooks from the build, e.g., because they contain long computations. See exclude_patterns in conf.py.

Documentation building on readthedocs.io

WAX-ML’s auto-generated documentations is at https://wax-ml.readthedocs.io/.

The documentation building is controlled for the entire project by the readthedocs WAX-ML settings. The current settings trigger a documentation build as soon as code is pushed to the GitHub main branch. For each code version, the building process is driven by the .readthedocs.yml and the docs/conf.py configuration files.

For each automated documentation build you can see the documentation build logs.

If you want to test the documentation generation on Readthedocs, you can push code to the test-docs branch. That branch is also built automatically, and you can see the generated documentation here. If the documentation build fails you may want to wipe the build environment for test-docs.

For a local test, you can do it in a fresh directory by replaying the commands executed by Readthedocs and written in their logs:

mkvirtualenv wax-ml-docs  # A new virtualenv
mkdir wax-ml-docs  # A new directory
cd wax-ml-docs
git clone --no-single-branch --depth 50 https://github.com/eserie/wax-ml
cd wax-ml-docs
git checkout --force origin/test-docs
git clean -d -f 
workon wax-ml-docs

python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html

Public API: wax package

Subpackages

wax.modules package

Gym modules

In WAX-ML, an agent and environments are simple functions:

_images/agent_env.png

A Gym feedback loops can be represented with the diagram:

_images/gymfeedback.png

Equivalently, it can be described with the pair of init and apply functions:

_images/gymfeedback_init_apply.png

gym_feedback

Gym feedback between an agent and a Gym environment.

Online Learning

WAX-ML contains a module to perform online learning for supervised problems.

online_supervised_learner

Online supervised learner.

Other Haiku modules

buffer

Implement buffering mechanism.

diff

Implement difference of values on sequential data.

ewma

Compute exponentioal moving average.

ewmcov

Compute exponentially weighted covariance.

ewmvar

Compute exponentially weighted variance.

fill_nan_inf

Fill nan, posinf and neginf values.

has_changed

Detect if something has changed.

lag

Delay operator.

ohlc

Open-High-Low-Close binning.

pct_change

Relative change between the current and a prior element.

rolling_mean

Rolling mean.

update_on_event

Apply a module when an event occur otherwise return last computed output.

wax.gym package

Gym objects

agent

Define API for Gym agents.

callbacks

Callbacks to work with wax.gym.gym_unroll

entity

Base class for Gym Agent and Env classes.

env

Define API for Gym environments.

haiku_agent

Gym agent defined from an Haiku module

haiku_env

Gym environment defined from an Haiku module

wax.datasets package

generation functions

generate_temperature_data

Generate fake temperature data for tests purposes.

wax.universal package

Universal Haiku modules

eager_ewma

Universal Exponential moving average module and unroll implementations.

wax.accessors

Define accessors for xarray and pandas data containers.

wax.compile

Compilation helper for Haiku Transformed and TransformedWithState pairs of pure funcions.

wax.encode

Encoding schemes to encode/decode numpy data types non supported by JAX, e.g.

wax.format

Format nested data structures to numpy/xarray/pandas containers.

wax.stream

Define Stream object used to synchronize in-memory data streams and unroll data transformations on it.

wax.transform

Transformation functions to work on batches of data.

wax.unroll

Unroll modules on data along first axis.

wax.utils

Some utils functions used in WAX-ML.

Indices and tables