Welcome to WAX-ML

WAX-ML is a library for machine-learning on streaming data.

For an introduction to WAX-ML, start at the WAX-ML GitHub page.

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7ffb501244b0>]

〰 Compute exponential moving averages with xarray and pandas accessors 〰

Open in Colab

WAX-ML implements pandas and xarray accessors to ease the usage of machine-learning algorithms with high-level data APIs :

  • pandas’s DataFrame and Series

  • xarray’s Dataset and DataArray.

These accessors allow to easily execute any function using Haiku modules on these data containers.

For instance, WAX-ML propose an implementation of the exponential moving average realized with this mechanism.

Let’s show how it works.

Load accessors

First you need to load accessors:

from wax.accessors import register_wax_accessors

register_wax_accessors()

EWMA on dataframes

Let’s look at a simple example: The exponential moving average (EWMA).

Let’s apply the EWMA algorithm to the NCEP/NCAR ‘s Air temperature data.

🌡 Load temperature dataset 🌡

import xarray as xr

da = xr.tutorial.open_dataset("air_temperature")

Let’s see what this dataset looks like:

da
<xarray.Dataset>
Dimensions:  (lat: 25, time: 2920, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 ...
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...

To compute a EWMA on some variables of a dataset, we usually need to convert data in pandas series or dataframe.

So, let’s convert the dataset into a dataframe to illustrate accessors on a dataframe:

dataframe = da.air.to_series().unstack(["lon", "lat"])

EWMA with pandas

%%time
air_temp_ewma = dataframe.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
CPU times: user 406 ms, sys: 20.5 ms, total: 426 ms
Wall time: 426 ms
_images/01_demo_EWMA_17_1.png

EWMA with WAX-ML

%%time
air_temp_ewma = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
_ = air_temp_ewma.iloc[:, 0].plot()
CPU times: user 1.78 s, sys: 452 ms, total: 2.23 s
Wall time: 2.22 s
_images/01_demo_EWMA_19_1.png

On small data, WAX-ML’s EWMA is slower than Pandas’ because of the expensive data conversion steps. WAX-ML’s accessors are interesting to use on large data loads (See our three-steps_workflow)

Apply a custom function to a Dataset

Now let’s illustrate how WAX-ML accessors work on xarray datasets.

from wax.modules import EWMA


def my_custom_function(da):
    return {
        "air_10": EWMA(1.0 / 10.0)(da["air"]),
        "air_100": EWMA(1.0 / 100.0)(da["air"]),
    }


da = xr.tutorial.open_dataset("air_temperature")
output, state = da.wax.stream().apply(my_custom_function, format_dims=da.air.dims)

_ = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))
_images/01_demo_EWMA_22_0.png
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7f68731083f0>]

⏱ Synchronize data streams ⏱

Open in Colab

Physicists, and not the least 😅, have brought a solution to the synchronization problem. See Poincaré-Einstein synchronization Wikipedia page for more details.

In WAX-ML we strive to follow their recommendations and implement a synchronization mechanism between different data streams. Using the terminology of Henri Poincaré (see link above), we introduce the notion of “local time” to unravel the stream in which the user wants to apply transformations. We call the other streams “secondary streams”. They can work at different frequencies, lower or higher. The data from these secondary streams will be represented in the “local time” either with the use of a forward filling mechanism for lower frequencies or a buffering mechanism for higher frequencies.

We implement a “data tracing” mechanism to optimize access to out-of-sync streams. This mechanism works on in-memory data. We perform the first pass on the data, without actually accessing it, and determine the indices necessary to later access the data. Doing so we are vigilant to not let any “future” information pass through and thus guaranty a data processing that respects causality.

The buffering mechanism used in the case of higher frequencies works with a fixed buffer size (see the WAX-ML module wax.modules.Buffer) which allows us to use JAX / XLA optimizations and have efficient processing.

Let’s illustrate with a small example how wax.stream.Stream synchronizes data streams.

Let’s use the dataset “air temperature” with :

  • An air temperature is defined with hourly resolution.

  • A “fake” ground temperature is defined with a daily resolution as the air temperature minus 10 degrees.

import xarray as xr

da = xr.tutorial.open_dataset("air_temperature")
da["ground"] = da.air.resample(time="d").last().rename({"time": "day"}) - 10

Let’s see what this dataset looks like:

da
<xarray.Dataset>
Dimensions:  (lat: 25, time: 2920, lon: 53, day: 730)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
  * day      (day) datetime64[ns] 2013-01-01 2013-01-02 ... 2014-12-31
Data variables:
    air      (time, lat, lon) float32 ...
    ground   (day, lat, lon) float32 231.9 231.8 231.8 ... 286.5 286.2 285.7
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
from wax.accessors import register_wax_accessors

register_wax_accessors()
from wax.modules import EWMA


def my_custom_function(da):
    return {
        "air_10": EWMA(1.0 / 10.0)(da["air"]),
        "air_100": EWMA(1.0 / 100.0)(da["air"]),
        "ground_100": EWMA(1.0 / 100.0)(da["ground"]),
    }
results, state = da.wax.stream(local_time="time", pbar=True).apply(
    my_custom_function, format_dims=da.air.dims
)
_ = results.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().plot(figsize=(12, 8))
_images/02_Synchronize_data_streams_11_0.png
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7f8d889385b0>]

🌡 Binning temperatures 🌡

Open in Colab

Let’s again considering the air temperatures dataset. It is sampled at a hourly resolution. We will make “trailing” air temperature bins during each day and “reset” the bin aggregation process at each day change.

import numpy as onp
import xarray as xr
from wax.accessors import register_wax_accessors
from wax.modules import OHLC, HasChanged

register_wax_accessors()
da = xr.tutorial.open_dataset("air_temperature")
da["date"] = da.time.dt.date.astype(onp.datetime64)
def bin_temperature(da):
    day_change = HasChanged()(da["date"])
    return OHLC()(da["air"], reset_on=day_change)


output, state = da.wax.stream().apply(
    bin_temperature, format_dims=onp.array(da.air.dims)
)
output = xr.Dataset(output._asdict())
df = output.isel(lat=0, lon=0).drop(["lat", "lon"]).to_pandas().loc["2013-01"]
_ = df.plot(figsize=(12, 8))
_images/03_ohlc_temperature_9_0.png

The UpdateOnEvent module

The OHLC module uses the primitive wax.modules.UpdateOnEvent.

Its implementation required to complete Haiku with a central function set_params_or_state_dict which we have actually integrated in this WAX-ML module.

We have opened an issue on the Haiku github to integrate it in Haiku.

# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
jax backend gpu
[GpuDevice(id=0, process_index=0)]

🎛 The 3-steps workflow 🎛

Open in Colab

It is already very useful to be able to execute a JAX function on a dataframe in a single work step and with a single command line thanks to WAX-ML accessors.

The 1-step WAX-ML’s stream API works like that:

<data-container>.stream(...).apply(...)

But this is not optimal because, under the hood, there are mainly three costly steps:

  • (1) (synchronize | data tracing | encode): make the data “JAX ready”

  • (2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.

  • (3) (format): convert data back to pandas/xarray/numpy format.

With the wax.stream primitives, it is quite easy to explicitly split the 1-step workflow into a 3-step workflow.

This will allow the user to have full control over each step and iterate on each one.

It is actually very useful to iterate on step (2), the “calculation step” when you are doing research. You can then take full advantage of the JAX primitives, especially the jit primitive.

Let’s illustrate how to reimplement WAX-ML EWMA yourself with the WAX-ML 3-step workflow.

Imports

import haiku as hk
import numpy as onp
import pandas as pd
import xarray as xr

from wax.accessors import register_wax_accessors
from wax.compile import jit_init_apply
from wax.external.eagerpy import convert_to_tensors
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import dynamic_unroll

register_wax_accessors()

Performance on big dataframes

Generate data

T = 1.0e5
N = 1000
%%time
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
    onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)
