Source code for ise.models.normalizing_flow

"""Conditional autoregressive normalizing flow for aleatoric uncertainty.

This module provides ``NormalizingFlow`` — a conditional density estimator
built with the ``nflows`` library.  It models the conditional distribution
P(y | X) of sea level equivalent (SLE) output given climate forcing features
X, enabling both probabilistic sampling and aleatoric uncertainty estimation.

Architecture
------------
- **Base distribution:** ``ConditionalDiagonalNormal`` with a 2-layer MLP
  context encoder (input_size → flow_hidden_features → output_size * 2),
  mapping forcing features X to the mean and log-variance of the latent z.
- **Transforms:** ``num_flow_transforms`` alternating pairs of
  ``RandomPermutation`` and ``MaskedAffineAutoregressiveTransform`` steps,
  each conditioned on the full feature vector X.
- **Flow:** ``nflows.flows.base.Flow`` wrapping the composite transform and
  the base distribution.

Role in ISEFlow
---------------
The ``NormalizingFlow`` is trained **first** (before the ``DeepEnsemble``)
via maximum likelihood (negative log-probability).  Once trained it serves
two roles:

1. **Latent features:** ``get_latent(X)`` samples ``z ~ q(z | X)`` from the
   base distribution, providing a single low-dimensional context variable that
   the ``DeepEnsemble`` members receive as extra input alongside X.
2. **Aleatoric uncertainty:** ``aleatoric(X, num_samples)`` draws ``num_samples``
   SLE values from the full flow for each input row and returns the std across
   draws, interpreted as the inherent data uncertainty (aleatoric component).

Usage
-----
::

    from ise.models.normalizing_flow import NormalizingFlow

    nf = NormalizingFlow(input_size=83, output_size=1, num_flow_transforms=5,
                         flow_hidden_features=16)
    nf.fit(X_train, y_train, epochs=500, batch_size=64,
           X_val=X_val, y_val=y_val, early_stopping=True, patience=20)

    z = nf.get_latent(X_val)          # shape (N, 1) — latent context for DeepEnsemble
    uncertainty = nf.aleatoric(X_val, num_samples=100)  # shape (N,)
    samples = nf.sample(X_val[:5], num_samples=200)     # shape (5, 200)

    nf.save("nf_checkpoint.pth")
    nf_loaded = NormalizingFlow.load("nf_checkpoint.pth")

Checkpointing and early stopping
---------------------------------
``fit()`` integrates with ``CheckpointSaver`` / ``EarlyStoppingCheckpointer``
from ``ise.models.training``.  If a checkpoint file already exists at
``checkpoint_path``, training resumes from the saved epoch.  After training
completes the best checkpoint is loaded back and the temporary file is deleted.
Optional ``wandb`` logging is supported via the ``wandb_run`` argument.
"""

import json
import os

import numpy as np
import torch
import wandb
from nflows import distributions, flows, transforms
from torch import nn, optim

from ise.data.dataclasses import EmulatorDataset
from ise.models.training import CheckpointSaver, EarlyStoppingCheckpointer
from ise.utils.functions import get_device, to_tensor


