"""Conditional autoregressive normalizing flow for aleatoric uncertainty.
This module provides ``NormalizingFlow`` — a conditional density estimator
built with the ``nflows`` library. It models the conditional distribution
P(y | X) of sea level equivalent (SLE) output given climate forcing features
X, enabling both probabilistic sampling and aleatoric uncertainty estimation.
Architecture
------------
- **Base distribution:** ``ConditionalDiagonalNormal`` with a 2-layer MLP
context encoder (input_size → flow_hidden_features → output_size * 2),
mapping forcing features X to the mean and log-variance of the latent z.
- **Transforms:** ``num_flow_transforms`` alternating pairs of
``RandomPermutation`` and ``MaskedAffineAutoregressiveTransform`` steps,
each conditioned on the full feature vector X.
- **Flow:** ``nflows.flows.base.Flow`` wrapping the composite transform and
the base distribution.
Role in ISEFlow
---------------
The ``NormalizingFlow`` is trained **first** (before the ``DeepEnsemble``)
via maximum likelihood (negative log-probability). Once trained it serves
two roles:
1. **Latent features:** ``get_latent(X)`` samples ``z ~ q(z | X)`` from the
base distribution, providing a single low-dimensional context variable that
the ``DeepEnsemble`` members receive as extra input alongside X.
2. **Aleatoric uncertainty:** ``aleatoric(X, num_samples)`` draws ``num_samples``
SLE values from the full flow for each input row and returns the std across
draws, interpreted as the inherent data uncertainty (aleatoric component).
Usage
-----
::
from ise.models.normalizing_flow import NormalizingFlow
nf = NormalizingFlow(input_size=83, output_size=1, num_flow_transforms=5,
flow_hidden_features=16)
nf.fit(X_train, y_train, epochs=500, batch_size=64,
X_val=X_val, y_val=y_val, early_stopping=True, patience=20)
z = nf.get_latent(X_val) # shape (N, 1) — latent context for DeepEnsemble
uncertainty = nf.aleatoric(X_val, num_samples=100) # shape (N,)
samples = nf.sample(X_val[:5], num_samples=200) # shape (5, 200)
nf.save("nf_checkpoint.pth")
nf_loaded = NormalizingFlow.load("nf_checkpoint.pth")
Checkpointing and early stopping
---------------------------------
``fit()`` integrates with ``CheckpointSaver`` / ``EarlyStoppingCheckpointer``
from ``ise.models.training``. If a checkpoint file already exists at
``checkpoint_path``, training resumes from the saved epoch. After training
completes the best checkpoint is loaded back and the temporary file is deleted.
Optional ``wandb`` logging is supported via the ``wandb_run`` argument.
"""
import json
import os
import numpy as np
import torch
import wandb
from nflows import distributions, flows, transforms
from torch import nn, optim
from ise.data.dataclasses import EmulatorDataset
from ise.models.training import CheckpointSaver, EarlyStoppingCheckpointer
from ise.utils.functions import get_device, to_tensor
[docs]
class NormalizingFlow(nn.Module):
"""
A normalizing flow model for probabilistic modeling using invertible transformations.
This model utilizes a sequence of invertible transformations to model complex probability
distributions. It is built with a base distribution and a series of transformations,
leveraging autoregressive neural networks.
Attributes:
num_flow_transforms (int): Number of flow transformations in the model.
num_input_features (int): Number of input features.
num_predicted_sle (int): Number of predicted sea-level equivalent values.
flow_hidden_features (int): Number of hidden features in the flow model.
output_sequence_length (int): Length of the output sequence.
device (str): Device on which the model is run ("cuda" or "cpu").
base_distribution (distributions.normal.ConditionalDiagonalNormal):
The base normal distribution conditioned on input features.
t (transforms.base.CompositeTransform): Composite transformation for the normalizing flow.
flow (flows.base.Flow): The normalizing flow model.
optimizer (torch.optim.Adam): Optimizer for training the model.
criterion (callable): Log probability function used as the loss criterion.
trained (bool): Flag indicating if the model has been trained.
"""
def __init__(
self,
input_size=43,
output_size=1,
output_sequence_length=86,
num_flow_transforms=5,
flow_hidden_features=16,
legacy_v1_0_0=False,
):
"""Construct the normalizing flow architecture.
Args:
input_size (int, optional): Number of conditioning features. Defaults to 43.
output_size (int, optional): Dimensionality of the target variable (SLE). Defaults to 1.
output_sequence_length (int, optional): Projection length used in dataset batching. Defaults to 86.
num_flow_transforms (int, optional): Number of RandomPermutation + MaskedAffineAutoregressive
transform pairs. Defaults to 5.
flow_hidden_features (int, optional): Width of the context encoder and autoregressive
hidden layers. Defaults to 16.
legacy_v1_0_0 (bool, optional): Build the v1.0.0 architecture variant: a single
``nn.Linear`` context encoder (no hidden layer) and ``flow_hidden_features =
output_size * 2``. Used only to load v1.0.0 ISEFlow weights — leave False
for any newly trained model. Defaults to False.
"""
super().__init__()
self.num_flow_transforms = num_flow_transforms
self.num_input_features = input_size
self.num_predicted_sle = output_size
self.legacy_v1_0_0 = legacy_v1_0_0
if legacy_v1_0_0:
# v1.0.0 hardcoded flow_hidden_features = output_size * 2
self.flow_hidden_features = output_size * 2
else:
self.flow_hidden_features = flow_hidden_features
self.output_sequence_length = output_sequence_length
self.device = get_device()
self.to(self.device)
# Define base distribution. v1.0.0 used a single Linear; current models use
# a 2-layer MLP context encoder.
if legacy_v1_0_0:
context_encoder = nn.Linear(self.num_input_features, output_size * 2)
else:
context_encoder = nn.Sequential(
nn.Linear(self.num_input_features, self.flow_hidden_features),
nn.ReLU(),
nn.Linear(self.flow_hidden_features, output_size * 2),
)
self.base_distribution = distributions.normal.ConditionalDiagonalNormal(
shape=[self.num_predicted_sle],
context_encoder=context_encoder,
)
# Create flow transforms
t = []
for _ in range(self.num_flow_transforms):
t.append(
transforms.permutations.RandomPermutation(
features=self.num_predicted_sle,
)
)
t.append(
transforms.autoregressive.MaskedAffineAutoregressiveTransform(
features=self.num_predicted_sle,
hidden_features=self.flow_hidden_features,
context_features=self.num_input_features,
)
)
self.t = transforms.base.CompositeTransform(t)
# Build flow model
self.flow = flows.base.Flow(transform=self.t, distribution=self.base_distribution)
# Define optimizer and criterion
self.trained = False
self.wandb_run = None
self.optimizer = optim.AdamW(
self.flow.parameters(),
)
[docs]
def fit(
self,
X,
y,
X_val=None,
y_val=None,
epochs=100,
batch_size=64,
save_checkpoints=True,
checkpoint_path="checkpoint.pt",
early_stopping=True,
patience=10,
verbose=True,
wandb_run=None,
lr=1e-4,
wd=1e-6,
):
"""Train the normalizing flow via maximum likelihood (negative log-probability).
If ``checkpoint_path`` already exists, training resumes from the saved
epoch. After training, the best checkpoint is loaded back into the model
and the temporary file is deleted.
Args:
X (array-like): Input features of shape ``(num_samples, num_features)``.
y (array-like): Target values of shape ``(num_samples, output_size)``.
X_val (array-like, optional): Validation features for early stopping.
y_val (array-like, optional): Validation targets for early stopping.
epochs (int, optional): Number of training epochs. Defaults to 100.
batch_size (int, optional): Batch size for training. Defaults to 64.
save_checkpoints (bool, optional): Whether to save model checkpoints. Defaults to True.
checkpoint_path (str, optional): Path to save model checkpoints. Defaults to ``'checkpoint.pt'``.
early_stopping (bool, optional): Whether to use early stopping. Defaults to True.
patience (int, optional): Number of epochs to wait before early stopping. Defaults to 10.
verbose (bool, optional): Whether to print training progress. Defaults to True.
wandb_run (wandb.run, optional): Weights & Biases run for logging. Defaults to None.
lr (float, optional): AdamW learning rate. Defaults to ``1e-4``.
wd (float, optional): AdamW weight decay. Defaults to ``1e-6``.
"""
if self.trained:
print("Model is already trained. Skipping training.")
return
self.optimizer = optim.AdamW(self.flow.parameters(), lr=lr, weight_decay=wd)
self.criterion = self.flow.log_prob
self.to(self.device)
X, y = to_tensor(X).to(self.device), to_tensor(y).to(self.device)
if y.ndimension() == 1:
y = y.unsqueeze(1)
self.wandb_run = wandb_run
validate = True if X_val is not None and y_val is not None else False
start_epoch = 1
best_loss = float("inf")
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, weights_only=True)
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"] + 1
best_loss = checkpoint.get("best_loss", float("inf"))
if verbose:
print(
f"Resuming from checkpoint at epoch {start_epoch} with validation loss {best_loss:.6f}"
)
dataset = EmulatorDataset(
X, y, sequence_length=1, projection_length=self.output_sequence_length
)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
self.train()
if validate:
X_val, y_val = to_tensor(X_val).to(self.device), to_tensor(y_val).to(self.device)
if y_val.ndimension() == 1:
y_val = y_val.unsqueeze(1)
val_dataset = EmulatorDataset(
X_val, y_val, sequence_length=1, projection_length=self.output_sequence_length
)
val_data_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False
)
if save_checkpoints:
if early_stopping:
checkpointer = EarlyStoppingCheckpointer(
self, self.optimizer, checkpoint_path, patience, verbose
)
else:
checkpointer = CheckpointSaver(self, self.optimizer, checkpoint_path, verbose)
checkpointer.best_loss = best_loss
if start_epoch <= epochs:
for epoch in range(start_epoch, epochs + 1):
epoch_loss = []
for i, (x, y) in enumerate(data_loader):
x = x.to(self.device).view(x.shape[0], -1)
y = y.to(self.device)
self.optimizer.zero_grad()
loss = torch.mean(-self.flow.log_prob(inputs=y, context=x))
loss.backward()
self.optimizer.step()
epoch_loss.append(loss.item())
if validate:
self.eval()
val_losses = []
with torch.no_grad():
for val_x, val_y in val_data_loader:
val_x = val_x.to(self.device).view(val_x.shape[0], -1)
val_y = val_y.to(self.device)
val_loss = torch.mean(-self.flow.log_prob(inputs=val_y, context=val_x))
val_losses.append(val_loss.item())
average_epoch_loss = (
sum(val_losses) / len(val_losses) if val_losses else float("inf")
)
train_avg_loss = (
sum(epoch_loss) / len(epoch_loss) if epoch_loss else float("inf")
)
if self.wandb_run:
log_dict = {"epoch": epoch, "val_loss": average_epoch_loss}
if train_avg_loss is not None:
log_dict["train_loss"] = train_avg_loss
self.wandb_run.log(log_dict)
self.train()
else:
average_epoch_loss = sum(epoch_loss) / len(epoch_loss)
if self.wandb_run:
self.wandb_run.log({"epoch": epoch, "loss": average_epoch_loss})
if save_checkpoints:
checkpointer(average_epoch_loss, epoch)
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
if verbose:
print("Early stopping")
break
if verbose:
print(
f"[epoch/total]: [{epoch}/{epochs}], loss: {average_epoch_loss}{f' -- {checkpointer.log}' if save_checkpoints else ''}"
)
else:
if verbose:
print(f"Training already completed ({epochs}/{epochs}).")
self.trained = True
self.model_dir = None
# Load best checkpoint back into the model — but only if it actually
# exists. The checkpointer only writes when loss improves, so very
# short training runs can finish without a checkpoint file.
if save_checkpoints and os.path.exists(checkpoint_path):
if self.wandb_run:
model_name = checkpoint_path.split("/")[-1].replace(".pt", "")
artifact = wandb.Artifact(model_name, type="model")
artifact.add_file(checkpoint_path)
self.wandb_run.log_artifact(artifact)
checkpoint = torch.load(checkpoint_path, weights_only=True)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint.keys():
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.best_loss = checkpoint["best_loss"]
self.epochs_trained = checkpoint["epoch"]
else:
self.load_state_dict(checkpoint)
os.remove(checkpoint_path)
[docs]
def sample(self, features, num_samples, return_type="numpy"):
"""
Generates samples from the trained normalizing flow model.
Args:
features (array-like or torch.Tensor): Input features to condition the samples on.
num_samples (int): Number of samples to generate per input feature set.
return_type (str, optional): Return type, either "numpy" or "tensor". Defaults to "numpy".
Returns:
np.ndarray or torch.Tensor: Generated samples of shape (num_samples, output_size).
"""
features = to_tensor(features)
samples = self.flow.sample(num_samples, context=features).reshape(
features.shape[0], num_samples
)
if return_type == "numpy":
return samples.detach().cpu().numpy()
return samples
[docs]
def get_latent(self, x, latent_dim=1):
"""
Computes the latent space representation of the given input.
Two approaches are used depending on the model version:
- **v1.0.0 (legacy)**: deterministically pushes a zero vector through the
forward transform conditioned on ``x``. Same ``x`` always yields the same
``z``, so the flow acts as a learned deterministic feature extractor.
- **v1.1.0+ (default)**: draws ``latent_dim`` samples from the conditional
base distribution. ``z`` is a stochastic summary of ``x``, which is more
statistically grounded — the latent reflects the modeled conditional
distribution rather than a single fixed point.
These approaches behave differently and produce different downstream
DeepEnsemble inputs; the choice may be worth revisiting in future versions.
The legacy path is more pragmatic and reproducible at inference, while the
sampling path aligns more closely with the probabilistic interpretation of
the flow.
Args:
x (array-like or torch.Tensor): Input data of shape (num_samples, num_features).
latent_dim (int, optional): Number of latent samples to draw (post-v1.0.0). Defaults to 1.
Returns:
torch.Tensor: Latent space representation of the input data.
"""
x = to_tensor(x).to(self.device)
if self.legacy_v1_0_0:
# v1.0.0 deterministically pushed a zero vector through the forward
# transform conditioned on x. The DeepEnsemble was trained on these
# latents, so inference must reproduce the same operation.
latent_constant_tensor = torch.zeros((x.shape[0], 1), device=self.device)
z, _ = self.t(latent_constant_tensor.float(), context=x)
return z
return self.base_distribution.sample(latent_dim, context=x).squeeze(2)
[docs]
def aleatoric(self, features, num_samples, batch_size=128):
"""Estimate aleatoric uncertainty as the std across flow samples per input row.
For each input row, draws ``num_samples`` samples from the conditional flow
and returns the standard deviation across those samples. NaN samples are
ignored when computing the std.
Args:
features (array-like or torch.Tensor): Input features of shape ``(N, num_features)``.
num_samples (int): Number of flow samples drawn per input row.
batch_size (int, optional): Number of rows processed per forward pass. Defaults to 128.
Returns:
numpy.ndarray: Per-row aleatoric uncertainty, shape ``(N,)``.
"""
features = to_tensor(features)
num_batches = (features.shape[0] + batch_size - 1) // batch_size
aleatoric_uncertainty = []
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, features.shape[0])
batch_features = features[start_idx:end_idx]
samples = self.flow.sample(num_samples, context=batch_features)
samples = samples.detach().cpu().numpy()
samples = np.where(np.isfinite(samples), samples, np.nan)
std = np.nanstd(samples, axis=1).squeeze()
aleatoric_uncertainty.append(std)
return np.concatenate(aleatoric_uncertainty)
[docs]
def save(self, path):
"""
Saves the trained model and its metadata.
Args:
path (str): Path to save the model checkpoint.
Raises:
ValueError: If the model has not been trained before saving.
"""
if not self.trained:
raise ValueError("Train the model before saving.")
metadata = {
"input_size": self.num_input_features,
"output_size": self.num_predicted_sle,
"device": self.device,
"best_loss": float(getattr(self, "best_loss", float("inf"))),
"epochs_trained": int(getattr(self, "epochs_trained", 0)),
"flow_hidden_size": self.flow_hidden_features,
"num_flows": self.num_flow_transforms,
}
metadata_path = path + "_metadata.json"
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=4)
torch.save(
{
"model_state_dict": self.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"trained": self.trained,
},
path,
)
print(f"Model and metadata saved to {path} and {metadata_path}, respectively.")
[docs]
@staticmethod
def load(path):
"""
Loads a trained normalizing flow model from a saved checkpoint.
Args:
path (str): Path to the saved model checkpoint.
Returns:
NormalizingFlow: A restored instance of the NormalizingFlow model.
"""
metadata_path = path + "_metadata.json"
with open(metadata_path) as f:
metadata = json.load(f)
# v1.0.0 metadata only stored input_size/output_size; the architecture used
# a single-Linear context encoder and num_flow_transforms=5. Detect by absence
# of the post-v1.0.0 keys.
is_legacy_v1_0_0 = "flow_hidden_size" not in metadata or "num_flows" not in metadata
if is_legacy_v1_0_0:
model = NormalizingFlow(
input_size=metadata["input_size"],
output_size=metadata["output_size"],
num_flow_transforms=5,
legacy_v1_0_0=True,
)
else:
model = NormalizingFlow(
input_size=metadata["input_size"],
output_size=metadata["output_size"],
flow_hidden_features=metadata["flow_hidden_size"],
num_flow_transforms=metadata["num_flows"],
)
checkpoint = torch.load(
path, map_location="cpu" if get_device() == "cpu" else None, weights_only=True
)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint.keys():
model.load_state_dict(checkpoint["model_state_dict"])
model.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.trained = checkpoint["trained"]
else:
model.load_state_dict(checkpoint)
model.trained = True
model.trained = True
model.model_dir = os.path.dirname(path)
model.best_loss = metadata.get("best_loss", None)
model.epochs_trained = metadata.get("epochs_trained", None)
model.flow_hidden_size = metadata.get("flow_hidden_size", None)
model.num_flows = metadata.get("num_flows", None)
model.to(model.device)
model.eval()
return model