CPU times: user 3.87 s, sys: 143 ms, total: 4.01 s
Wall time: 3.99 s

Pandas EWMA

%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 1.95 s, sys: 480 ms, total: 2.42 s
Wall time: 2.42 s

WAX-ML EWMA

%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()
CPU times: user 2.21 s, sys: 679 ms, total: 2.89 s
Wall time: 3.11 s

It’s a little faster, but not that much faster…

WAX-ML EWMA (without format step)

Let’s disable the final formatting step (the output is now in raw JAX format):

%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()
CPU times: user 1.77 s, sys: 563 ms, total: 2.33 s
Wall time: 2.22 s
type(df_ewma_wax_no_format)
jaxlib.xla_extension.DeviceArray

Let’s check the device on which the calculation was performed (if you have GPU available, this should be GpuDevice otherwise it will be CpuDevice):

df_ewma_wax_no_format.device()
GpuDevice(id=0, process_index=0)

That’s better! In fact (see below) there is a performance problem in the final formatting step. See WEP3 for a proposal to improve the formatting step.

Generate data (in dataset format)

WAX-ML Sream object works on datasets. So let’s transform the DataFrame into a xarray Dataset:

dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")

Step (1) (synchronize | data tracing | encode)

In this step, WAX-ML do:

  • “data tracing” : prepare the indices for fast access tin the JAX function access_data

  • synchronize streams if there is multiple ones. This functionality have options : freq, ffills

  • encode and convert data from numpy to JAX: use encoders for datetimes64 and string_ dtypes. Be aware that by default JAX works in float32 (see JAX’s Common Gotchas to work in float64).

We have a function Stream.prepare that implement this Step (1). It prepares a function that wraps the input function with the actual data and indices in a pair of pure functions (TransformedWithState Haiku tuple).

%%time
stream = dataframe.wax.stream()
CPU times: user 20 µs, sys: 5 µs, total: 25 µs
Wall time: 29.8 µs

Define our custom function to be applied on a dict of arrays having the same structure than the original dataset:

def my_ewma_on_dataset(dataset):
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
transform_dataset, jxs = stream.prepare(dataset, my_ewma_on_dataset)

Let’s definite the init parameters and state of the transformation we will apply.

Init params and state

from wax.unroll import init_params_state
rng = jax.random.PRNGKey(42)
params, state = init_params_state(transform_dataset, rng, jxs)
params
FlatMapping({'ewma': FlatMapping({'alpha': DeviceArray(0.1, dtype=float32)})})
assert state["ewma"]["count"].shape == (N,)
assert state["ewma"]["mean"].shape == (N,)

Step (2) (compile | code tracing | execution)

In this step we:

  • prepare a pure function (with Haiku’s transform mechanism) Define a “transformation” function which:

    • access to the data

    • apply another transformation, here: EWMA

  • compile it with jax.jit

  • perform code tracing and execution (the last line):

    • Unroll the transformation on “steps” xs (a np.arange vector).

rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
outputs.device()
GpuDevice(id=0, process_index=0)

Once it has been compiled and “traced” by JAX, the function is much faster to execute:

%%timeit
outputs, _ = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
1 loop, best of 5: 1.18 s per loop
%%time
outputs, _ = dynamic_unroll(transform_dataset, params, state, rng, False, jxs)
CPU times: user 1.29 s, sys: 6.21 ms, total: 1.29 s
Wall time: 1.24 s

This is 3x faster than pandas implementation!

(The 3x factor is obtained by measuring the execution with %timeit. We don’t know why, but when executing a code cell once at a time, then the execution time can vary a lot and we can observe some executions with a speed-up of 100x).

Manually prepare the data and manage the device

In order to manage the device on which the computations take place, we need to have even more control over the execution flow. Instead of calling stream.prepare to build the transform_dataset function, we can do it ourselves by :

  • using the stream.trace_dataset function

  • converting the numpy data in jax ourself

  • puting the data on the device we want.

np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:2983: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax._check_user_dtype_supported(dtype, "asarray")
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:2983: UserWarning: Explicitly requested dtype int64 requested in asarray is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax._check_user_dtype_supported(dtype, "asarray")

We explicitly set data on CPUs (the is not needed if you only have CPUs):

from jax.tree_util import tree_leaves, tree_map

cpus = jax.devices("cpu")
jnp_data, jnp_index, jxs = tree_map(
    lambda x: jax.device_put(x, cpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to CPU device.")
data copied to CPU device.

We have now “JAX-ready” data for later fast access.

Let’s define the transformation that wrap the actual data and indices in a pair of pure functions:

%%time
@jit_init_apply
@hk.transform_with_state
def transform_dataset(step):
    dataset = tree_access_data(jnp_data, jnp_index, step)
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])
CPU times: user 277 µs, sys: 0 ns, total: 277 µs
Wall time: 289 µs

And we can call it as before:

%%time
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 373 ms, sys: 163 ms, total: 537 ms
Wall time: 533 ms
outputs.device()
CpuDevice(id=0)

Step(3) (format)

Let’s come back to pandas/xarray:

%%time
y = format_dataframe(
    dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)
CPU times: user 2.14 s, sys: 9.66 ms, total: 2.15 s
Wall time: 2.13 s

It’s quite slow (see WEP3 enhancement proposal).

GPU execution

Let’s look with execution on GPU

try:
    gpus = jax.devices("gpu")
    jnp_data, jnp_index, jxs = tree_map(
        lambda x: jax.device_put(x, gpus[0]), (jnp_data, jnp_index, jxs)
    )
    print("data copied to GPU device.")
    GPU_AVAILABLE = True
except RuntimeError as err:
    print(err)
    GPU_AVAILABLE = False
data copied to GPU device.

Let’s check that our data is on the GPUs:

tree_leaves(jnp_data)[0].device()
GpuDevice(id=0, process_index=0)
tree_leaves(jnp_index)[0].device()
GpuDevice(id=0, process_index=0)
jxs.device()
GpuDevice(id=0, process_index=0)
%%time
if GPU_AVAILABLE:
    rng = next(hk.PRNGSequence(42))
    outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 1.7 s, sys: 555 ms, total: 2.25 s
Wall time: 2.14 s

Let’s redefine our function transform_dataset by explicitly specify to jax.jit the device option.

%%time
if GPU_AVAILABLE:

    @hk.transform_with_state
    def transform_dataset(step):
        dataset = tree_access_data(jnp_data, jnp_index, step)
        return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])

    transform_dataset = type(transform_dataset)(
        transform_dataset.init, jax.jit(transform_dataset.apply, device=gpus[0])
    )

    rng = next(hk.PRNGSequence(42))
    outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
CPU times: user 1.48 s, sys: 17.6 ms, total: 1.5 s
Wall time: 2.13 s
outputs.device()
GpuDevice(id=0, process_index=0)
%%timeit
if GPU_AVAILABLE:
    outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)
1 loop, best of 5: 1.18 s per loop
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()

🔭 Reconstructing the light curve of stars with LSTM 🔭

Open in Colab

Let’s take a walk through the stars…

This notebook is based on the study done in this post by Christophe Pere and the notebook available on the authors’s github.

We will repeat this study on starlight using the LSTM architecture to predict the observed light flux through time.

Our LSTM implementation is based on this notebook from Haiku’s github repository.

We’ll see how to use WAX-ML to ease the preparation of time series data stored in dataframes and having Nans before calling a “standard” deep-learning workflow.

Disclaimer

Despite the fact that this code works with real data, the results presented here should not be considered as scientific knowledge insights, to the knowledge of the authors of WAX-ML, neither the results nor the data source have been reviewed by an astrophysics pair.

The purpose of this notebook is only to demonstrate how WAX-ML can be used when applying a “standard” machine learning workflow, here LSTM, to analyze time series.