[docs] class NormalizingFlow(nn.Module): """ A normalizing flow model for probabilistic modeling using invertible transformations. This model utilizes a sequence of invertible transformations to model complex probability distributions. It is built with a base distribution and a series of transformations, leveraging autoregressive neural networks. Attributes: num_flow_transforms (int): Number of flow transformations in the model. num_input_features (int): Number of input features. num_predicted_sle (int): Number of predicted sea-level equivalent values. flow_hidden_features (int): Number of hidden features in the flow model. output_sequence_length (int): Length of the output sequence. device (str): Device on which the model is run ("cuda" or "cpu"). base_distribution (distributions.normal.ConditionalDiagonalNormal): The base normal distribution conditioned on input features. t (transforms.base.CompositeTransform): Composite transformation for the normalizing flow. flow (flows.base.Flow): The normalizing flow model. optimizer (torch.optim.Adam): Optimizer for training the model. criterion (callable): Log probability function used as the loss criterion. trained (bool): Flag indicating if the model has been trained. """ def __init__( self, input_size=43, output_size=1, output_sequence_length=86, num_flow_transforms=5, flow_hidden_features=16, legacy_v1_0_0=False, ): """Construct the normalizing flow architecture. Args: input_size (int, optional): Number of conditioning features. Defaults to 43. output_size (int, optional): Dimensionality of the target variable (SLE). Defaults to 1. output_sequence_length (int, optional): Projection length used in dataset batching. Defaults to 86. num_flow_transforms (int, optional): Number of RandomPermutation + MaskedAffineAutoregressive transform pairs. Defaults to 5. flow_hidden_features (int, optional): Width of the context encoder and autoregressive hidden layers. Defaults to 16. legacy_v1_0_0 (bool, optional): Build the v1.0.0 architecture variant: a single ``nn.Linear`` context encoder (no hidden layer) and ``flow_hidden_features = output_size * 2``. Used only to load v1.0.0 ISEFlow weights — leave False for any newly trained model. Defaults to False. """ super().__init__() self.num_flow_transforms = num_flow_transforms self.num_input_features = input_size self.num_predicted_sle = output_size self.legacy_v1_0_0 = legacy_v1_0_0 if legacy_v1_0_0: # v1.0.0 hardcoded flow_hidden_features = output_size * 2 self.flow_hidden_features = output_size * 2 else: self.flow_hidden_features = flow_hidden_features self.output_sequence_length = output_sequence_length self.device = get_device() self.to(self.device) # Define base distribution. v1.0.0 used a single Linear; current models use # a 2-layer MLP context encoder. if legacy_v1_0_0: context_encoder = nn.Linear(self.num_input_features, output_size * 2) else: context_encoder = nn.Sequential( nn.Linear(self.num_input_features, self.flow_hidden_features), nn.ReLU(), nn.Linear(self.flow_hidden_features, output_size * 2), ) self.base_distribution = distributions.normal.ConditionalDiagonalNormal( shape=[self.num_predicted_sle], context_encoder=context_encoder, ) # Create flow transforms t = [] for _ in range(self.num_flow_transforms): t.append( transforms.permutations.RandomPermutation( features=self.num_predicted_sle, ) ) t.append( transforms.autoregressive.MaskedAffineAutoregressiveTransform( features=self.num_predicted_sle, hidden_features=self.flow_hidden_features, context_features=self.num_input_features, ) ) self.t = transforms.base.CompositeTransform(t) # Build flow model self.flow = flows.base.Flow(transform=self.t, distribution=self.base_distribution) # Define optimizer and criterion self.trained = False self.wandb_run = None self.optimizer = optim.AdamW( self.flow.parameters(), )
[docs] def fit( self, X, y, X_val=None, y_val=None, epochs=100, batch_size=64, save_checkpoints=True, checkpoint_path="checkpoint.pt", early_stopping=True, patience=10, verbose=True, wandb_run=None, lr=1e-4, wd=1e-6, ): """Train the normalizing flow via maximum likelihood (negative log-probability). If ``checkpoint_path`` already exists, training resumes from the saved epoch. After training, the best checkpoint is loaded back into the model and the temporary file is deleted. Args: X (array-like): Input features of shape ``(num_samples, num_features)``. y (array-like): Target values of shape ``(num_samples, output_size)``. X_val (array-like, optional): Validation features for early stopping. y_val (array-like, optional): Validation targets for early stopping. epochs (int, optional): Number of training epochs. Defaults to 100. batch_size (int, optional): Batch size for training. Defaults to 64. save_checkpoints (bool, optional): Whether to save model checkpoints. Defaults to True. checkpoint_path (str, optional): Path to save model checkpoints. Defaults to ``'checkpoint.pt'``. early_stopping (bool, optional): Whether to use early stopping. Defaults to True. patience (int, optional): Number of epochs to wait before early stopping. Defaults to 10. verbose (bool, optional): Whether to print training progress. Defaults to True. wandb_run (wandb.run, optional): Weights & Biases run for logging. Defaults to None. lr (float, optional): AdamW learning rate. Defaults to ``1e-4``. wd (float, optional): AdamW weight decay. Defaults to ``1e-6``. """ if self.trained: print("Model is already trained. Skipping training.") return self.optimizer = optim.AdamW(self.flow.parameters(), lr=lr, weight_decay=wd) self.criterion = self.flow.log_prob self.to(self.device) X, y = to_tensor(X).to(self.device), to_tensor(y).to(self.device) if y.ndimension() == 1: y = y.unsqueeze(1) self.wandb_run = wandb_run validate = True if X_val is not None and y_val is not None else False start_epoch = 1 best_loss = float("inf") if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, weights_only=True) self.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] + 1 best_loss = checkpoint.get("best_loss", float("inf")) if verbose: print( f"Resuming from checkpoint at epoch {start_epoch} with validation loss {best_loss:.6f}" ) dataset = EmulatorDataset( X, y, sequence_length=1, projection_length=self.output_sequence_length ) data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) self.train() if validate: X_val, y_val = to_tensor(X_val).to(self.device), to_tensor(y_val).to(self.device) if y_val.ndimension() == 1: y_val = y_val.unsqueeze(1) val_dataset = EmulatorDataset( X_val, y_val, sequence_length=1, projection_length=self.output_sequence_length ) val_data_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False ) if save_checkpoints: if early_stopping: checkpointer = EarlyStoppingCheckpointer( self, self.optimizer, checkpoint_path, patience, verbose ) else: checkpointer = CheckpointSaver(self, self.optimizer, checkpoint_path, verbose) checkpointer.best_loss = best_loss if start_epoch <= epochs: for epoch in range(start_epoch, epochs + 1): epoch_loss = [] for i, (x, y) in enumerate(data_loader): x = x.to(self.device).view(x.shape[0], -1) y = y.to(self.device) self.optimizer.zero_grad() loss = torch.mean(-self.flow.log_prob(inputs=y, context=x)) loss.backward() self.optimizer.step() epoch_loss.append(loss.item()) if validate: self.eval() val_losses = [] with torch.no_grad(): for val_x, val_y in val_data_loader: val_x = val_x.to(self.device).view(val_x.shape[0], -1) val_y = val_y.to(self.device) val_loss = torch.mean(-self.flow.log_prob(inputs=val_y, context=val_x)) val_losses.append(val_loss.item()) average_epoch_loss = ( sum(val_losses) / len(val_losses) if val_losses else float("inf") ) train_avg_loss = ( sum(epoch_loss) / len(epoch_loss) if epoch_loss else float("inf") ) if self.wandb_run: log_dict = {"epoch": epoch, "val_loss": average_epoch_loss} if train_avg_loss is not None: log_dict["train_loss"] = train_avg_loss self.wandb_run.log(log_dict) self.train() else: average_epoch_loss = sum(epoch_loss) / len(epoch_loss) if self.wandb_run: self.wandb_run.log({"epoch": epoch, "loss": average_epoch_loss}) if save_checkpoints: checkpointer(average_epoch_loss, epoch) if hasattr(checkpointer, "early_stop") and checkpointer.early_stop: if verbose: print("Early stopping") break if verbose: print( f"[epoch/total]: [{epoch}/{epochs}], loss: {average_epoch_loss}{f' -- {checkpointer.log}' if save_checkpoints else ''}" ) else: if verbose: print(f"Training already completed ({epochs}/{epochs}).") self.trained = True self.model_dir = None # Load best checkpoint back into the model — but only if it actually # exists. The checkpointer only writes when loss improves, so very # short training runs can finish without a checkpoint file. if save_checkpoints and os.path.exists(checkpoint_path): if self.wandb_run: model_name = checkpoint_path.split("/")[-1].replace(".pt", "") artifact = wandb.Artifact(model_name, type="model") artifact.add_file(checkpoint_path) self.wandb_run.log_artifact(artifact) checkpoint = torch.load(checkpoint_path, weights_only=True) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint.keys(): self.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.best_loss = checkpoint["best_loss"] self.epochs_trained = checkpoint["epoch"] else: self.load_state_dict(checkpoint) os.remove(checkpoint_path)
[docs] def sample(self, features, num_samples, return_type="numpy"): """ Generates samples from the trained normalizing flow model. Args: features (array-like or torch.Tensor): Input features to condition the samples on. num_samples (int): Number of samples to generate per input feature set. return_type (str, optional): Return type, either "numpy" or "tensor". Defaults to "numpy". Returns: np.ndarray or torch.Tensor: Generated samples of shape (num_samples, output_size). """ features = to_tensor(features) samples = self.flow.sample(num_samples, context=features).reshape( features.shape[0], num_samples ) if return_type == "numpy": return samples.detach().cpu().numpy() return samples
[docs] def get_latent(self, x, latent_dim=1): """ Computes the latent space representation of the given input. Two approaches are used depending on the model version: - **v1.0.0 (legacy)**: deterministically pushes a zero vector through the forward transform conditioned on ``x``. Same ``x`` always yields the same ``z``, so the flow acts as a learned deterministic feature extractor. - **v1.1.0+ (default)**: draws ``latent_dim`` samples from the conditional base distribution. ``z`` is a stochastic summary of ``x``, which is more statistically grounded — the latent reflects the modeled conditional distribution rather than a single fixed point. These approaches behave differently and produce different downstream DeepEnsemble inputs; the choice may be worth revisiting in future versions. The legacy path is more pragmatic and reproducible at inference, while the sampling path aligns more closely with the probabilistic interpretation of the flow. Args: x (array-like or torch.Tensor): Input data of shape (num_samples, num_features). latent_dim (int, optional): Number of latent samples to draw (post-v1.0.0). Defaults to 1. Returns: torch.Tensor: Latent space representation of the input data. """ x = to_tensor(x).to(self.device) if self.legacy_v1_0_0: # v1.0.0 deterministically pushed a zero vector through the forward # transform conditioned on x. The DeepEnsemble was trained on these # latents, so inference must reproduce the same operation. latent_constant_tensor = torch.zeros((x.shape[0], 1), device=self.device) z, _ = self.t(latent_constant_tensor.float(), context=x) return z return self.base_distribution.sample(latent_dim, context=x).squeeze(2)
[docs] def aleatoric(self, features, num_samples, batch_size=128): """Estimate aleatoric uncertainty as the std across flow samples per input row. For each input row, draws ``num_samples`` samples from the conditional flow and returns the standard deviation across those samples. NaN samples are ignored when computing the std. Args: features (array-like or torch.Tensor): Input features of shape ``(N, num_features)``. num_samples (int): Number of flow samples drawn per input row. batch_size (int, optional): Number of rows processed per forward pass. Defaults to 128. Returns: numpy.ndarray: Per-row aleatoric uncertainty, shape ``(N,)``. """ features = to_tensor(features) num_batches = (features.shape[0] + batch_size - 1) // batch_size aleatoric_uncertainty = [] for i in range(num_batches): start_idx = i * batch_size end_idx = min((i + 1) * batch_size, features.shape[0]) batch_features = features[start_idx:end_idx] samples = self.flow.sample(num_samples, context=batch_features) samples = samples.detach().cpu().numpy() samples = np.where(np.isfinite(samples), samples, np.nan) std = np.nanstd(samples, axis=1).squeeze() aleatoric_uncertainty.append(std) return np.concatenate(aleatoric_uncertainty)
[docs] def save(self, path): """ Saves the trained model and its metadata. Args: path (str): Path to save the model checkpoint. Raises: ValueError: If the model has not been trained before saving. """ if not self.trained: raise ValueError("Train the model before saving.") metadata = { "input_size": self.num_input_features, "output_size": self.num_predicted_sle, "device": self.device, "best_loss": float(getattr(self, "best_loss", float("inf"))), "epochs_trained": int(getattr(self, "epochs_trained", 0)), "flow_hidden_size": self.flow_hidden_features, "num_flows": self.num_flow_transforms, } metadata_path = path + "_metadata.json" with open(metadata_path, "w") as f: json.dump(metadata, f, indent=4) torch.save( { "model_state_dict": self.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "trained": self.trained, }, path, ) print(f"Model and metadata saved to {path} and {metadata_path}, respectively.")
[docs] @staticmethod def load(path): """ Loads a trained normalizing flow model from a saved checkpoint. Args: path (str): Path to the saved model checkpoint. Returns: NormalizingFlow: A restored instance of the NormalizingFlow model. """ metadata_path = path + "_metadata.json" with open(metadata_path) as f: metadata = json.load(f) # v1.0.0 metadata only stored input_size/output_size; the architecture used # a single-Linear context encoder and num_flow_transforms=5. Detect by absence # of the post-v1.0.0 keys. is_legacy_v1_0_0 = "flow_hidden_size" not in metadata or "num_flows" not in metadata if is_legacy_v1_0_0: model = NormalizingFlow( input_size=metadata["input_size"], output_size=metadata["output_size"], num_flow_transforms=5, legacy_v1_0_0=True, ) else: model = NormalizingFlow( input_size=metadata["input_size"], output_size=metadata["output_size"], flow_hidden_features=metadata["flow_hidden_size"], num_flow_transforms=metadata["num_flows"], ) checkpoint = torch.load( path, map_location="cpu" if get_device() == "cpu" else None, weights_only=True ) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint.keys(): model.load_state_dict(checkpoint["model_state_dict"]) model.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) model.trained = checkpoint["trained"] else: model.load_state_dict(checkpoint) model.trained = True model.trained = True model.model_dir = os.path.dirname(path) model.best_loss = metadata.get("best_loss", None) model.epochs_trained = metadata.get("epochs_trained", None) model.flow_hidden_size = metadata.get("flow_hidden_size", None) model.num_flows = metadata.get("num_flows", None) model.to(model.device) model.eval() return model