Download the data

%matplotlib inline
import io
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
from sklearn.preprocessing import MinMaxScaler
from tqdm.auto import tqdm

from wax.accessors import register_wax_accessors
from wax.modules import RollingMean

register_wax_accessors()
# Parameters
STAR = "007609553"
SEQ_LEN = 64
BATCH_SIZE = 8
TRAIN_STEPS = 2 ** 16
TRAIN_SIZE = 2 ** 16
TOTAL_LEN = None
TRAIN_DATE = "2016"
CACHE_DIR = Path("./cached_data/")
%%time
filename = CACHE_DIR / "kep_lightcurves.parquet"
try:
    raw_dataframe = pd.read_parquet(open(filename, "rb"))
    print(f"data read from {filename}")
except FileNotFoundError:
    # Downloading the csv file from Chrustioge Pere GitHub account
    download = requests.get(
        "https://raw.github.com/Christophe-pere/Time_series_RNN/master/kep_lightcurves.csv"
    ).content
    raw_dataframe = pd.read_csv(io.StringIO(download.decode("utf-8")))
    # set date index
    raw_dataframe.index = pd.Index(
        pd.date_range("2009-03-07", periods=len(raw_dataframe.index), freq="h"),
        name="time",
    )
    # save dataframe locally in CACHE_DIR
    CACHE_DIR.mkdir(exist_ok=True)
    raw_dataframe.to_parquet(filename)
    print(f"data saved in {filename}")
data saved in cached_data/kep_lightcurves.parquet
CPU times: user 1.2 s, sys: 197 ms, total: 1.39 s
Wall time: 3.73 s
# shortening of data to speed up the execution of the notebook in the CI
if TOTAL_LEN:
    raw_dataframe = raw_dataframe.iloc[:TOTAL_LEN]

Let’s visualize the description of this dataset:

raw_dataframe.describe().T.to_xarray()
<xarray.Dataset>
Dimensions:  (index: 52)
Coordinates:
  * index    (index) object '001430305_orig' ... '011611275_res'
Data variables:
    count    (index) float64 6.48e+04 5.674e+04 ... 5.673e+04 5.673e+04
    mean     (index) float64 6.776e+04 -0.2265 0.01231 ... 0.001437 0.004351
    std      (index) float64 1.363e+03 15.42 15.27 12.45 ... 4.648 6.415 4.904
    min      (index) float64 6.529e+04 -123.3 -75.59 ... -20.32 -31.97 -20.89
    25%      (index) float64 6.619e+04 -9.488 -9.875 ... -3.269 -4.281 -3.279
    50%      (index) float64 6.806e+04 -0.3476 0.007812 ... 0.007812 -0.06529
    75%      (index) float64 6.882e+04 8.988 10.02 8.092 ... 2.872 4.277 3.213
    max      (index) float64 7.021e+04 128.7 72.31 69.34 ... 26.53 30.94 29.45
stars = raw_dataframe.columns
stars = sorted(list(set([i.split("_")[0] for i in stars])))
print(f"The number of stars available is: {len(stars)}")
print(f"star identifiers: {stars}")
The number of stars available is: 13
star identifiers: ['001430305', '001724719', '005209845', '007596240', '007609553', '008241079', '008247770', '009345933', '009347009', '009349482', '009349757', '010024701', '011611275']
dataframe = raw_dataframe[[i + "_rscl" for i in stars]].rename(
    columns=lambda c: c.replace("_rscl", "")
)
dataframe.columns.names = ["star"]
dataframe.shape
(71427, 13)

Rolling mean

We will smooth the data by applying a rolling mean with a window of 100 periods.

Count nan values

But before since the dataset has some nan values, we will extract few statistics about the density of nan values in windows of size 100.

It will be the occasion to show a usage of the wax.modules.Buffer module with the format_outputs=False option for the dataframe accessor .wax.stream.

import jax.numpy as jnp
import numpy as onp
from wax.modules import Buffer

Let’s apply the Buffer module to the data:

buffer, _ = dataframe.wax.stream(format_outputs=False).apply(lambda x: Buffer(100)(x))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
assert isinstance(buffer, jnp.ndarray)

Let’s describe the statistic of nans with pandas:

count_nan = jnp.isnan(buffer).sum(axis=1)
pd.DataFrame(onp.array(count_nan)).stack().describe().astype(int)
count    928551
mean         20
std          27
min           0
25%           5
50%           8
75%          19
max         100
dtype: int64

Computing the rolling mean

We will choose a min_periods of 5 in order to keep at leas 75% of the points.

%%time
dataframe_mean, _ = dataframe.wax.stream().apply(
    lambda x: RollingMean(100, min_periods=5)(x)
)
CPU times: user 446 ms, sys: 11.5 ms, total: 458 ms
Wall time: 456 ms
dataframe.loc[:, "008241079"].plot()
dataframe_mean.loc[:, "008241079"].plot()
<AxesSubplot:xlabel='time'>
_images/05_reconstructing_the_light_curve_of_stars_30_1.png

With Dataset API

Let’s illustrate how to do the same rolling mean operation but using wax accessors on xarray Dataset.

from functools import partial

from jax.tree_util import tree_map
dataset = dataframe.to_xarray()
dataset
<xarray.Dataset>
Dimensions:    (time: 71427)
Coordinates:
  * time       (time) datetime64[ns] 2009-03-07 ... 2017-04-30T02:00:00
Data variables: (12/13)
    001430305  (time) float64 -4.943 2.338 nan -0.9275 ... 13.92 7.728 nan 1.33
    001724719  (time) float64 -9.95 -19.69 -6.298 nan ... 7.535 nan 0.8825
    005209845  (time) float64 nan nan nan nan ... -5.528 -6.864 -25.47 -25.75
    007596240  (time) float64 -1.353 -1.534 -9.497 -3.48 ... 4.234 -3.83 -0.6448
    007609553  (time) float64 38.36 19.7 19.08 27.18 ... 172.2 178.2 163.8 169.9
    008241079  (time) float64 -22.68 -12.27 nan -10.34 ... 3.236 8.97 18.74
    ...         ...
    009345933  (time) float64 -9.756 -9.812 -8.808 ... 5.486 -10.21 -1.196
    009347009  (time) float64 2.219 -3.694 -3.056 -3.843 ... -3.58 11.99 -1.917
    009349482  (time) float64 -4.975 -11.5 -7.711 -9.017 ... 2.338 1.825 0.9793
    009349757  (time) float64 nan -16.64 -20.52 -15.6 ... nan 19.15 17.15 18.18
    010024701  (time) float64 -45.78 -54.53 -35.46 ... -292.7 -301.2 -283.9
    011611275  (time) float64 -1.308 -4.728 -5.136 2.284 ... 12.73 -6.223 -8.024
%%time
dataset_mean, _ = dataset.wax.stream().apply(
    partial(tree_map, lambda x: RollingMean(100, min_periods=5)(x)),
    format_dims=["time"],
)
CPU times: user 5.53 s, sys: 132 ms, total: 5.67 s
Wall time: 5.67 s

(Its much longer than with dataframe)

TODO: This is an issue that we should solve.

dataset_mean
<xarray.Dataset>
Dimensions:    (time: 71427)
Coordinates:
  * time       (time) datetime64[ns] 2009-03-07 ... 2017-04-30T02:00:00
Data variables: (12/13)
    001430305  (time) float32 nan nan nan nan ... -0.1169 0.1599 0.2577 0.3024
    001724719  (time) float32 nan nan nan nan ... -6.384 -6.214 -6.223 -6.147
    005209845  (time) float32 nan nan nan nan ... -9.909 -9.939 -10.13 -10.29
    007596240  (time) float32 nan nan nan nan nan ... 1.165 1.19 1.14 1.134
    007609553  (time) float32 nan nan nan nan 25.99 ... 145.5 146.0 146.4 146.9
    008241079  (time) float32 nan nan nan nan nan ... 13.57 13.57 13.4 13.47
    ...         ...
    009345933  (time) float32 nan nan nan nan ... -14.8 -14.53 -14.48 -14.31
    009347009  (time) float32 nan nan nan nan ... -3.367 -3.462 -3.263 -3.25
    009349482  (time) float32 nan nan nan nan -8.398 ... 1.861 1.858 1.817 1.825
    009349757  (time) float32 nan nan nan nan nan ... 10.3 10.61 10.91 11.08
    010024701  (time) float32 nan nan nan nan ... -322.8 -323.0 -322.8 -322.9
    011611275  (time) float32 nan nan nan nan ... -4.214 -4.037 -4.106 -4.192
dataset["008241079"].plot()
dataset_mean["008241079"].plot()
[<matplotlib.lines.Line2D at 0x7fad38d28b70>]
_images/05_reconstructing_the_light_curve_of_stars_38_1.png

With dataarray

dataarray = dataframe.to_xarray().to_array("star").transpose("time", "star")
dataarray
<xarray.DataArray (time: 71427, star: 13)>
array([[  -4.94312846,   -9.94989465,           nan, ...,           nan,
         -45.78391029,   -1.30840132],
       [   2.33812154,  -19.6881759 ,           nan, ...,  -16.6422077 ,
         -54.53000404,   -4.7283232 ],
       [          nan,   -6.2975509 ,           nan, ...,  -20.52306708,
         -35.45969154,   -5.13554976],
       ...,
       [   7.7280167 ,    7.53484594,   -6.86435205, ...,   19.15124702,
        -292.7287725 ,   12.73319537],
       [          nan,           nan,  -25.47275049, ...,   17.15027046,
        -301.15455375,   -6.22285932],
       [   1.3295792 ,    0.88250219,  -25.74862939, ...,   18.18347358,
        -283.88111625,   -8.02364057]])
Coordinates:
  * time     (time) datetime64[ns] 2009-03-07 ... 2017-04-30T02:00:00
  * star     (star) <U9 '001430305' '001724719' ... '010024701' '011611275'
%%time
dataarray_mean, _ = dataarray.wax.stream().apply(
    lambda x: RollingMean(100, min_periods=5)(x)
)
CPU times: user 426 ms, sys: 7.23 ms, total: 433 ms
Wall time: 431 ms

(Its much longer than with dataframe)

dataarray_mean
<xarray.DataArray (time: 71427, star: 13)>
array([[           nan,            nan,            nan, ...,
                   nan,            nan,            nan],
       [           nan,            nan,            nan, ...,
                   nan,            nan,            nan],
       [           nan,            nan,            nan, ...,
                   nan,            nan,            nan],
       ...,
       [ 1.5992440e-01, -6.2136836e+00, -9.9386721e+00, ...,
         1.0611248e+01, -3.2304874e+02, -4.0370173e+00],
       [ 2.5768742e-01, -6.2231779e+00, -1.0134243e+01, ...,
         1.0905440e+01, -3.2282303e+02, -4.1056514e+00],
       [ 3.0240160e-01, -6.1467724e+00, -1.0286570e+01, ...,
         1.1077212e+01, -3.2285059e+02, -4.1921792e+00]], dtype=float32)
Coordinates:
  * time     (time) datetime64[ns] 2009-03-07 ... 2017-04-30T02:00:00
  * star     (star) <U9 '001430305' '001724719' ... '010024701' '011611275'
dataarray.sel(star="008241079").plot()
dataarray_mean.sel(star="008241079").plot()
[<matplotlib.lines.Line2D at 0x7fadbac4e518>]
_images/05_reconstructing_the_light_curve_of_stars_45_1.png

Forecasting with Machine Learning

We need two forecast in this data, if you look with attention you’ll see micro holes and big holes.

import warnings
from typing import NamedTuple, Tuple, TypeVar

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
import plotnine as gg

T = TypeVar("T")
Pair = Tuple[T, T]


class Pair(NamedTuple):
    x: T
    y: T


class TrainSplit(NamedTuple):
    train: T
    validation: T


gg.theme_set(gg.theme_bw())
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = 18, 8
fig, (ax, lax) = plt.subplots(ncols=2, gridspec_kw={"width_ratios": [4, 1]})
dataframe.plot(ax=ax, title="raw data")
ax.legend(bbox_to_anchor=(0, 0, 1, 1), bbox_transform=lax.transAxes)
lax.axis("off")
(0.0, 1.0, 0.0, 1.0)
_images/05_reconstructing_the_light_curve_of_stars_48_1.png
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = 18, 8
fig, (ax, lax) = plt.subplots(ncols=2, gridspec_kw={"width_ratios": [4, 1]})
dataframe_mean.plot(ax=ax, title="Smoothed data")
ax.legend(bbox_to_anchor=(0, 0, 1, 1), bbox_transform=lax.transAxes)
lax.axis("off")
(0.0, 1.0, 0.0, 1.0)
_images/05_reconstructing_the_light_curve_of_stars_49_1.png

Normalize data

dataframe_mean.stack().hist(bins=100)
<AxesSubplot:>
_images/05_reconstructing_the_light_curve_of_stars_51_1.png
from wax.encode import Encoder


def min_max_scaler(values: pd.DataFrame, output_format: str = "dataframe") -> Encoder:
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaler.fit(values)
    index = values.index
    columns = values.columns

    def encode(dataframe: pd.DataFrame):
        nonlocal index
        nonlocal columns
        index = dataframe.index
        columns = dataframe.columns
        array_normed = scaler.transform(dataframe)
        if output_format == "dataframe":
            return pd.DataFrame(array_normed, index, columns)
        elif output_format == "jax":
            return jnp.array(array_normed)
        else:
            return array_normed

    def decode(array_scaled):
        value = scaler.inverse_transform(array_scaled)
        if output_format == "dataframe":
            return pd.DataFrame(value, index, columns)
        else:
            return value

    return Encoder(encode, decode)
scaler = min_max_scaler(dataframe_mean)
dataframe_normed = scaler.encode(dataframe_mean)
assert (scaler.decode(dataframe_normed) - dataframe_mean).stack().abs().max() < 1.0e-4
dataframe_normed.stack().hist(bins=100)
<AxesSubplot:>
_images/05_reconstructing_the_light_curve_of_stars_54_1.png

Prepare train / validation datasets

from wax.modules import FillNanInf, Lag
def split_feature_target(dataframe, look_back=SEQ_LEN) -> Pair:
    x, _ = dataframe.wax.stream(format_outputs=False).apply(
        lambda x: FillNanInf()(Lag(1)(Buffer(look_back)(x)))
    )
    B, T, F = x.shape
    x = x.transpose(1, 0, 2)

    y, _ = dataframe.wax.stream(format_outputs=False).apply(
        lambda x: FillNanInf()(Buffer(look_back)(x))
    )
    y = y.transpose(1, 0, 2)
    return Pair(x, y)


def split_feature_target(
    dataframe,
    look_back=SEQ_LEN,
    stack=True,
    shuffle=False,
    min_periods_ratio: float = 0.8,
) -> Pair:
    x, _ = dataframe.wax.stream(format_outputs=False).apply(
        lambda x: Lag(1)(Buffer(look_back)(x))
    )
    x = x.transpose(1, 0, 2)

    y, _ = dataframe.wax.stream(format_outputs=False).apply(
        lambda x: Buffer(look_back)(x)
    )
    y = y.transpose(1, 0, 2)

    T, B, F = x.shape

    if stack:
        x = x.reshape(T, B * F, 1)
        y = y.reshape(T, B * F, 1)

    if shuffle:
        rng = jax.random.PRNGKey(42)
        idx = jnp.arange(x.shape[1])
        idx = jax.random.shuffle(rng, idx)
        x = x[:, idx]
        y = y[:, idx]

    if min_periods_ratio:
        count_nan = jnp.isnan(x).sum(axis=0)
        mask = count_nan < min_periods_ratio * T
        idx = jnp.where(mask)
        # print("count_nan = ", count_nan)
        # print("B = ", B)
        x = x[:, idx[0], :]
        y = y[:, idx[0], :]
        T, B, F = x.shape
        # print("B = ", B)

    # round Batch size to a power of to
    B_round = int(2 ** jnp.floor(jnp.log2(B)))
    x = x[:, :B_round, :]
    y = y[:, :B_round, :]

    # fillnan by zeros
    fill_nan_inf = hk.transform(lambda x: FillNanInf()(x))
    params = fill_nan_inf.init(None, jnp.full(x.shape, jnp.nan, x.dtype))
    x = fill_nan_inf.apply(params, None, x)
    y = fill_nan_inf.apply(params, None, y)

    return Pair(x, y)
def split_train_validation(dataframe, stars, train_size, look_back) -> TrainSplit:

    # prepare scaler
    dataframe_train = dataframe[stars].iloc[:train_size]
    scaler = min_max_scaler(dataframe_train)

    # prepare train data
    dataframe_train_normed = scaler.encode(dataframe_train)
    train = split_feature_target(dataframe_train_normed, look_back)

    # prepare validation data
    valid_size = len(dataframe[stars]) - train_size
    valid_size = int(2 ** jnp.floor(jnp.log2(valid_size)))
    valid_end = int(train_size + valid_size)
    dataframe_valid = dataframe[stars].iloc[train_size:valid_end]
    dataframe_valid_normed = scaler.encode(dataframe_valid)
    valid = split_feature_target(dataframe_valid_normed, look_back)

    return TrainSplit(train, valid)
print(f"Look at star: {STAR}")
train, valid = split_train_validation(dataframe_normed, [STAR], TRAIN_SIZE, SEQ_LEN)
Look at star: 007609553
train[0].shape, train[1].shape, valid[0].shape, valid[1].shape
((64, 32768, 1), (64, 32768, 1), (64, 2048, 1), (64, 2048, 1))
TRAIN_SIZE, VALID_SIZE = len(train.x), len(valid.x)
seq = hk.PRNGSequence(42)
# Plot an observation/target pair.
batch_plot = jax.random.choice(next(seq), len(train[0]))
df = pd.DataFrame(
    {"x": train[0][:, batch_plot, 0], "y": train[1][:, batch_plot, 0]}
).reset_index()
df = pd.melt(df, id_vars=["index"], value_vars=["x", "y"])
plot = (
    gg.ggplot(df)
    + gg.aes(x="index", y="value", color="variable")
    + gg.geom_line()
    + gg.scales.scale_y_log10()
)
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_63_0.png

Dataset iterator

class Dataset:
    """An iterator over a numpy array, revealing batch_size elements at a time."""

    def __init__(self, xy: Pair, batch_size: int):
        self._x, self._y = xy
        self._batch_size = batch_size
        self._length = self._x.shape[1]
        self._idx = 0
        if self._length % batch_size != 0:
            msg = "dataset size {} must be divisible by batch_size {}."
            raise ValueError(msg.format(self._length, batch_size))

    def __next__(self) -> Pair:
        start = self._idx
        end = start + self._batch_size
        x, y = self._x[:, start:end], self._y[:, start:end]
        if end >= self._length:
            end = end % self._length
            assert end == 0  # Guaranteed by ctor assertion.
        self._idx = end
        return x, y
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)
del train, valid  # Don't leak temporaries.

Training an LSTM

To train the LSTM, we define a Haiku function which unrolls the LSTM over the input sequence, generating predictions for all output values. The LSTM always starts with its initial state at the start of the sequence.

The Haiku function is then transformed into a pure function through hk.transform, and is trained with Adam on an L2 prediction loss.

from wax.compile import jit_init_apply
x, y = next(train_ds)
x.shape, y.shape
((64, 8, 1), (64, 8, 1))
from collections import defaultdict
def unroll_net(seqs: jnp.ndarray):
    """Unrolls an LSTM over seqs, mapping each output to a scalar."""
    # seqs is [T, B, F].
    core = hk.LSTM(32)
    batch_size = seqs.shape[1]
    outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
    # We could include this Linear as part of the recurrent core!
    # However, it's more efficient on modern accelerators to run the linear once
    # over the entire sequence than once per sequence element.
    return hk.BatchApply(hk.Linear(1))(outs), state
model = jit_init_apply(hk.transform(unroll_net))
def train_model(
    train_ds: Dataset, valid_ds: Dataset, max_iterations: int = -1
) -> hk.Params:
    """Initializes and trains a model on train_ds, returning the final params."""
    rng = jax.random.PRNGKey(428)
    opt = optax.adam(1e-3)

    @jax.jit
    def loss(params, x, y):
        pred, _ = model.apply(params, None, x)
        return jnp.mean(jnp.square(pred - y))

    @jax.jit
    def update(step, params, opt_state, x, y):
        l, grads = jax.value_and_grad(loss)(params, x, y)
        grads, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, grads)
        return l, params, opt_state

    # Initialize state.
    sample_x, _ = next(train_ds)
    params = model.init(rng, sample_x)
    opt_state = opt.init(params)

    step = 0
    records = defaultdict(list)

    def _format_results(records):
        records = {key: jnp.stack(l) for key, l in records.items()}
        return records

    with tqdm() as pbar:
        while True:
            if step % 100 == 0:
                x, y = next(valid_ds)
                valid_loss = loss(params, x, y)
                # print("Step {}: valid loss {}".format(step, valid_loss))
                records["step"].append(step)
                records["valid_loss"].append(valid_loss)

            try:
                x, y = next(train_ds)
            except StopIteration:
                return params, _format_results(records)
            train_loss, params, opt_state = update(step, params, opt_state, x, y)
            if step % 100 == 0:
                # print("Step {}: train loss {}".format(step, train_loss))
                records["train_loss"].append(train_loss)

            step += 1
            pbar.update()
            if max_iterations > 0 and step >= max_iterations:
                return params, _format_results(records)
%%time
trained_params, records = train_model(train_ds, valid_ds, TRAIN_STEPS)
CPU times: user 2min 36s, sys: 6.9 s, total: 2min 42s
Wall time: 1min 23s
# Plot losses
losses = pd.DataFrame(records)
df = pd.melt(losses, id_vars=["step"], value_vars=["train_loss", "valid_loss"])
plot = (
    gg.ggplot(df)
    + gg.aes(x="step", y="value", color="variable")
    + gg.geom_line()
    + gg.scales.scale_y_log10()
)
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_75_0.png

Sampling

The point of training models is so that they can make predictions! How can we generate predictions with the trained model?

If we’re allowed to feed in the ground truth, we can just run the original model’s apply function.

def plot_samples(truth: np.ndarray, prediction: np.ndarray) -> gg.ggplot:
    assert truth.shape == prediction.shape
    df = pd.DataFrame(
        {"truth": truth.squeeze(), "predicted": prediction.squeeze()}
    ).reset_index()
    df = pd.melt(df, id_vars=["index"], value_vars=["truth", "predicted"])
    plot = (
        gg.ggplot(df) + gg.aes(x="index", y="value", color="variable") + gg.geom_line()
    )
    return plot
# Grab a sample from the validation set.
sample_x, _ = next(valid_ds)
sample_x = sample_x[:, :1]  # Shrink to batch-size 1.

# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(trained_params, None, sample_x)

plot = plot_samples(sample_x[1:], predicted[:-1])
plot.draw()
del sample_x, predicted
_images/05_reconstructing_the_light_curve_of_stars_78_0.png

Run autoregressively

If we can’t feed in the ground truth (because we don’t have it), we can also run the model autoregressively.

def autoregressive_predict(
    trained_params: hk.Params,
    context: jnp.ndarray,
    seq_len: int,
):
    """Given a context, autoregressively generate the rest of a sine wave."""
    ar_outs = []
    context = jax.device_put(context)
    times = range(seq_len - context.shape[0])
    for _ in times:
        full_context = jnp.concatenate([context] + ar_outs)
        outs, _ = jax.jit(model.apply)(trained_params, None, full_context)
        # Append the newest prediction to ar_outs.
        ar_outs.append(outs[-1:])
    # Return the final full prediction.
    return outs
sample_x, _ = next(valid_ds)
context_length = SEQ_LEN // 8
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:context_length, :1]
%%time
# We can reuse params we got from training for inference - as long as the
# declaration order is the same.
predicted = autoregressive_predict(trained_params, context, SEQ_LEN)

plot = plot_samples(sample_x[1:, :1], predicted)
plot += gg.geom_vline(xintercept=len(context), linetype="dashed")
plot.draw()
del predicted
CPU times: user 9.71 s, sys: 194 ms, total: 9.91 s
Wall time: 9.82 s
_images/05_reconstructing_the_light_curve_of_stars_83_1.png
Sharing parameters with a different function.

Unfortunately, this is a bit slow - we’re doing O(N^2) computation for a sequence of length N.

It’d be better if we could do the autoregressive sampling all at once - but we need to write a new Haiku function for that.

We’re in luck - if the Haiku module names match, the same parameters can be used for multiple Haiku functions.

This can be achieved through a combination of two techniques:

  1. If we manually give a unique name to a module, we can ensure that the parameters are directed to the right places.

  2. If modules are instantiated in the same order, they’ll have the same names in different functions.

Here, we rely on method #2 to create a fast autoregressive prediction.

def fast_autoregressive_predict_fn(context, seq_len):
    """Given a context, autoregressively generate the rest of a sine wave."""
    core = hk.LSTM(32)
    dense = hk.Linear(1)
    state = core.initial_state(context.shape[1])
    # Unroll over the context using `hk.dynamic_unroll`.
    # As before, we `hk.BatchApply` the Linear for efficiency.
    context_outs, state = hk.dynamic_unroll(core, context, state)
    context_outs = hk.BatchApply(dense)(context_outs)

    # Now, unroll one step at a time using the running recurrent state.
    ar_outs = []
    x = context_outs[-1]
    times = range(seq_len - context.shape[0])
    for _ in times:
        x, state = core(x, state)
        x = dense(x)
        ar_outs.append(x)
    return jnp.concatenate([context_outs, jnp.stack(ar_outs)])


fast_ar_predict = hk.transform(fast_autoregressive_predict_fn)
fast_ar_predict = jax.jit(fast_ar_predict.apply, static_argnums=3)
%%time
# Reuse the same context from the previous cell.
predicted = fast_ar_predict(trained_params, None, context, SEQ_LEN)

# The plots should be equivalent!
plot = plot_samples(sample_x[1:, :1], predicted[:-1])
plot += gg.geom_vline(xintercept=len(context), linetype="dashed")
_ = plot.draw()
CPU times: user 6.67 s, sys: 144 ms, total: 6.82 s
Wall time: 6.75 s
_images/05_reconstructing_the_light_curve_of_stars_86_1.png
%timeit autoregressive_predict(trained_params, context, SEQ_LEN)
%timeit fast_ar_predict(trained_params, None, context, SEQ_LEN)
86.3 ms ± 1.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
34.2 µs ± 549 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Train all stars

Training

def split_train_validation_date(dataframe, stars, date, look_back) -> TrainSplit:
    train_size = len(dataframe.loc[:date])
    return split_train_validation(dataframe, stars, train_size, look_back)
%%time
train, valid = split_train_validation_date(dataframe_normed, stars, TRAIN_DATE, SEQ_LEN)
TRAIN_SIZE = train[0].shape[1]
print(f"TRAIN_SIZE = {TRAIN_SIZE}")
TRAIN_SIZE = 524288
CPU times: user 5.45 s, sys: 1.75 s, total: 7.2 s
Wall time: 4.42 s
train[0].shape, train[1].shape, valid[0].shape, valid[1].shape
((64, 524288, 1), (64, 524288, 1), (64, 16384, 1), (64, 16384, 1))
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)
del train, valid  # Don't leak temporaries.
%%time
trained_params, records = train_model(train_ds, valid_ds, TRAIN_STEPS)
CPU times: user 2min 36s, sys: 7.03 s, total: 2min 43s
Wall time: 1min 24s
# Plot losses
losses = pd.DataFrame(records)
df = pd.melt(losses, id_vars=["step"], value_vars=["train_loss", "valid_loss"])
plot = (
    gg.ggplot(df)
    + gg.aes(x="step", y="value", color="variable")
    + gg.geom_line()
    + gg.scales.scale_y_log10()
)
_ = plot.draw()
_images/05_reconstructing_the_light_curve_of_stars_95_0.png

Sampling

# Grab a sample from the validation set.
sample_x, _ = next(valid_ds)
sample_x = sample_x[:, :1]  # Shrink to batch-size 1.

# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(trained_params, None, sample_x)

plot = plot_samples(sample_x[1:], predicted[:-1])
plot.draw()
del sample_x, predicted
_images/05_reconstructing_the_light_curve_of_stars_97_0.png

Run autoregressively

%%time
sample_x, _ = next(valid_ds)
context_length = SEQ_LEN // 8
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:context_length, :1]

# Reuse the same context from the previous cell.
predicted = fast_ar_predict(trained_params, None, context, SEQ_LEN)

# The plots should be equivalent!
plot = plot_samples(sample_x[1:, :1], predicted[:-1])
plot += gg.geom_vline(xintercept=len(context), linetype="dashed")
_ = plot.draw()
CPU times: user 195 ms, sys: 18.2 ms, total: 213 ms
Wall time: 144 ms
_images/05_reconstructing_the_light_curve_of_stars_99_1.png
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# check available devices
import jax
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax backend cpu
[<jaxlib.xla_extension.Device at 0x7f4da67afcb0>]

🦎 Online linear regression with a non-stationary environment 🦎

Open in Colab

We implement an online learning non-stationary linear regression problem.

We go there progressively by showing how a linear regression problem can be cast into an online learning problem thanks to the OnlineSupervisedLearner module.

Then, to tackle a non-stationary linear regression problem (i.e. with a weight that can vary in time) we reformulate the problem into a reinforcement learning problem that we implement with the GymFeedBack module of WAX-ML.

We then need to define an “agent” and an “environment” using simple functions implemented with modules:

  • The agent is responsible for learning the weights of its internal linear model.

  • The environment is responsible for generating labels and evaluating the agent’s reward metric.

We experiment with a non-stationary environment that returns the sign of the linear regression parameters at a given time step, known only to the environment.

This example shows that it is quite simple to implement this online-learning task with WAX-ML tools. In particular, the functional workflow adopted here allows reusing the functions implemented for a task for each new task of increasing complexity,

In this journey, we will use:

  • Haiku basic linear module hk.Linear.

  • Optax stochastic gradient descent optimizer: sgd.

  • WAX-ML modules: OnlineSupervisedLearner, Lag, GymFeedBack

  • WAX-ML helper functions: dynamic_unroll, jit_init_apply

%pylab inline
Populating the interactive namespace from numpy and matplotlib
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from matplotlib import pyplot as plt
from wax.compile import jit_init_apply
from wax.modules import OnlineSupervisedLearner

Static Linear Regression

First, let’s implement a simple linear regression

Generate data

Let’s generate a batch of data:

seq = hk.PRNGSequence(42)
X = jax.random.normal(next(seq), (100, 3))
w_true = jnp.ones(3)

Define the model

We use the basic module hk.Linear which is a linear layer. By default, it initializes the weights with random values from the truncated normal, with a standard deviation of \(1 / \sqrt{N}\) (See https://arxiv.org/abs/1502.03167v3) where \(N\) is the size of the inputs.

@jit_init_apply
@hk.transform_with_state
def linear_model(x):
    return hk.Linear(output_size=1, with_bias=False)(x)

Run the model

Let’s run the model using WAX-ML dynamic_unroll on the batch of data.

from wax.unroll import dynamic_unroll
params, state = linear_model.init(next(seq), X[0])
linear_model.apply(params, state, None, X[0])
(DeviceArray([-0.2070887], dtype=float32), FlatMapping({}))
Y_pred, state = dynamic_unroll(linear_model, None, None, next(seq), False, X)
Y_pred.shape
(100, 1)

Check cost

Let’s look at the mean squared error for this non-trained model.

noise = jax.random.normal(next(seq), (100,))
Y = X.dot(w_true) + noise
L = ((Y - Y_pred) ** 2).sum(axis=1)
mean_loss = L.mean()
assert mean_loss > 0

Let’s look at the regret (cumulative sum of the loss) for the non-trained model.

plt.plot(L.cumsum())
plt.title("Regret")
Text(0.5, 1.0, 'Regret')
_images/06_Online_Linear_Regression_26_1.png

As expected, we have a linear regret when we did not train the model!

Online Linear Regression

We will now start training the model online. For a review on online-learning methods see [1]

[1] Elad Hazan, Introduction to Online Convex Optimization

Define an optimizer

opt = optax.sgd(1e-3)

Define a loss

Since we are doing online learning, we need to define a local loss function: $\( \ell_t(y, w, x) = \lVert y_t - w \cdot x_t \rVert^2 \)$

@jax.jit
def loss(y_pred, y):
    return jnp.mean(jnp.square(y_pred - y))

Define a learning strategy

@jit_init_apply
@hk.transform_with_state
def learner(x, y):
    return OnlineSupervisedLearner(linear_model, opt, loss)(x, y)

Generate data

def generate_many_observations(T=300, sigma=1.0e-2, rng=None):
    rng = jax.random.PRNGKey(42) if rng is None else rng
    X = jax.random.normal(rng, (T, 3))
    noise = sigma * jax.random.normal(rng, (T,))
    w_true = jnp.ones(3)
    noise = sigma * jax.random.normal(rng, (T,))
    Y = X.dot(w_true) + noise
    return (X, Y)
T = 3000
X, Y = generate_many_observations(T)

Unroll the learner

output, online_state = dynamic_unroll(learner, None, None, next(seq), False, X, Y)

Plot the regret

Let’s look at the loss and regret over time.

fig, axs = plt.subplots(1, 2, figsize=(9, 3))
axs[0].plot(output["loss"].cumsum())
axs[0].set_title("Regret")

axs[1].plot(output["params"]["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
Text(0.5, 1.0, 'Weight[0,0]')
_images/06_Online_Linear_Regression_44_1.png

We have sub-linear regret!

Online learning with Gym

Now we will recast the online linear regression learning task as a reinforcement learning task implemented with the GymFeedback module of WAX-ML.

For that, we define:

  • obserbations (obs) : pairs (x, y) of features and labels

  • raw observations (raw_obs): pairs (x, noise) of features and noise.

Linear regression agent

In WAX-ML, an agent is a simple function with the following API:

logo

Let’s define a simple linear regression agent with the elements we have defined so far.

def linear_regression_agent(obs):
    x, y = obs

    @jit_init_apply
    @hk.transform_with_state
    def model(x):
        return hk.Linear(output_size=1, with_bias=False)(x)

    opt = optax.sgd(1e-3)

    @jax.jit
    def loss(y_pred, y):
        return jnp.mean(jnp.square(y_pred - y))

    def learner(x, y):
        return OnlineSupervisedLearner(model, opt, loss)(x, y)

    return learner(x, y)

Linear regression environment

In WAX-ML, an environment is a simple function with the following API:

logo

Let’s now define a linear regression environment that, for the moment, have static weights.

It is responsible for generating the real labels and evaluating the agent’s reward.

For the evaluation of the reward, we need the Lag module to evaluate the action of the agent with the labels generated in the previous time step.

from wax.modules import Lag
def stationary_linear_regression_env(action, raw_obs):

    # Only the environment now the true value of the parameters
    w_true = -jnp.ones(3)

    # The environment has its proper loss definition
    @jax.jit
    def loss(y_pred, y):
        return jnp.mean(jnp.square(y_pred - y))

    # raw observation contains features and generative noise
    x, noise = raw_obs

    # generate targets
    y = x @ w_true + noise
    obs = (x, y)

    y_previous = Lag(1)(y)
    # evaluate the prediction made by the agent
    y_pred = action["y_pred"]
    reward = loss(y_pred, y_previous)

    return reward, obs

Generate raw observation

Let’s define a function that generate the raw observation:

def generate_many_raw_observations(T=300, sigma=1.0e-2, rng=None):
    rng = jax.random.PRNGKey(42) if rng is None else rng
    X = jax.random.normal(rng, (T, 3))
    noise = sigma * jax.random.normal(rng, (T,))
    return (X, noise)

Implement Feedback

We are now ready to set things up with the GymFeedback module implemented in WAX-ML.

It implements the following feedback loop:

logo

Equivalently, it can be described with the pair of init and apply functions:

logo
from wax.modules import GymFeedback
@hk.transform_with_state
def gym_fun(raw_obs):
    return GymFeedback(
        linear_regression_agent, stationary_linear_regression_env, return_action=True
    )(raw_obs)

And now we can unroll it on a sequence of raw observations!

seq = hk.PRNGSequence(42)
T = 3000
raw_observations = generate_many_raw_observations(T)
rng = next(seq)
output_sequence, final_state = dynamic_unroll(
    gym_fun,
    None,
    None,
    rng,
    True,
    raw_observations,
)

Let’s visualize the outputs.

We now use pd.Series to represent the reward sequence since its first value is Nan due to the use of the lag operator.

import pandas as pd
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
pd.Series(output_sequence.reward).cumsum().plot(ax=axs[0], title="Regret")
axs[1].plot(output["params"]["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
Text(0.5, 1.0, 'Weight[0,0]')
_images/06_Online_Linear_Regression_67_1.png

Non-stationary environment

Now, let’s implement a non-stationary environment.

We implement it so that the sign of the weight is reversed after 2000$ steps.

class NonStationaryEnvironment(hk.Module):
    def __call__(self, action, raw_obs):

        step = hk.get_state("step", [], init=lambda *_: 0)

        # Only the environment now the true value of the parameters
        # at step 2000 we flip the sign of the true parameters !
        w_true = hk.cond(
            step < 2000,
            step,
            lambda step: -jnp.ones(3),
            step,
            lambda step: jnp.ones(3),
        )

        # The environment has its proper loss definition
        @jax.jit
        def loss(y_pred, y):
            return jnp.mean(jnp.square(y_pred - y))

        # raw observation contains features and generative noise
        x, noise = raw_obs

        # generate targets
        y = x @ w_true + noise
        obs = (x, y)

        # evaluate the prediction made by the agent
        y_previous = Lag(1)(y)
        y_pred = action["y_pred"]
        reward = loss(y_pred, y_previous)

        step += 1
        hk.set_state("step", step)

        return reward, obs

Now let’s run a gym simulation to see how the agent adapt to the change of environment.

@hk.transform_with_state
def gym_fun(raw_obs):
    return GymFeedback(
        linear_regression_agent, NonStationaryEnvironment(), return_action=True
    )(raw_obs)
T = 6000
raw_observations = generate_many_raw_observations(T)
rng = jax.random.PRNGKey(42)
output_sequence, final_state = dynamic_unroll(
    gym_fun,
    None,
    None,
    rng,
    True,
    raw_observations,
)
import pandas as pd
fig, axs = plt.subplots(1, 2, figsize=(9, 3))
pd.Series(output_sequence.reward).cumsum().plot(ax=axs[0], title="Regret")
axs[1].plot(output_sequence.action["params"]["linear"]["w"][:, 0, 0])
axs[1].set_title("Weight[0,0]")
# plt.savefig("../_static/online_linear_regression_regret.png")
Text(0.5, 1.0, 'Weight[0,0]')
_images/06_Online_Linear_Regression_75_1.png

It adapts!

The regret first converges, then jumps on step 2000 and finally readjusts to the new regime.

We see that the weights converge to the correct values in both regimes.

Change log

Best viewed here.

wax 0.0.2 (Unreleased)

  • First realease.

Installing wax

First, obtain the WAX-ML source code:

git clone https://github.com/eserie/wax-ml
cd wax

You can install wax by running:

pip install -e .[complete]  # install wax

To upgrade to the latest version from GitHub, just run git pull from the WAX-ML repository root. You shouldn’t have to reinstall wax because pip install -e sets up symbolic links from site-packages into the repository.

You can install wax development tools by running:

pip install -e .[dev]  # install wax-development-tools

Running the tests

To run all the WAX-ML tests, we recommend using pytest-xdist, which can run tests in parallel. First, install pytest-xdist and pytest-benchmark by running ip install -r build/test-requirements.txt. Then, from the repository root directory run:

pytest -n auto .

You can run a more specific set of tests using pytest’s built-in selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:

pytest -v wax/accessors_test.py

The Colab notebooks are tested for errors as part of the documentation build and Github actions.

Type checking

We use mypy to check the type hints. To check types locally the same way as Github actions checks, you can run:

mypy wax

or

make mypy

Flake8

We use flake8 to check that the code follow the pep8 standard. To check the code, you can run

make flake8

Formatting code

We use isort and black to format the code.

When you are in the root directory of the project, to format code in the package, you can run:

make format-package

To format notebooks in the documentation, you can use:

make format-notebooks

To format all files you can run:

make format

Note that the CI running with actions will verify that formatting all source code does not affect the files. You can check this locally by running :

make check-format

Check actions

You can check that everything is ok by running:

make act

This will check flake8, mypy, isort and black formatting, licenses headers and run tests and coverage.

Update documentation

To rebuild the documentation, install several packages:

pip install -r docs/requirements.txt

And then run:

sphinx-build -b html docs docs/build/html

or run

make docs

This can take a long time because it executes many of the notebooks in the documentation source; if you’d prefer to build the docs without executing the notebooks, you can run:

sphinx-build -b html -D jupyter_execute_notebooks=off docs docs/build/html

or run

make docs-fast

You can then see the generated documentation in docs/_build/html/index.html.

Update notebooks

We use jupytext to maintain three synced copies of the notebooks in docs/notebooks: one in ipynb format, one in py and one in md format. The advantage of the former is that it can be opened and executed directly in Colab; the advantage of the second is that it makes easier to refactor and format python code; the advantage of the latter is that it makes it much easier to track diffs within version control.

Editing ipynb

For making large changes that substantially modify code and outputs, it is easiest to edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface, open http://colab.research.google.com and Upload from your local repo. Update it as needed, Run all cells then Download ipynb. You may want to test that it executes properly, using sphinx-build as explained above.

You could format the python code in your notebooks by running make format in the docs/notebooks directory or make format-notebooks in the root directory.

Editing md

For making smaller changes to the text content of the notebooks, it is easiest to edit the .md versions using a text editor.

Syncing notebooks

After editing either the ipynb or md versions of the notebooks, you can sync the two versions using jupytext by running:

jupytext --sync docs/notebooks/*

or:

cd  docs/notebooks/
make sync

Alternatively, you can run this command via the pre-commit framework by executing the following in the main WAX-ML directory:

pre-commit run --all

See the pre-commit framework documentation for information on how to set your local git environment to execute this automatically.

Creating new notebooks

If you are adding a new notebook to the documentation and would like to use the jupytext --sync command discussed here, you can set up your notebook for jupytext by using the following command:

jupytext --set-formats ipynb,py,md:myst path/to/the/notebook.ipynb

This works by adding a "jupytext" metadata field to the notebook file which specifies the desired formats, and which the jupytext --sync command recognizes when invoked.

Notebooks within the sphinx build

Some of the notebooks are built automatically as part of the Travis pre-submit checks and as part of the Read the docs build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with raises-exceptions metadata (example PR). You have to add this metadata by hand in the .ipynb file. It will be preserved when somebody else re-saves the notebook.

We exclude some notebooks from the build, e.g., because they contain long computations. See exclude_patterns in conf.py.

Documentation building on readthedocs.io

WAX-ML’s auto-generated documentations is at https://wax-ml.readthedocs.io/.

The documentation building is controlled for the entire project by the readthedocs WAX-ML settings. The current settings trigger a documentation build as soon as code is pushed to the GitHub main branch. For each code version, the building process is driven by the .readthedocs.yml and the docs/conf.py configuration files.

For each automated documentation build you can see the documentation build logs.

If you want to test the documentation generation on Readthedocs, you can push code to the test-docs branch. That branch is also built automatically, and you can see the generated documentation here. If the documentation build fails you may want to wipe the build environment for test-docs.

For a local test, you can do it in a fresh directory by replaying the commands executed by Readthedocs and written in their logs:

mkvirtualenv wax-docs  # A new virtualenv
mkdir wax-docs  # A new directory
cd wax-docs
git clone --no-single-branch --depth 50 https://github.com/eserie/wax-ml
cd wax
git checkout --force origin/test-docs
git clean -d -f -f
workon wax-docs

python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html

Public API: wax package

Subpackages

wax.modules package

Gym modules

In WAX-ML, an agent and environments are simple functions:

_images/agent_env.png

A Gym feedback loops can be represented with the diagram:

_images/gymfeedback.png

Equivalently, it can be described with the pair of init and apply functions:

_images/gymfeedback_init_apply.png

gym_feedback

Gym feedback between an agent and a Gym environment.

Online Learning

WAX-ML contains a module to perform online learning for supervised problems.

online_supervised_learner

Other Haiku modules

buffer

Buffer module.

diff

Diff module.

ewma

Exponentioal moving average module.

ewmcov

Exponentially weighted variance module.

ewmvar

Exponentially weighted variance module.

has_changed

Detect if something has changed.

lag

Delay operator.

ohlc

Open-High-Low-Close binning.

pct_change

Relative change between the current and a prior element.

rolling_mean

Rolling mean.

update_on_event

Apply a module when an event occur otherwise return last computed output.

wax.gym package

Gym objects

agent

Define API for Gym agents.

callbacks

Callbacks to work with wax.gym.gym_unroll

entity

Base class for Gym Agent and Env classes.

env

Define API for Gym environments.

haiku_agent

Gym agent defined from an Haiku module

haiku_env

Gym environment defined from an Haiku module

wax.datasets package

generation functions

generate_temperature_data

Generate fake temperature data for tests purposes.

wax.universal package

Universal Haiku modules

eager_ewma

Universal Exponential moving average module and unroll implementations.

wax.accessors

Define accessors for xarray and pandas data containers.

wax.compile

Compilation helper for Haiku Transformed and TransformedWithState pairs of pure funcions.

wax.encode

Encoding schemes to encode/decode numpy data types non supported by JAX, e.g.

wax.format

Format nested data structures to numpy/xarray/pandas containers.

wax.stream

Define Stream object used to synchronize in-memory data streams and unroll data transformations on it.

wax.transform

Transformation functions to work on batches of data.

wax.unroll

Unroll modules on data along first axis.

wax.utils

Some utils functions used in WAX-ML.

Indices and tables