ise.models
Ice sheet emulator models: ISEFlow, LSTM, DeepEnsemble, NormalizingFlow, loss functions, training utilities, and pretrained weight paths.
The flagship model is ISEFlow (ise.models.iseflow), a hybrid emulator
that chains a NormalizingFlow (aleatoric uncertainty) with a DeepEnsemble
of LSTM networks (epistemic uncertainty). Pretrained weights for AIS v1.1.0
and GrIS v1.1.0 are accessed via ISEFlow_AIS and ISEFlow_GrIS.
Additional experimental models (GP, PCA, ScenarioPredictor,
VariationalLSTMEmulator) live in ise.models._experimental and are not
part of the primary API.
Submodules
ise.models.iseflow
ISEFlow hybrid ice sheet emulator — base class and pretrained variants.
This module is the primary user-facing entry point for running and training ISEFlow emulators. It provides:
ISEFlow (base class)
A hybrid torch.nn.Module combining:
NormalizingFlow— conditional autoregressive flow trained first via maximum likelihood. At inference time it provides (a) a latent representationzfed as extra context to the ensemble, and (b) the aleatoric uncertainty estimate via Monte Carlo sampling.DeepEnsemble— ensemble of LSTMs trained on[X, z](original features concatenated with the NF latent). Disagreement across members gives the epistemic uncertainty.
Training is sequential and order-sensitive: the NF must be trained before
the DE so that the latent features z are meaningful:
from ise.models.iseflow import ISEFlow
from ise.models.normalizing_flow import NormalizingFlow
from ise.models.deep_ensemble import DeepEnsemble
nf = NormalizingFlow(input_size=83, output_size=1, num_flow_transforms=5)
de = DeepEnsemble(input_size=83, output_size=1, num_ensemble_members=10)
model = ISEFlow(de, nf)
model.fit(X_train, y_train, nf_epochs=500, de_epochs=200, X_val=X_val, y_val=y_val)
model.save("my_model/")
Uncertainty decomposition:
predictions, uncertainties = model.predict(X)
# uncertainties = {"total": ..., "epistemic": ..., "aleatoric": ...}
# total = epistemic + aleatoric (scalar sum per timestep)
ISEFlow_AIS / ISEFlow_GrIS (pretrained)
Convenience subclasses that load bundled pretrained weights and expose a
simplified predict(inputs) interface:
from ise.models.iseflow import ISEFlow_AIS
from ise.data.inputs import ISEFlowAISInputs
model = ISEFlow_AIS(version="v1.1.0") # loads pretrained weights from package
inputs = ISEFlowAISInputs(...) # or ISEFlowAISInputs.from_absolute_forcings(...)
predictions, uncertainties = model.predict(inputs)
# predictions: numpy array shape (86, 1), SLE in mm, years 2015-2100
Supported versions
v1.0.0: AIS only; includesmrro_anomalyas a forcing variable.v1.1.0(default): AIS + GrIS;mrro_anomalyremoved from AIS inputs.
Output smoothing
Both ISEFlow.predict() and ISEFlow_AIS/GrIS.predict() accept a
smoothing_window argument. When > 0, a uniform moving-average filter
of that width is applied after inverse-scaling so the smoothing acts on
physical SLE values rather than scaled outputs. Projection boundaries are
respected (no mixing between runs).
Deprecated classes
ISEFlow_AIS_DE_v1_0_0, ISEFlow_GrIS_DE_v1_0_0,
ISEFlow_AIS_NF_v1_0_0, ISEFlow_GrIS_NF_v1_0_0 are kept for loading
old v1.0.0 checkpoints only. Use ISEFlow_AIS / ISEFlow_GrIS instead.
- class ise.models.iseflow.ISEFlow(deep_ensemble, normalizing_flow)[source]
Bases:
ModuleISEFlow is a hybrid ice sheet emulator that combines a deep ensemble model and a normalizing flow model.
This class provides methods to train, predict, save, and load hybrid models for ice sheet emulation. It integrates a deep ensemble to capture epistemic uncertainties and a normalizing flow to model aleatoric uncertainties.
- device
The computing device (‘cuda’ if available, else ‘cpu’).
- Type:
str
- deep_ensemble
The deep ensemble model for epistemic uncertainty.
- Type:
- normalizing_flow
The normalizing flow model for aleatoric uncertainty.
- Type:
- trained
Flag indicating whether the model has been trained.
- Type:
bool
- scaler_path
Path to the scaler used for output transformation.
- Type:
str or None
- fit(X, y, nf_epochs, de_epochs, batch_size=64, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint_ensemble', early_stopping=True, sequence_length=5, patience=10, verbose=True)[source]
Trains the hybrid emulator using the provided data.
This method trains the normalizing flow model first, then uses its latent representations to train the deep ensemble model.
- Parameters:
X (array-like) – Input feature matrix.
y (array-like) – Target values.
nf_epochs (int) – Number of training epochs for the normalizing flow.
de_epochs (int) – Number of training epochs for the deep ensemble.
batch_size (int, optional) – Batch size for training. Defaults to 64.
X_val (array-like, optional) – Validation feature matrix. Defaults to None.
y_val (array-like, optional) – Validation target values. Defaults to None.
save_checkpoints (bool, optional) – Whether to save training checkpoints. Defaults to True.
checkpoint_path (str, optional) – Path prefix for saving model checkpoints. Defaults to ‘checkpoint_ensemble’.
early_stopping (bool, optional) – Whether to use early stopping. Defaults to True.
sequence_length (int, optional) – Sequence length for recurrent architectures. Defaults to 5.
patience (int, optional) – Number of epochs with no improvement before stopping. Defaults to 10.
verbose (bool, optional) – Whether to print training progress. Defaults to True.
- Raises:
Warning – If the model has already been trained.
- forward(x)[source]
Run a forward pass through the hybrid emulator.
- Parameters:
x (array-like) – Input feature matrix with shape
(N, num_features).- Returns:
(prediction, uncertainties)where:prediction (numpy.ndarray): Mean prediction across ensemble members.
uncertainties (dict): Keys
'total','epistemic','aleatoric'with numpy arrays giving per-row uncertainty in scaled (model) units.
- Return type:
tuple
- Warns:
UserWarning – If the model has not been trained.
- static load(model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None)[source]
Load a trained ISEFlow from saved deep ensemble and normalizing flow checkpoints.
Provide either
model_dir(which expectsdeep_ensemble.pthandnormalizing_flow.pthfiles inside it) or bothdeep_ensemble_pathandnormalizing_flow_pathexplicitly.- Parameters:
model_dir (str, optional) – Directory containing the saved sub-model files.
deep_ensemble_path (str, optional) – Explicit path to the saved deep ensemble.
normalizing_flow_path (str, optional) – Explicit path to the saved normalizing flow.
- Returns:
The loaded model, with
trained=True.- Return type:
- predict(x, output_scaler=True, smoothing_window=0)[source]
Predict SLE projections and uncertainties, applying inverse scaling and optional smoothing.
Smoothing is applied to the final unscaled predictions and uncertainties so the physical SLE curve is what gets smoothed (rather than scaled values).
- Parameters:
x (array-like) – Input feature matrix with shape
(N, num_features).output_scaler (bool or str, optional) – If
True(default), loads the scaler bundled with the pretrained model. IfFalse, returns un-rescaled predictions. If a string, loads the sklearn scaler at that path and uses it to inverse-transform the output.smoothing_window (int, optional) – Width of a centered moving-average smoother applied to the unscaled predictions and uncertainties.
0(default) disables smoothing.
- Returns:
(unscaled_predictions, uncertainties)where:unscaled_predictions (numpy.ndarray): Predictions in mm SLE.
uncertainties (dict): Keys
'total','epistemic','aleatoric'with numpy arrays in mm SLE.
- Return type:
tuple
- Warns:
UserWarning – If no scaler is available; predictions and uncertainties are then returned in the model’s scaled output space rather than mm SLE.
- save(save_dir, input_features=None, output_scaler_path=None)[source]
Saves the trained model and related components to a specified directory.
- Parameters:
save_dir (str) – Directory where the model should be saved.
input_features (list, optional) – List of input feature names. Defaults to None.
output_scaler_path (str, optional) – Path to the output scaler. Defaults to None.
- Raises:
ValueError – If the model has not been trained.
ValueError – If save_dir is a file instead of a directory.
ValueError – If input_features is not a list.
- class ise.models.iseflow.ISEFlow_AIS(version='v1.1.0')[source]
Bases:
ISEFlowPretrained ISEFlow emulator for the Antarctic Ice Sheet (AIS).
Loads pretrained weights for AIS (18 sectors, 8 km resolution) from HuggingFace Hub and exposes
predict(inputs)whereinputsis anISEFlowAISInputsinstance.Note
versionrefers to the ISEFlow model weights version, not the ise-py package version. Seeise.models.pretrained.ISEFLOW_LATEST_MODEL_VERSIONfor the current default.Supported model versions:
v1.0.0: includesmrro_anomalyas a forcing variable.v1.1.0(default):mrro_anomalyremoved; improved GrIS+AIS joint training.
- Parameters:
version (str, optional) – ISEFlow model weights version. One of
'v1.0.0'or'v1.1.0'. Defaults to the latest:'v1.1.0'.- Raises:
NotImplementedError – If an unsupported version string is provided.
- predict(inputs: ISEFlowAISInputs, smoothing_window: int = 0)[source]
Predicts AIS sea level contribution using the pretrained ISEFlow_AIS model.
Internally calls
process()to scale, add lag variables, and one-hot encode the inputs before running the hybrid forward pass.- Parameters:
inputs (ISEFlowAISInputs) – Validated input dataclass containing climate forcings and ISM configuration for a single sector.
smoothing_window (int, optional) – If > 0, applies a uniform moving-average smoother of this width to the output time series. Defaults to 0 (no smoothing).
- Returns:
A tuple containing:
predictions (numpy.ndarray, shape
(86, 1)): Unscaled sea level equivalent (SLE) projections in mm for years 2015-2100.uncertainties (dict): Dictionary with keys:
'total': total uncertainty (epistemic + aleatoric).'epistemic': uncertainty from ensemble disagreement.'aleatoric': uncertainty from normalizing-flow sampling.
- Return type:
tuple
- process(inputs: ISEFlowAISInputs)[source]
Preprocess ISEFlowAISInputs into the feature matrix expected by the model.
Applies input scaling (using the version-specific
scaler_X.pkl), adds 5-step lag variables, one-hot encodes ISM configuration columns, and pads any missing one-hot columns withFalse.- Parameters:
inputs (ISEFlowAISInputs) – Validated input dataclass for a single AIS sector.
- Returns:
Feature matrix aligned to the column order expected by the pretrained model weights for the current version.
- Return type:
pandas.DataFrame
- Raises:
ValueError – If
mrro_anomalyisNonewhen using v1.0.0.
- test(X_test)[source]
Tests the model on a test dataset.
- Parameters:
X_test (array-like) – Test feature matrix.
- Returns:
- A tuple containing:
unscaled_predictions (numpy.ndarray): Model predictions in the original scale.
- uncertainties (dict): Dictionary with keys:
’total’ (numpy.ndarray): Total uncertainty.
’epistemic’ (numpy.ndarray): Epistemic uncertainty.
’aleatoric’ (numpy.ndarray): Aleatoric uncertainty.
- Return type:
tuple
- class ise.models.iseflow.ISEFlow_AIS_DE_v1_0_0[source]
Bases:
DeepEnsembleDeprecated AIS deep ensemble (v1.0.0). Use
ISEFlow_AISinstead.This hard-coded 10-member LSTM ensemble was used in ISEFlow v1.0.0 for AIS emulation (input_size=99, includes
mrro_anomaly). It is kept for backward compatibility with saved v1.0.0 checkpoints only.Deprecated since version Use:
ISEFlow_AIS(version='v1.0.0')orISEFlow_AIS(version='v1.1.0')instead. This class will be removed in a future release.- input_size
99 (includes mrro_anomaly).
- Type:
int
- output_size
- Type:
int
- class ise.models.iseflow.ISEFlow_AIS_NF_v1_0_0(version='1.0.0')[source]
Bases:
NormalizingFlowDeprecated AIS normalizing flow (v1.0.0). Use
ISEFlow_AISinstead.This pre-configured NormalizingFlow was used in ISEFlow v1.0.0 for AIS aleatoric uncertainty estimation.
Deprecated since version Use:
ISEFlow_AIS(version='v1.0.0')orISEFlow_AIS(version='v1.1.0')instead. This class will be removed in a future release.
- class ise.models.iseflow.ISEFlow_GrIS(version='v1.1.0')[source]
Bases:
ISEFlowPretrained ISEFlow emulator for the Greenland Ice Sheet (GrIS).
Loads pretrained weights for GrIS (6 drainage basins, 5 km resolution) from HuggingFace Hub and exposes
predict(inputs)whereinputsis anISEFlowGrISInputsinstance.Note
versionrefers to the ISEFlow model weights version, not the ise-py package version. Seeise.models.pretrained.ISEFLOW_LATEST_MODEL_VERSIONfor the current default.Supported model versions:
v1.0.0: initial GrIS release.v1.1.0(default): improved AIS+GrIS joint training.
- Parameters:
version (str, optional) – ISEFlow model weights version. One of
'v1.0.0'or'v1.1.0'. Defaults to the latest:'v1.1.0'.- Raises:
NotImplementedError – If an unsupported version string is provided.
- predict(inputs: ISEFlowGrISInputs, smoothing_window: int = 0)[source]
Predicts GrIS sea level contribution using the pretrained ISEFlow_GrIS model.
Internally calls
process()to scale, add lag variables, and one-hot encode the inputs before running the hybrid forward pass.- Parameters:
inputs (ISEFlowGrISInputs) – Validated input dataclass containing climate forcings and ISM configuration for a single GrIS drainage basin.
smoothing_window (int, optional) – If > 0, applies a uniform moving-average smoother of this width to the output time series. Defaults to 0 (no smoothing).
- Returns:
A tuple containing:
predictions (numpy.ndarray, shape
(86, 1)): Unscaled sea level equivalent (SLE) projections in mm for years 2015-2100.uncertainties (dict): Dictionary with keys:
'total': total uncertainty (epistemic + aleatoric).'epistemic': uncertainty from ensemble disagreement.'aleatoric': uncertainty from normalizing-flow sampling.
- Return type:
tuple
- process(inputs: ISEFlowGrISInputs)[source]
Preprocess ISEFlowGrISInputs into the feature matrix expected by the model.
Applies input scaling (using the version-specific
scaler_X.pkl), adds 5-step lag variables, one-hot encodes ISM configuration columns, and pads any missing one-hot columns withFalse.- Parameters:
inputs (ISEFlowGrISInputs) – Validated input dataclass for a single GrIS basin.
- Returns:
Feature matrix aligned to the column order expected by the pretrained model weights for the current version.
- Return type:
pandas.DataFrame
- test(X_test)[source]
Tests the model on a test dataset.
- Parameters:
X_test (array-like) – Test feature matrix.
- Returns:
- A tuple containing:
unscaled_predictions (numpy.ndarray): Model predictions in the original scale.
- uncertainties (dict): Dictionary with keys:
’total’ (numpy.ndarray): Total uncertainty.
’epistemic’ (numpy.ndarray): Epistemic uncertainty.
’aleatoric’ (numpy.ndarray): Aleatoric uncertainty.
- Return type:
tuple
- class ise.models.iseflow.ISEFlow_GrIS_DE_v1_0_0[source]
Bases:
DeepEnsembleDeprecated GrIS deep ensemble (v1.0.0). Use
ISEFlow_GrISinstead.This hard-coded 10-member LSTM ensemble was used in ISEFlow v1.0.0 for GrIS emulation (input_size=90). It is kept for backward compatibility with saved v1.0.0 checkpoints only.
Deprecated since version Use:
ISEFlow_GrIS(version='v1.0.0')orISEFlow_GrIS(version='v1.1.0')instead. This class will be removed in a future release.- input_size
- Type:
int
- output_size
- Type:
int
- class ise.models.iseflow.ISEFlow_GrIS_NF_v1_0_0[source]
Bases:
NormalizingFlowDeprecated GrIS normalizing flow (v1.0.0). Use
ISEFlow_GrISinstead.This pre-configured NormalizingFlow was used in ISEFlow v1.0.0 for GrIS aleatoric uncertainty estimation (input_size=90).
Deprecated since version Use:
ISEFlow_GrIS(version='v1.0.0')orISEFlow_GrIS(version='v1.1.0')instead. This class will be removed in a future release.
- ise.models.iseflow.smooth_projections(data, window_size, projection_length=86)[source]
Apply smoothing to projections while respecting projection boundaries. Uses scipy’s uniform_filter1d for more effective smoothing.
- Parameters:
data (np.ndarray) – Array of shape (n_samples,) or (n_samples, 1) containing values to smooth
window_size (int) – Size of the smoothing window
projection_length (int) – Length of each projection segment (default: 86 years)
- Returns:
Smoothed data with same shape as input
- Return type:
np.ndarray
ise.models.deep_ensemble
Deep ensemble of LSTM models for epistemic uncertainty estimation.
This module provides DeepEnsemble, which wraps a collection of LSTM
instances and exposes a single forward() that returns the mean
prediction and the epistemic uncertainty (standard deviation across
ensemble members) simultaneously.
Epistemic uncertainty in ISEFlow
The ensemble captures uncertainty that arises from limited training data and
model capacity — the kind that would shrink if more ISMIP6 simulations were
available. Disagreement between members is used as a proxy: if all members
agree, the epistemic uncertainty is low; if they diverge, it is high. This
is combined additively with the aleatoric uncertainty from NormalizingFlow
to form the total reported uncertainty.
Ensemble construction
Members can be supplied explicitly (allowing heterogeneous architectures and loss functions, as in the pretrained ISEFlow weights) or auto-generated randomly:
from ise.models.deep_ensemble import DeepEnsemble
from ise.models.lstm import LSTM
import torch.nn as nn
# Explicit heterogeneous ensemble (matches pretrained v1.1.0 AIS members)
members = [
LSTM(1, 128, input_size=84, output_size=1, criterion=nn.HuberLoss()),
LSTM(2, 256, input_size=84, output_size=1, criterion=nn.MSELoss()),
# ... more members
]
de = DeepEnsemble(ensemble_members=members)
# Auto-generated random ensemble
de = DeepEnsemble(input_size=83, num_ensemble_members=10)
Note: input_size passed to DeepEnsemble is the feature dimensionality
before the NF latent is appended. The constructor automatically adds
latent_dim (default 1) so each LSTM member receives input_size + 1
features.
Training
fit() trains each member independently on the same (X_latent, y) data,
where X_latent = [X, z] and z is the latent from the pretrained
NormalizingFlow:
de.fit(X_latent, y, X_val=X_val_latent, y_val=y_val,
epochs=200, batch_size=128, sequence_length=5,
early_stopping=True, patience=15)
Persistence
save(path) writes a deep_ensemble.pth state dict, a
deep_ensemble_metadata.json with per-member architecture configs, and
individual ensemble_members/member_N.pth files. load(path) fully
reconstructs the ensemble from these artifacts.
- class ise.models.deep_ensemble.DeepEnsemble(ensemble_members=None, input_size=83, output_size=1, num_ensemble_members=3, output_sequence_length=86, latent_dim=1)[source]
Bases:
ModuleDeep Ensemble Model using multiple LSTMs for time series forecasting.
This class implements an ensemble of LSTM-based predictors. Each LSTM model is trained separately, and predictions from all ensemble members are aggregated to provide a mean prediction along with an epistemic uncertainty estimate.
- input_size
Size of the input features.
- Type:
int
- output_size
Size of the output features.
- Type:
int
- output_sequence_length
Length of the predicted output sequence.
- Type:
int
- loss_choices
List of loss functions used for different ensemble members.
- Type:
list
- ensemble_members
List of LSTM models used as ensemble members.
- Type:
list
- trained
Indicates whether all ensemble members have been trained.
- Type:
bool
- Parameters:
ensemble_members (list, optional) – Pretrained LSTM models. If None, a new ensemble is created.
input_size (int) – Number of input features.
output_size (int) – Number of output features.
num_ensemble_members (int) – Number of ensemble members to create if ensemble_members is None.
output_sequence_length (int) – Length of the output sequence to predict.
latent_dim (int) – Additional latent dimension added to the input.
- Raises:
ValueError – If ensemble_members is provided but does not contain only LSTM instances.
- fit(X, y, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint_ensemble', early_stopping=True, epochs=100, batch_size=128, sequence_length=5, patience=10, verbose=True)[source]
Trains each ensemble member on the provided data.
The ensemble members are trained separately, allowing for independent learning dynamics. Checkpoints can be saved for each model, and early stopping is available to prevent overfitting.
- Parameters:
X (Tensor) – Training input data.
y (Tensor) – Training target data.
X_val (Tensor, optional) – Validation input data for early stopping.
y_val (Tensor, optional) – Validation target data for early stopping.
save_checkpoints (bool, optional) – Whether to save checkpoints during training. Defaults to True.
checkpoint_path (str, optional) – Path prefix for saving model checkpoints.
early_stopping (bool, optional) – Whether to use early stopping. Defaults to True.
epochs (int, optional) – Number of training epochs. Defaults to 100.
batch_size (int, optional) – Batch size for training. Defaults to 128.
sequence_length (int, optional) – Length of input sequences. Defaults to 5.
patience (int, optional) – Number of epochs to wait before early stopping. Defaults to 10.
verbose (bool, optional) – Whether to print training progress. Defaults to True.
- Raises:
Warning – If the model has already been trained, a warning is issued before proceeding.
- forward(x)[source]
Performs a forward pass through the ensemble, aggregating predictions.
Each ensemble member makes a prediction, and the mean and standard deviation of these predictions are computed to provide an estimate of epistemic uncertainty.
- Parameters:
x (Tensor) – Input tensor of shape (batch_size, sequence_length, input_size).
- Returns:
Mean prediction across all ensemble members.
Epistemic uncertainty (standard deviation of predictions).
- Return type:
Tuple[Tensor, Tensor]
Warning
If the model is not trained, a warning is issued indicating that predictions may be unreliable.
- classmethod load(model_path)[source]
Loads a trained ensemble model from a file.
This method restores the ensemble’s state, including the metadata and individual LSTM members. The ensemble members are reinitialized and their state dictionaries are loaded from disk.
- Parameters:
model_path (str) – Path to the saved model file.
- Returns:
An instance of the loaded ensemble model.
- Return type:
- Raises:
FileNotFoundError – If any ensemble member’s file is missing.
ValueError – If the saved model type does not match DeepEnsemble.
Notes
The method ensures compatibility between the saved metadata and the loaded model.
Loss functions are restored using a predefined lookup.
The model is set to evaluation mode after loading.
- predict(x)[source]
Makes predictions using the trained ensemble.
This method calls forward while ensuring the model is in evaluation mode.
- Parameters:
x (Tensor) – Input tensor for prediction.
- Returns:
Mean predictions across ensemble members.
Uncertainty estimates (standard deviation of predictions).
- Return type:
Tuple[Tensor, Tensor]
- save(model_path)[source]
Saves the ensemble model and its metadata.
This method stores the model parameters, metadata, and each ensemble member’s state dictionary. The metadata includes information about the ensemble members, such as their architecture, loss function, and training status.
- Parameters:
model_path (str) – File path to save the model.
- Raises:
ValueError – If attempting to save the model before it has been trained.
Notes
The model directory is automatically created if it does not exist.
Each ensemble member is saved in a separate subdirectory.
After saving, any temporary checkpoint files are removed.
ise.models.lstm
Single LSTM network for time series sea-level projection.
This module provides the LSTM class — the constituent building block of
DeepEnsemble. Each instance is an independent stacked LSTM followed by a
two-layer fully-connected head:
LSTM layers (num_layers, hidden_size)
→ FC layer (hidden_size → 32) + ReLU
→ FC output layer (32 → output_size)
The architecture is deliberately simple: hidden-to-output mapping uses only
the final hidden state hn[-1], making this a many-to-one sequence model
that takes a window of sequence_length feature vectors and predicts a
single SLE value for the last timestep.
Usage — stand-alone
from ise.models.lstm import LSTM
import torch.nn as nn
model = LSTM(
lstm_num_layers=2,
lstm_hidden_size=256,
input_size=84, # features after NF latent concat
output_size=1,
criterion=nn.HuberLoss(),
)
model.fit(X_train, y_train, epochs=200, sequence_length=5,
X_val=X_val, y_val=y_val, early_stopping=True, patience=15)
preds = model.predict(X_test, sequence_length=5) # Tensor shape (N, 1)
model.save("lstm.pth")
model_loaded = LSTM.load("lstm.pth")
Usage — inside DeepEnsemble
LSTM instances are assembled into a DeepEnsemble by passing a list of
them to the constructor. The ensemble calls member.predict(x) on each
member and aggregates the results. In this context the LSTM is never
called with fit() directly — DeepEnsemble.fit() handles that.
Checkpointing and early stopping
fit() uses CheckpointSaver / EarlyStoppingCheckpointer from
ise.models.training. If a checkpoint already exists at the given path,
training resumes from the saved epoch. After training, the best checkpoint
is reloaded before fit() returns.
save() writes both a .pth state dict and a _metadata.json file
storing the full architecture and optimizer hyperparameters. load()
reads both to reconstruct the model without any prior constructor call.
- class ise.models.lstm.LSTM(lstm_num_layers, lstm_hidden_size, input_size=83, output_size=1, criterion=MSELoss(), output_sequence_length=86, optimizer=<class 'torch.optim.adamw.AdamW'>, lr=0.0001, wd=1e-06, dropout=0.0)[source]
Bases:
ModuleLong Short-Term Memory (LSTM) model for time series forecasting.
This class implements an LSTM network with multiple layers, dropout, and fully connected layers to generate predictions for sequential data.
- lstm_num_layers
Number of LSTM layers in the model.
- Type:
int
Number of hidden units in each LSTM layer.
- Type:
int
- input_size
Number of input features.
- Type:
int
- output_size
Number of output features.
- Type:
int
- output_sequence_length
Number of time steps predicted by the model.
- Type:
int
- device
Device on which the model runs (‘cuda’ or ‘cpu’).
- Type:
str
- lstm
LSTM layer for sequence modeling.
- Type:
nn.LSTM
- relu
ReLU activation function.
- Type:
nn.ReLU
- linear1
Intermediate fully connected layer.
- Type:
nn.Linear
- linear_out
Output layer mapping to final predictions.
- Type:
nn.Linear
- optimizer
Optimization algorithm used for training.
- Type:
torch.optim.Optimizer
- dropout
Dropout layer to prevent overfitting.
- Type:
nn.Dropout
- criterion
Loss function used for training.
- Type:
torch.nn.modules.loss._Loss
- trained
Flag indicating whether the model has been trained.
- Type:
bool
- Parameters:
lstm_num_layers (int) – Number of LSTM layers.
lstm_hidden_size (int) – Number of hidden units in each LSTM layer.
input_size (int, optional) – Number of input features. Defaults to 83.
output_size (int, optional) – Number of output features. Defaults to 1.
criterion (torch.nn.modules.loss._Loss, optional) – Loss function. Defaults to MSELoss.
output_sequence_length (int, optional) – Number of output time steps. Defaults to 86.
optimizer (torch.optim.Optimizer, optional) – Optimizer type. Defaults to AdamW.
- fit(X, y, epochs=100, sequence_length=5, batch_size=64, criterion=None, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint.pt', early_stopping=False, patience=10, verbose=True, dataclass=<class 'ise.data.dataclasses.EmulatorDataset'>, wandb_run=None)[source]
Trains the LSTM model on the provided data.
Supports optional checkpointing and early stopping. If a checkpoint exists, training resumes from the last saved state.
- Parameters:
X (Tensor or DataFrame) – Input training data.
y (Tensor or DataFrame) – Target values corresponding to the input data.
epochs (int, optional) – Number of epochs for training. Defaults to 100.
sequence_length (int, optional) – Length of input sequences. Defaults to 5.
batch_size (int, optional) – Batch size used in training. Defaults to 64.
criterion (torch.nn.modules.loss._Loss, optional) – Loss function. Defaults to None.
X_val (Tensor or DataFrame, optional) – Validation input data. Defaults to None.
y_val (Tensor or DataFrame, optional) – Validation target data. Defaults to None.
save_checkpoints (bool, optional) – Whether to save model checkpoints. Defaults to True.
checkpoint_path (str, optional) – Path to save model checkpoints. Defaults to ‘checkpoint.pt’.
early_stopping (bool, optional) – Whether to enable early stopping. Defaults to False.
patience (int, optional) – Number of epochs to wait before stopping. Defaults to 10.
verbose (bool, optional) – Whether to print training progress. Defaults to True.
dataclass (type, optional) – Dataset class for handling data. Defaults to EmulatorDataset.
wandb_run (wandb.run, optional) – Weights & Biases run for per-epoch metric logging. Defaults to None.
- Raises:
ValueError – If no loss function is provided.
Notes
If validation data is provided but early stopping is disabled, a warning is issued.
If a checkpoint exists, training resumes from the saved epoch.
If early stopping is enabled, the model stops training when validation loss stops improving.
- forward(x)[source]
Performs a forward pass through the LSTM network.
Given an input sequence, the LSTM processes the sequence to extract features, which are passed through a fully connected network to generate predictions.
- Parameters:
x (Tensor) – Input tensor of shape (batch_size, sequence_length, input_size).
- Returns:
Output tensor of shape (batch_size, output_size), representing the model’s predictions.
- Return type:
Tensor
- classmethod load(model_path: str) LSTM[source]
Loads a trained LSTM model from disk.
- Expects:
<model_path> (a .pth with state_dict)
<model_path>_metadata.json (hyperparams & config)
- Returns:
- A model instance reconstructed with saved hyperparams, loss,
and optimizer type (with saved lr/weight_decay).
- Return type:
- Raises:
FileNotFoundError – If weights or metadata files are missing.
ValueError – If the saved model_type does not match this class.
- predict(X, sequence_length=None, batch_size=64, dataclass=<class 'ise.data.dataclasses.EmulatorDataset'>)[source]
Generates predictions using the trained LSTM model.
The model processes input sequences and returns predictions. Predictions are computed in a batch-wise manner to optimize memory usage.
- Parameters:
X (Tensor or DataFrame) – Input data for prediction.
sequence_length (int, optional) – Length of input sequences. Defaults to 5.
batch_size (int, optional) – Batch size used for inference. Defaults to 64.
dataclass (type, optional) – Dataset class for handling data. Defaults to EmulatorDataset.
- Returns:
Predicted values for the input data.
- Return type:
Tensor
Notes
The model is set to evaluation mode before making predictions.
Data is converted to tensors if initially provided as pandas DataFrames.
- save(model_path: str)[source]
Saves the LSTM model weights and metadata.
Writes <model_path> (state_dict) and <model_path>_metadata.json (config).
Records architecture, optimizer type & hparams (lr/weight_decay), and loss name.
Removes the training checkpoint file if this instance has one.
- Parameters:
model_path (str) – Destination file path ending in ‘.pth’.
- Raises:
ValueError – If the model has not been trained yet.
ise.models.normalizing_flow
Conditional autoregressive normalizing flow for aleatoric uncertainty.
This module provides NormalizingFlow — a conditional density estimator
built with the nflows library. It models the conditional distribution
P(y | X) of sea level equivalent (SLE) output given climate forcing features
X, enabling both probabilistic sampling and aleatoric uncertainty estimation.
Architecture
Base distribution:
ConditionalDiagonalNormalwith a 2-layer MLP context encoder (input_size → flow_hidden_features → output_size * 2), mapping forcing features X to the mean and log-variance of the latent z.Transforms:
num_flow_transformsalternating pairs ofRandomPermutationandMaskedAffineAutoregressiveTransformsteps, each conditioned on the full feature vector X.Flow:
nflows.flows.base.Flowwrapping the composite transform and the base distribution.
Role in ISEFlow
The NormalizingFlow is trained first (before the DeepEnsemble)
via maximum likelihood (negative log-probability). Once trained it serves
two roles:
Latent features:
get_latent(X)samplesz ~ q(z | X)from the base distribution, providing a single low-dimensional context variable that theDeepEnsemblemembers receive as extra input alongside X.Aleatoric uncertainty:
aleatoric(X, num_samples)drawsnum_samplesSLE values from the full flow for each input row and returns the std across draws, interpreted as the inherent data uncertainty (aleatoric component).
Usage
from ise.models.normalizing_flow import NormalizingFlow
nf = NormalizingFlow(input_size=83, output_size=1, num_flow_transforms=5,
flow_hidden_features=16)
nf.fit(X_train, y_train, epochs=500, batch_size=64,
X_val=X_val, y_val=y_val, early_stopping=True, patience=20)
z = nf.get_latent(X_val) # shape (N, 1) — latent context for DeepEnsemble
uncertainty = nf.aleatoric(X_val, num_samples=100) # shape (N,)
samples = nf.sample(X_val[:5], num_samples=200) # shape (5, 200)
nf.save("nf_checkpoint.pth")
nf_loaded = NormalizingFlow.load("nf_checkpoint.pth")
Checkpointing and early stopping
fit() integrates with CheckpointSaver / EarlyStoppingCheckpointer
from ise.models.training. If a checkpoint file already exists at
checkpoint_path, training resumes from the saved epoch. After training
completes the best checkpoint is loaded back and the temporary file is deleted.
Optional wandb logging is supported via the wandb_run argument.
- class ise.models.normalizing_flow.NormalizingFlow(input_size=43, output_size=1, output_sequence_length=86, num_flow_transforms=5, flow_hidden_features=16, legacy_v1_0_0=False)[source]
Bases:
ModuleA normalizing flow model for probabilistic modeling using invertible transformations.
This model utilizes a sequence of invertible transformations to model complex probability distributions. It is built with a base distribution and a series of transformations, leveraging autoregressive neural networks.
- num_flow_transforms
Number of flow transformations in the model.
- Type:
int
- num_input_features
Number of input features.
- Type:
int
- num_predicted_sle
Number of predicted sea-level equivalent values.
- Type:
int
Number of hidden features in the flow model.
- Type:
int
- output_sequence_length
Length of the output sequence.
- Type:
int
- device
Device on which the model is run (“cuda” or “cpu”).
- Type:
str
- base_distribution
The base normal distribution conditioned on input features.
- Type:
distributions.normal.ConditionalDiagonalNormal
- t
Composite transformation for the normalizing flow.
- Type:
transforms.base.CompositeTransform
- flow
The normalizing flow model.
- Type:
flows.base.Flow
- optimizer
Optimizer for training the model.
- Type:
torch.optim.Adam
- criterion
Log probability function used as the loss criterion.
- Type:
callable
- trained
Flag indicating if the model has been trained.
- Type:
bool
- aleatoric(features, num_samples, batch_size=128)[source]
Estimate aleatoric uncertainty as the std across flow samples per input row.
For each input row, draws
num_samplessamples from the conditional flow and returns the standard deviation across those samples. NaN samples are ignored when computing the std.- Parameters:
features (array-like or torch.Tensor) – Input features of shape
(N, num_features).num_samples (int) – Number of flow samples drawn per input row.
batch_size (int, optional) – Number of rows processed per forward pass. Defaults to 128.
- Returns:
Per-row aleatoric uncertainty, shape
(N,).- Return type:
numpy.ndarray
- fit(X, y, X_val=None, y_val=None, epochs=100, batch_size=64, save_checkpoints=True, checkpoint_path='checkpoint.pt', early_stopping=True, patience=10, verbose=True, wandb_run=None, lr=0.0001, wd=1e-06)[source]
Train the normalizing flow via maximum likelihood (negative log-probability).
If
checkpoint_pathalready exists, training resumes from the saved epoch. After training, the best checkpoint is loaded back into the model and the temporary file is deleted.- Parameters:
X (array-like) – Input features of shape
(num_samples, num_features).y (array-like) – Target values of shape
(num_samples, output_size).X_val (array-like, optional) – Validation features for early stopping.
y_val (array-like, optional) – Validation targets for early stopping.
epochs (int, optional) – Number of training epochs. Defaults to 100.
batch_size (int, optional) – Batch size for training. Defaults to 64.
save_checkpoints (bool, optional) – Whether to save model checkpoints. Defaults to True.
checkpoint_path (str, optional) – Path to save model checkpoints. Defaults to
'checkpoint.pt'.early_stopping (bool, optional) – Whether to use early stopping. Defaults to True.
patience (int, optional) – Number of epochs to wait before early stopping. Defaults to 10.
verbose (bool, optional) – Whether to print training progress. Defaults to True.
wandb_run (wandb.run, optional) – Weights & Biases run for logging. Defaults to None.
lr (float, optional) – AdamW learning rate. Defaults to
1e-4.wd (float, optional) – AdamW weight decay. Defaults to
1e-6.
- get_latent(x, latent_dim=1)[source]
Computes the latent space representation of the given input.
Two approaches are used depending on the model version:
v1.0.0 (legacy): deterministically pushes a zero vector through the forward transform conditioned on
x. Samexalways yields the samez, so the flow acts as a learned deterministic feature extractor.v1.1.0+ (default): draws
latent_dimsamples from the conditional base distribution.zis a stochastic summary ofx, which is more statistically grounded — the latent reflects the modeled conditional distribution rather than a single fixed point.
These approaches behave differently and produce different downstream DeepEnsemble inputs; the choice may be worth revisiting in future versions. The legacy path is more pragmatic and reproducible at inference, while the sampling path aligns more closely with the probabilistic interpretation of the flow.
- Parameters:
x (array-like or torch.Tensor) – Input data of shape (num_samples, num_features).
latent_dim (int, optional) – Number of latent samples to draw (post-v1.0.0). Defaults to 1.
- Returns:
Latent space representation of the input data.
- Return type:
torch.Tensor
- static load(path)[source]
Loads a trained normalizing flow model from a saved checkpoint.
- Parameters:
path (str) – Path to the saved model checkpoint.
- Returns:
A restored instance of the NormalizingFlow model.
- Return type:
- sample(features, num_samples, return_type='numpy')[source]
Generates samples from the trained normalizing flow model.
- Parameters:
features (array-like or torch.Tensor) – Input features to condition the samples on.
num_samples (int) – Number of samples to generate per input feature set.
return_type (str, optional) – Return type, either “numpy” or “tensor”. Defaults to “numpy”.
- Returns:
Generated samples of shape (num_samples, output_size).
- Return type:
np.ndarray or torch.Tensor
ise.models.loss
Custom loss functions for ice-sheet emulator training.
Standard MSE treats all timesteps and sectors equally, but ice sheet projections have two properties that motivate custom losses:
Extreme-value importance — large SLE departures (high melt scenarios) are scientifically most consequential and are underrepresented in the training set. Weighted losses upweight these samples.
Spatial smoothness — when predicting full 2-D grid fields (rather than sector-averaged scalars), neighbouring cells should vary smoothly. The grid-aware losses add a total-variation regularisation term.
All classes are torch.nn.Module subclasses with the standard
forward(input, target) signature.
Loss classes
- WeightedMSELoss:
MSE with per-sample weights proportional to the deviation of the target from the dataset mean, normalised by the dataset std. Extreme targets receive up to
(1 + weight_factor * |z|) × MSEpenalty:from ise.models.loss import WeightedMSELoss criterion = WeightedMSELoss(data_mean=y_mean, data_std=y_std, weight_factor=2.0) loss = criterion(predictions, targets)
- WeightedMSELossWithSignPenalty:
Extends
WeightedMSELosswith an extra additive penalty when the predicted sign differs from the true sign. Useful for preventing the model from predicting sea-level fall when rise is expected.- WeightedMSEPCALoss:
Adds user-supplied per-batch
custom_weightson top of the deviation weights. Intended for PCA-based training where different principal components should have different loss contributions.- WeightedGridLoss:
Pixel-wise weighted MSE + total variation regularisation (TVR) for full 2-D grid predictions. TVR penalises large spatial gradients in
predictedto encourage smooth output fields.- GridCriterion:
Simpler grid loss: pixel-wise MSE + TVR with a fixed
smoothness_weight, without the extreme-value weighting.- WeightedPCALoss:
MSE with fixed per-component weights for PCA coefficient regression. The first (dominant) principal component can be penalised more heavily.
- MSEDeviationLoss:
MSE plus an extra
penalty_multiplier * MSEterm applied to samples whose absolute error exceeds athreshold.
- class ise.models.loss.GridCriterion[source]
Bases:
ModuleCustom loss function enforcing spatial smoothness using total variation regularization.
This function encourages smoothness in spatial predictions by penalizing large variations.
- - total_variation_regularization
Computes the smoothness loss.
- - forward
Computes the final loss.
- forward(true, predicted, smoothness_weight=0.001)[source]
Computes the final loss by combining pixel-wise MSE and TVR.
- Parameters:
true (Tensor) – Ground truth values.
predicted (Tensor) – Model predictions.
smoothness_weight (float, optional) – Weight for TVR. Defaults to 0.001.
- Returns:
Computed loss.
- Return type:
Tensor
- class ise.models.loss.MSEDeviationLoss(threshold=1.0, penalty_multiplier=2.0)[source]
Bases:
ModuleCustom MSE Loss with an additional penalty for large deviations.
This function penalizes predictions that deviate significantly from the target.
- threshold
Deviation threshold for applying penalties.
- Type:
float
- penalty_multiplier
Multiplier controlling penalty severity.
- Type:
float
- - forward
Computes the loss with deviation penalties.
- class ise.models.loss.WeightedGridLoss[source]
Bases:
ModuleCustom loss function that penalizes errors based on the total variation of a grid.
This loss function consists of two components: 1. Pixel-wise Weighted Mean Squared Error (MSE): Higher weight is assigned to extreme values. 2. Total Variation Regularization (TVR): Enforces spatial smoothness by penalizing large differences between adjacent grid values.
- device
The device on which the model runs (‘cuda’ or ‘cpu’).
- Type:
str
- forward(true, predicted, smoothness_weight=0.001, extreme_value_threshold=1e-06)[source]
Computes the final weighted loss combining pixel-wise MSE and TVR.
- Parameters:
true (Tensor) – Ground truth values.
predicted (Tensor) – Model predictions.
smoothness_weight (float, optional) – Weighting factor for the TVR loss. Defaults to 0.001.
extreme_value_threshold (float, optional) – Threshold to define extreme values. Defaults to 1e-6.
- Returns:
The total computed loss.
- Return type:
Tensor
- total_variation_regularization(grid)[source]
Computes the total variation regularization (TVR) loss for spatial smoothness.
- Parameters:
grid (Tensor) – A 2D tensor representing spatial data.
- Returns:
The total variation loss.
- Return type:
Tensor
- weighted_pixelwise_mse(true, predicted, weights)[source]
Computes the pixel-wise mean squared error (MSE) with custom weights.
- Parameters:
true (Tensor) – Ground truth values.
predicted (Tensor) – Model predictions.
weights (Tensor) – Weighting factor for each pixel.
- Returns:
Weighted mean squared error.
- Return type:
Tensor
- class ise.models.loss.WeightedMSELoss(data_mean, data_std, weight_factor=1.0)[source]
Bases:
ModuleCustom loss function that applies a weighted penalty to extreme values.
This function increases the weight of extreme values based on their deviation from the dataset mean, normalizing by the standard deviation.
- data_mean
Mean value of the dataset.
- Type:
Tensor
- data_std
Standard deviation of the dataset.
- Type:
Tensor
- weight_factor
Factor controlling how much extreme values are penalized.
- Type:
Tensor
- - forward
Computes the weighted mean squared error loss.
- class ise.models.loss.WeightedMSELossWithSignPenalty(data_mean, data_std, weight_factor=1.0, sign_penalty_factor=1.0)[source]
Bases:
ModuleCustom loss function that penalizes errors on extreme values and opposite sign predictions.
This function extends WeightedMSELoss by adding a penalty when the sign of the prediction differs from the target.
- data_mean
Mean of the dataset.
- Type:
Tensor
- data_std
Standard deviation of the dataset.
- Type:
Tensor
- weight_factor
Factor controlling extreme value weighting.
- Type:
Tensor
- sign_penalty_factor
Factor controlling penalty for opposite sign predictions.
- Type:
Tensor
- - forward
Computes the weighted loss with sign penalties.
- class ise.models.loss.WeightedMSEPCALoss(data_mean, data_std, weight_factor=1.0, custom_weights=None)[source]
Bases:
ModuleExtension of WeightedMSELoss that allows for custom per-batch weighting.
This loss function enables additional user-defined weights to further adjust penalties for different predictions.
- data_mean
Mean of the dataset.
- Type:
Tensor
- data_std
Standard deviation of the dataset.
- Type:
Tensor
- weight_factor
Controls the penalty for extreme values.
- Type:
Tensor
- custom_weights
User-defined weight tensor.
- Type:
Tensor, optional
- - forward
Computes the batch-weighted MSE loss.
- class ise.models.loss.WeightedPCALoss(component_weights, reduction='mean')[source]
Bases:
ModuleCustom loss function applying different weights to errors in principal component analysis.
This function allows assigning higher penalties to the first components.
- component_weights
Weighting factors for each principal component.
- Type:
Tensor
- reduction
Specifies reduction mode (‘mean’, ‘sum’, or ‘none’).
- Type:
str
- - forward
Computes the weighted PCA loss.
ise.models.training
Checkpointing and early-stopping callbacks for PyTorch model training.
Both LSTM.fit() and NormalizingFlow.fit() accept a checkpoint_path
and use the classes in this module to save the best model state and optionally
stop training when the validation loss stops improving.
Classes
- CheckpointSaver:
Saves a full checkpoint dict
{epoch, model_state_dict, optimizer_state_dict, best_loss}whenever the monitored loss improves. The saved file can be passed back tofit()on a later run to resume training from where it left off:from ise.models.training import CheckpointSaver saver = CheckpointSaver(model, optimizer, "checkpoint.pt", verbose=True) for epoch in range(1, epochs + 1): loss = train_one_epoch(...) saver(loss, epoch) # saves only if loss < best_loss
- EarlyStoppingCheckpointer (extends CheckpointSaver):
Adds a patience counter on top of
CheckpointSaver. Setsself.early_stop = Truewhen the loss has not improved forpatienceconsecutive calls. The training loop should check this flag and break:from ise.models.training import EarlyStoppingCheckpointer stopper = EarlyStoppingCheckpointer(model, optimizer, "ckpt.pt", patience=10, verbose=True) for epoch in range(1, max_epochs + 1): val_loss = evaluate(...) stopper(val_loss, epoch) if stopper.early_stop: print("Early stopping") break # After the loop, load the best checkpoint: stopper.load_checkpoint()
- class ise.models.training.CheckpointSaver(model: Module, optimizer: Optimizer, checkpoint_path: str, verbose: bool = False)[source]
Bases:
objectA class to handle saving and loading of model checkpoints during training.
This class monitors the model’s loss and saves the model’s state when an improvement is detected. It can also be configured to save the model at every epoch.
- checkpoint_path
Path where the checkpoint will be saved.
- Type:
str
- model
The PyTorch model being trained.
- Type:
torch.nn.Module
- optimizer
The optimizer used during training.
- Type:
torch.optim.Optimizer
- best_loss
The best recorded loss value. Initially set to infinity.
- Type:
float
- verbose
If True, logs messages when a checkpoint is saved.
- Type:
bool
- log
Stores log messages for saving actions.
- Type:
str or None
- load_checkpoint(path: str | None = None)[source]
Loads a checkpoint and restores the model and optimizer states.
- Parameters:
path (str, optional) – The file path to load the checkpoint from. If None, the default path is used.
- Returns:
The epoch number from which training should resume.
- Return type:
int
- save_checkpoint(epoch, loss, path: str | None = None)[source]
Saves the model checkpoint, including model state, optimizer state, and epoch.
- Parameters:
epoch (int) – The current epoch number.
loss (float) – The loss value associated with this checkpoint.
path (str, optional) – The file path to save the checkpoint. If None, the default path is used.
- class ise.models.training.EarlyStoppingCheckpointer(model, optimizer, checkpoint_path='checkpoint.pt', patience=10, verbose=False)[source]
Bases:
CheckpointSaverA class that extends CheckpointSaver to implement early stopping.
This class tracks model performance and stops training when the validation loss does not improve for a specified number of epochs (patience).
- patience
The number of epochs with no improvement before stopping.
- Type:
int
- counter
Tracks the number of epochs since the last improvement.
- Type:
int
- early_stop
Flag indicating whether early stopping should occur.
- Type:
bool
ise.models.pretrained
Pretrained ISEFlow weight management.
Weights are hosted on HuggingFace Hub at pvankatwyk/ISEFlow and are
downloaded automatically on first use via huggingface_hub. The downloaded
files are cached in the default HuggingFace cache directory
(~/.cache/huggingface/hub or $HF_HOME).
During local development, if the HuggingFace download fails (e.g. no internet
access) and local weights exist under ise/models/pretrained/ISEFlow/, the
loader falls back to those local paths transparently.
- ise.models.pretrained.ISEFLOW_LATEST_MODEL_VERSION = 'v1.1.0'
The most recent ISEFlow pretrained model version. Distinct from the ise-py package version.
- ise.models.pretrained.get_model_dir(version: str, ice_sheet: str) str[source]
Return the local directory containing weights for a given model version.
Downloads the weights from HuggingFace Hub if not already cached. Falls back to the bundled local path when HF is unavailable and local weights exist (development / air-gapped environments).
- Parameters:
version (str) – Model version string, e.g.
'v1.0.0'or'v1.1.0'.ice_sheet (str) – Ice sheet identifier —
'AIS'or'GrIS'.
- Returns:
Absolute path to the directory containing
deep_ensemble.pth,normalizing_flow.pth,scaler_X.pkl, andscaler_y.pkl.- Return type:
str
Module contents
Ice sheet emulator models: ISEFlow, predictors, density estimators, and utilities.
This package provides ISEFlow (hybrid deep ensemble + normalizing flow), LSTM and DeepEnsemble predictors, NormalizingFlow density estimators, loss modules, and pretrained model loading for AIS and GrIS.
- class ise.models.DeepEnsemble(ensemble_members=None, input_size=83, output_size=1, num_ensemble_members=3, output_sequence_length=86, latent_dim=1)[source]
Bases:
ModuleDeep Ensemble Model using multiple LSTMs for time series forecasting.
This class implements an ensemble of LSTM-based predictors. Each LSTM model is trained separately, and predictions from all ensemble members are aggregated to provide a mean prediction along with an epistemic uncertainty estimate.
- input_size
Size of the input features.
- Type:
int
- output_size
Size of the output features.
- Type:
int
- output_sequence_length
Length of the predicted output sequence.
- Type:
int
- loss_choices
List of loss functions used for different ensemble members.
- Type:
list
- ensemble_members
List of LSTM models used as ensemble members.
- Type:
list
- trained
Indicates whether all ensemble members have been trained.
- Type:
bool
- Parameters:
ensemble_members (list, optional) – Pretrained LSTM models. If None, a new ensemble is created.
input_size (int) – Number of input features.
output_size (int) – Number of output features.
num_ensemble_members (int) – Number of ensemble members to create if ensemble_members is None.
output_sequence_length (int) – Length of the output sequence to predict.
latent_dim (int) – Additional latent dimension added to the input.
- Raises:
ValueError – If ensemble_members is provided but does not contain only LSTM instances.
- fit(X, y, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint_ensemble', early_stopping=True, epochs=100, batch_size=128, sequence_length=5, patience=10, verbose=True)[source]
Trains each ensemble member on the provided data.
The ensemble members are trained separately, allowing for independent learning dynamics. Checkpoints can be saved for each model, and early stopping is available to prevent overfitting.
- Parameters:
X (Tensor) – Training input data.
y (Tensor) – Training target data.
X_val (Tensor, optional) – Validation input data for early stopping.
y_val (Tensor, optional) – Validation target data for early stopping.
save_checkpoints (bool, optional) – Whether to save checkpoints during training. Defaults to True.
checkpoint_path (str, optional) – Path prefix for saving model checkpoints.
early_stopping (bool, optional) – Whether to use early stopping. Defaults to True.
epochs (int, optional) – Number of training epochs. Defaults to 100.
batch_size (int, optional) – Batch size for training. Defaults to 128.
sequence_length (int, optional) – Length of input sequences. Defaults to 5.
patience (int, optional) – Number of epochs to wait before early stopping. Defaults to 10.
verbose (bool, optional) – Whether to print training progress. Defaults to True.
- Raises:
Warning – If the model has already been trained, a warning is issued before proceeding.
- forward(x)[source]
Performs a forward pass through the ensemble, aggregating predictions.
Each ensemble member makes a prediction, and the mean and standard deviation of these predictions are computed to provide an estimate of epistemic uncertainty.
- Parameters:
x (Tensor) – Input tensor of shape (batch_size, sequence_length, input_size).
- Returns:
Mean prediction across all ensemble members.
Epistemic uncertainty (standard deviation of predictions).
- Return type:
Tuple[Tensor, Tensor]
Warning
If the model is not trained, a warning is issued indicating that predictions may be unreliable.
- classmethod load(model_path)[source]
Loads a trained ensemble model from a file.
This method restores the ensemble’s state, including the metadata and individual LSTM members. The ensemble members are reinitialized and their state dictionaries are loaded from disk.
- Parameters:
model_path (str) – Path to the saved model file.
- Returns:
An instance of the loaded ensemble model.
- Return type:
- Raises:
FileNotFoundError – If any ensemble member’s file is missing.
ValueError – If the saved model type does not match DeepEnsemble.
Notes
The method ensures compatibility between the saved metadata and the loaded model.
Loss functions are restored using a predefined lookup.
The model is set to evaluation mode after loading.
- predict(x)[source]
Makes predictions using the trained ensemble.
This method calls forward while ensuring the model is in evaluation mode.
- Parameters:
x (Tensor) – Input tensor for prediction.
- Returns:
Mean predictions across ensemble members.
Uncertainty estimates (standard deviation of predictions).
- Return type:
Tuple[Tensor, Tensor]
- save(model_path)[source]
Saves the ensemble model and its metadata.
This method stores the model parameters, metadata, and each ensemble member’s state dictionary. The metadata includes information about the ensemble members, such as their architecture, loss function, and training status.
- Parameters:
model_path (str) – File path to save the model.
- Raises:
ValueError – If attempting to save the model before it has been trained.
Notes
The model directory is automatically created if it does not exist.
Each ensemble member is saved in a separate subdirectory.
After saving, any temporary checkpoint files are removed.
- class ise.models.ISEFlow(deep_ensemble, normalizing_flow)[source]
Bases:
ModuleISEFlow is a hybrid ice sheet emulator that combines a deep ensemble model and a normalizing flow model.
This class provides methods to train, predict, save, and load hybrid models for ice sheet emulation. It integrates a deep ensemble to capture epistemic uncertainties and a normalizing flow to model aleatoric uncertainties.
- device
The computing device (‘cuda’ if available, else ‘cpu’).
- Type:
str
- deep_ensemble
The deep ensemble model for epistemic uncertainty.
- Type:
- normalizing_flow
The normalizing flow model for aleatoric uncertainty.
- Type:
- trained
Flag indicating whether the model has been trained.
- Type:
bool
- scaler_path
Path to the scaler used for output transformation.
- Type:
str or None
- fit(X, y, nf_epochs, de_epochs, batch_size=64, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint_ensemble', early_stopping=True, sequence_length=5, patience=10, verbose=True)[source]
Trains the hybrid emulator using the provided data.
This method trains the normalizing flow model first, then uses its latent representations to train the deep ensemble model.
- Parameters:
X (array-like) – Input feature matrix.
y (array-like) – Target values.
nf_epochs (int) – Number of training epochs for the normalizing flow.
de_epochs (int) – Number of training epochs for the deep ensemble.
batch_size (int, optional) – Batch size for training. Defaults to 64.
X_val (array-like, optional) – Validation feature matrix. Defaults to None.
y_val (array-like, optional) – Validation target values. Defaults to None.
save_checkpoints (bool, optional) – Whether to save training checkpoints. Defaults to True.
checkpoint_path (str, optional) – Path prefix for saving model checkpoints. Defaults to ‘checkpoint_ensemble’.
early_stopping (bool, optional) – Whether to use early stopping. Defaults to True.
sequence_length (int, optional) – Sequence length for recurrent architectures. Defaults to 5.
patience (int, optional) – Number of epochs with no improvement before stopping. Defaults to 10.
verbose (bool, optional) – Whether to print training progress. Defaults to True.
- Raises:
Warning – If the model has already been trained.
- forward(x)[source]
Run a forward pass through the hybrid emulator.
- Parameters:
x (array-like) – Input feature matrix with shape
(N, num_features).- Returns:
(prediction, uncertainties)where:prediction (numpy.ndarray): Mean prediction across ensemble members.
uncertainties (dict): Keys
'total','epistemic','aleatoric'with numpy arrays giving per-row uncertainty in scaled (model) units.
- Return type:
tuple
- Warns:
UserWarning – If the model has not been trained.
- static load(model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None)[source]
Load a trained ISEFlow from saved deep ensemble and normalizing flow checkpoints.
Provide either
model_dir(which expectsdeep_ensemble.pthandnormalizing_flow.pthfiles inside it) or bothdeep_ensemble_pathandnormalizing_flow_pathexplicitly.- Parameters:
model_dir (str, optional) – Directory containing the saved sub-model files.
deep_ensemble_path (str, optional) – Explicit path to the saved deep ensemble.
normalizing_flow_path (str, optional) – Explicit path to the saved normalizing flow.
- Returns:
The loaded model, with
trained=True.- Return type:
- predict(x, output_scaler=True, smoothing_window=0)[source]
Predict SLE projections and uncertainties, applying inverse scaling and optional smoothing.
Smoothing is applied to the final unscaled predictions and uncertainties so the physical SLE curve is what gets smoothed (rather than scaled values).
- Parameters:
x (array-like) – Input feature matrix with shape
(N, num_features).output_scaler (bool or str, optional) – If
True(default), loads the scaler bundled with the pretrained model. IfFalse, returns un-rescaled predictions. If a string, loads the sklearn scaler at that path and uses it to inverse-transform the output.smoothing_window (int, optional) – Width of a centered moving-average smoother applied to the unscaled predictions and uncertainties.
0(default) disables smoothing.
- Returns:
(unscaled_predictions, uncertainties)where:unscaled_predictions (numpy.ndarray): Predictions in mm SLE.
uncertainties (dict): Keys
'total','epistemic','aleatoric'with numpy arrays in mm SLE.
- Return type:
tuple
- Warns:
UserWarning – If no scaler is available; predictions and uncertainties are then returned in the model’s scaled output space rather than mm SLE.
- save(save_dir, input_features=None, output_scaler_path=None)[source]
Saves the trained model and related components to a specified directory.
- Parameters:
save_dir (str) – Directory where the model should be saved.
input_features (list, optional) – List of input feature names. Defaults to None.
output_scaler_path (str, optional) – Path to the output scaler. Defaults to None.
- Raises:
ValueError – If the model has not been trained.
ValueError – If save_dir is a file instead of a directory.
ValueError – If input_features is not a list.
- class ise.models.ISEFlow_AIS(version='v1.1.0')[source]
Bases:
ISEFlowPretrained ISEFlow emulator for the Antarctic Ice Sheet (AIS).
Loads pretrained weights for AIS (18 sectors, 8 km resolution) from HuggingFace Hub and exposes
predict(inputs)whereinputsis anISEFlowAISInputsinstance.Note
versionrefers to the ISEFlow model weights version, not the ise-py package version. Seeise.models.pretrained.ISEFLOW_LATEST_MODEL_VERSIONfor the current default.Supported model versions:
v1.0.0: includesmrro_anomalyas a forcing variable.v1.1.0(default):mrro_anomalyremoved; improved GrIS+AIS joint training.
- Parameters:
version (str, optional) – ISEFlow model weights version. One of
'v1.0.0'or'v1.1.0'. Defaults to the latest:'v1.1.0'.- Raises:
NotImplementedError – If an unsupported version string is provided.
- predict(inputs: ISEFlowAISInputs, smoothing_window: int = 0)[source]
Predicts AIS sea level contribution using the pretrained ISEFlow_AIS model.
Internally calls
process()to scale, add lag variables, and one-hot encode the inputs before running the hybrid forward pass.- Parameters:
inputs (ISEFlowAISInputs) – Validated input dataclass containing climate forcings and ISM configuration for a single sector.
smoothing_window (int, optional) – If > 0, applies a uniform moving-average smoother of this width to the output time series. Defaults to 0 (no smoothing).
- Returns:
A tuple containing:
predictions (numpy.ndarray, shape
(86, 1)): Unscaled sea level equivalent (SLE) projections in mm for years 2015-2100.uncertainties (dict): Dictionary with keys:
'total': total uncertainty (epistemic + aleatoric).'epistemic': uncertainty from ensemble disagreement.'aleatoric': uncertainty from normalizing-flow sampling.
- Return type:
tuple
- process(inputs: ISEFlowAISInputs)[source]
Preprocess ISEFlowAISInputs into the feature matrix expected by the model.
Applies input scaling (using the version-specific
scaler_X.pkl), adds 5-step lag variables, one-hot encodes ISM configuration columns, and pads any missing one-hot columns withFalse.- Parameters:
inputs (ISEFlowAISInputs) – Validated input dataclass for a single AIS sector.
- Returns:
Feature matrix aligned to the column order expected by the pretrained model weights for the current version.
- Return type:
pandas.DataFrame
- Raises:
ValueError – If
mrro_anomalyisNonewhen using v1.0.0.
- test(X_test)[source]
Tests the model on a test dataset.
- Parameters:
X_test (array-like) – Test feature matrix.
- Returns:
- A tuple containing:
unscaled_predictions (numpy.ndarray): Model predictions in the original scale.
- uncertainties (dict): Dictionary with keys:
’total’ (numpy.ndarray): Total uncertainty.
’epistemic’ (numpy.ndarray): Epistemic uncertainty.
’aleatoric’ (numpy.ndarray): Aleatoric uncertainty.
- Return type:
tuple
- class ise.models.ISEFlow_GrIS(version='v1.1.0')[source]
Bases:
ISEFlowPretrained ISEFlow emulator for the Greenland Ice Sheet (GrIS).
Loads pretrained weights for GrIS (6 drainage basins, 5 km resolution) from HuggingFace Hub and exposes
predict(inputs)whereinputsis anISEFlowGrISInputsinstance.Note
versionrefers to the ISEFlow model weights version, not the ise-py package version. Seeise.models.pretrained.ISEFLOW_LATEST_MODEL_VERSIONfor the current default.Supported model versions:
v1.0.0: initial GrIS release.v1.1.0(default): improved AIS+GrIS joint training.
- Parameters:
version (str, optional) – ISEFlow model weights version. One of
'v1.0.0'or'v1.1.0'. Defaults to the latest:'v1.1.0'.- Raises:
NotImplementedError – If an unsupported version string is provided.
- predict(inputs: ISEFlowGrISInputs, smoothing_window: int = 0)[source]
Predicts GrIS sea level contribution using the pretrained ISEFlow_GrIS model.
Internally calls
process()to scale, add lag variables, and one-hot encode the inputs before running the hybrid forward pass.- Parameters:
inputs (ISEFlowGrISInputs) – Validated input dataclass containing climate forcings and ISM configuration for a single GrIS drainage basin.
smoothing_window (int, optional) – If > 0, applies a uniform moving-average smoother of this width to the output time series. Defaults to 0 (no smoothing).
- Returns:
A tuple containing:
predictions (numpy.ndarray, shape
(86, 1)): Unscaled sea level equivalent (SLE) projections in mm for years 2015-2100.uncertainties (dict): Dictionary with keys:
'total': total uncertainty (epistemic + aleatoric).'epistemic': uncertainty from ensemble disagreement.'aleatoric': uncertainty from normalizing-flow sampling.
- Return type:
tuple
- process(inputs: ISEFlowGrISInputs)[source]
Preprocess ISEFlowGrISInputs into the feature matrix expected by the model.
Applies input scaling (using the version-specific
scaler_X.pkl), adds 5-step lag variables, one-hot encodes ISM configuration columns, and pads any missing one-hot columns withFalse.- Parameters:
inputs (ISEFlowGrISInputs) – Validated input dataclass for a single GrIS basin.
- Returns:
Feature matrix aligned to the column order expected by the pretrained model weights for the current version.
- Return type:
pandas.DataFrame
- test(X_test)[source]
Tests the model on a test dataset.
- Parameters:
X_test (array-like) – Test feature matrix.
- Returns:
- A tuple containing:
unscaled_predictions (numpy.ndarray): Model predictions in the original scale.
- uncertainties (dict): Dictionary with keys:
’total’ (numpy.ndarray): Total uncertainty.
’epistemic’ (numpy.ndarray): Epistemic uncertainty.
’aleatoric’ (numpy.ndarray): Aleatoric uncertainty.
- Return type:
tuple
- class ise.models.LSTM(lstm_num_layers, lstm_hidden_size, input_size=83, output_size=1, criterion=MSELoss(), output_sequence_length=86, optimizer=<class 'torch.optim.adamw.AdamW'>, lr=0.0001, wd=1e-06, dropout=0.0)[source]
Bases:
ModuleLong Short-Term Memory (LSTM) model for time series forecasting.
This class implements an LSTM network with multiple layers, dropout, and fully connected layers to generate predictions for sequential data.
- lstm_num_layers
Number of LSTM layers in the model.
- Type:
int
Number of hidden units in each LSTM layer.
- Type:
int
- input_size
Number of input features.
- Type:
int
- output_size
Number of output features.
- Type:
int
- output_sequence_length
Number of time steps predicted by the model.
- Type:
int
- device
Device on which the model runs (‘cuda’ or ‘cpu’).
- Type:
str
- lstm
LSTM layer for sequence modeling.
- Type:
nn.LSTM
- relu
ReLU activation function.
- Type:
nn.ReLU
- linear1
Intermediate fully connected layer.
- Type:
nn.Linear
- linear_out
Output layer mapping to final predictions.
- Type:
nn.Linear
- optimizer
Optimization algorithm used for training.
- Type:
torch.optim.Optimizer
- dropout
Dropout layer to prevent overfitting.
- Type:
nn.Dropout
- criterion
Loss function used for training.
- Type:
torch.nn.modules.loss._Loss
- trained
Flag indicating whether the model has been trained.
- Type:
bool
- Parameters:
lstm_num_layers (int) – Number of LSTM layers.
lstm_hidden_size (int) – Number of hidden units in each LSTM layer.
input_size (int, optional) – Number of input features. Defaults to 83.
output_size (int, optional) – Number of output features. Defaults to 1.
criterion (torch.nn.modules.loss._Loss, optional) – Loss function. Defaults to MSELoss.
output_sequence_length (int, optional) – Number of output time steps. Defaults to 86.
optimizer (torch.optim.Optimizer, optional) – Optimizer type. Defaults to AdamW.
- fit(X, y, epochs=100, sequence_length=5, batch_size=64, criterion=None, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint.pt', early_stopping=False, patience=10, verbose=True, dataclass=<class 'ise.data.dataclasses.EmulatorDataset'>, wandb_run=None)[source]
Trains the LSTM model on the provided data.
Supports optional checkpointing and early stopping. If a checkpoint exists, training resumes from the last saved state.
- Parameters:
X (Tensor or DataFrame) – Input training data.
y (Tensor or DataFrame) – Target values corresponding to the input data.
epochs (int, optional) – Number of epochs for training. Defaults to 100.
sequence_length (int, optional) – Length of input sequences. Defaults to 5.
batch_size (int, optional) – Batch size used in training. Defaults to 64.
criterion (torch.nn.modules.loss._Loss, optional) – Loss function. Defaults to None.
X_val (Tensor or DataFrame, optional) – Validation input data. Defaults to None.
y_val (Tensor or DataFrame, optional) – Validation target data. Defaults to None.
save_checkpoints (bool, optional) – Whether to save model checkpoints. Defaults to True.
checkpoint_path (str, optional) – Path to save model checkpoints. Defaults to ‘checkpoint.pt’.
early_stopping (bool, optional) – Whether to enable early stopping. Defaults to False.
patience (int, optional) – Number of epochs to wait before stopping. Defaults to 10.
verbose (bool, optional) – Whether to print training progress. Defaults to True.
dataclass (type, optional) – Dataset class for handling data. Defaults to EmulatorDataset.
wandb_run (wandb.run, optional) – Weights & Biases run for per-epoch metric logging. Defaults to None.
- Raises:
ValueError – If no loss function is provided.
Notes
If validation data is provided but early stopping is disabled, a warning is issued.
If a checkpoint exists, training resumes from the saved epoch.
If early stopping is enabled, the model stops training when validation loss stops improving.
- forward(x)[source]
Performs a forward pass through the LSTM network.
Given an input sequence, the LSTM processes the sequence to extract features, which are passed through a fully connected network to generate predictions.
- Parameters:
x (Tensor) – Input tensor of shape (batch_size, sequence_length, input_size).
- Returns:
Output tensor of shape (batch_size, output_size), representing the model’s predictions.
- Return type:
Tensor
- classmethod load(model_path: str) LSTM[source]
Loads a trained LSTM model from disk.
- Expects:
<model_path> (a .pth with state_dict)
<model_path>_metadata.json (hyperparams & config)
- Returns:
- A model instance reconstructed with saved hyperparams, loss,
and optimizer type (with saved lr/weight_decay).
- Return type:
- Raises:
FileNotFoundError – If weights or metadata files are missing.
ValueError – If the saved model_type does not match this class.
- predict(X, sequence_length=None, batch_size=64, dataclass=<class 'ise.data.dataclasses.EmulatorDataset'>)[source]
Generates predictions using the trained LSTM model.
The model processes input sequences and returns predictions. Predictions are computed in a batch-wise manner to optimize memory usage.
- Parameters:
X (Tensor or DataFrame) – Input data for prediction.
sequence_length (int, optional) – Length of input sequences. Defaults to 5.
batch_size (int, optional) – Batch size used for inference. Defaults to 64.
dataclass (type, optional) – Dataset class for handling data. Defaults to EmulatorDataset.
- Returns:
Predicted values for the input data.
- Return type:
Tensor
Notes
The model is set to evaluation mode before making predictions.
Data is converted to tensors if initially provided as pandas DataFrames.
- save(model_path: str)[source]
Saves the LSTM model weights and metadata.
Writes <model_path> (state_dict) and <model_path>_metadata.json (config).
Records architecture, optimizer type & hparams (lr/weight_decay), and loss name.
Removes the training checkpoint file if this instance has one.
- Parameters:
model_path (str) – Destination file path ending in ‘.pth’.
- Raises:
ValueError – If the model has not been trained yet.
- class ise.models.NormalizingFlow(input_size=43, output_size=1, output_sequence_length=86, num_flow_transforms=5, flow_hidden_features=16, legacy_v1_0_0=False)[source]
Bases:
ModuleA normalizing flow model for probabilistic modeling using invertible transformations.
This model utilizes a sequence of invertible transformations to model complex probability distributions. It is built with a base distribution and a series of transformations, leveraging autoregressive neural networks.
- num_flow_transforms
Number of flow transformations in the model.
- Type:
int
- num_input_features
Number of input features.
- Type:
int
- num_predicted_sle
Number of predicted sea-level equivalent values.
- Type:
int
Number of hidden features in the flow model.
- Type:
int
- output_sequence_length
Length of the output sequence.
- Type:
int
- device
Device on which the model is run (“cuda” or “cpu”).
- Type:
str
- base_distribution
The base normal distribution conditioned on input features.
- Type:
distributions.normal.ConditionalDiagonalNormal
- t
Composite transformation for the normalizing flow.
- Type:
transforms.base.CompositeTransform
- flow
The normalizing flow model.
- Type:
flows.base.Flow
- optimizer
Optimizer for training the model.
- Type:
torch.optim.Adam
- criterion
Log probability function used as the loss criterion.
- Type:
callable
- trained
Flag indicating if the model has been trained.
- Type:
bool
- aleatoric(features, num_samples, batch_size=128)[source]
Estimate aleatoric uncertainty as the std across flow samples per input row.
For each input row, draws
num_samplessamples from the conditional flow and returns the standard deviation across those samples. NaN samples are ignored when computing the std.- Parameters:
features (array-like or torch.Tensor) – Input features of shape
(N, num_features).num_samples (int) – Number of flow samples drawn per input row.
batch_size (int, optional) – Number of rows processed per forward pass. Defaults to 128.
- Returns:
Per-row aleatoric uncertainty, shape
(N,).- Return type:
numpy.ndarray
- fit(X, y, X_val=None, y_val=None, epochs=100, batch_size=64, save_checkpoints=True, checkpoint_path='checkpoint.pt', early_stopping=True, patience=10, verbose=True, wandb_run=None, lr=0.0001, wd=1e-06)[source]
Train the normalizing flow via maximum likelihood (negative log-probability).
If
checkpoint_pathalready exists, training resumes from the saved epoch. After training, the best checkpoint is loaded back into the model and the temporary file is deleted.- Parameters:
X (array-like) – Input features of shape
(num_samples, num_features).y (array-like) – Target values of shape
(num_samples, output_size).X_val (array-like, optional) – Validation features for early stopping.
y_val (array-like, optional) – Validation targets for early stopping.
epochs (int, optional) – Number of training epochs. Defaults to 100.
batch_size (int, optional) – Batch size for training. Defaults to 64.
save_checkpoints (bool, optional) – Whether to save model checkpoints. Defaults to True.
checkpoint_path (str, optional) – Path to save model checkpoints. Defaults to
'checkpoint.pt'.early_stopping (bool, optional) – Whether to use early stopping. Defaults to True.
patience (int, optional) – Number of epochs to wait before early stopping. Defaults to 10.
verbose (bool, optional) – Whether to print training progress. Defaults to True.
wandb_run (wandb.run, optional) – Weights & Biases run for logging. Defaults to None.
lr (float, optional) – AdamW learning rate. Defaults to
1e-4.wd (float, optional) – AdamW weight decay. Defaults to
1e-6.
- get_latent(x, latent_dim=1)[source]
Computes the latent space representation of the given input.
Two approaches are used depending on the model version:
v1.0.0 (legacy): deterministically pushes a zero vector through the forward transform conditioned on
x. Samexalways yields the samez, so the flow acts as a learned deterministic feature extractor.v1.1.0+ (default): draws
latent_dimsamples from the conditional base distribution.zis a stochastic summary ofx, which is more statistically grounded — the latent reflects the modeled conditional distribution rather than a single fixed point.
These approaches behave differently and produce different downstream DeepEnsemble inputs; the choice may be worth revisiting in future versions. The legacy path is more pragmatic and reproducible at inference, while the sampling path aligns more closely with the probabilistic interpretation of the flow.
- Parameters:
x (array-like or torch.Tensor) – Input data of shape (num_samples, num_features).
latent_dim (int, optional) – Number of latent samples to draw (post-v1.0.0). Defaults to 1.
- Returns:
Latent space representation of the input data.
- Return type:
torch.Tensor
- static load(path)[source]
Loads a trained normalizing flow model from a saved checkpoint.
- Parameters:
path (str) – Path to the saved model checkpoint.
- Returns:
A restored instance of the NormalizingFlow model.
- Return type:
- sample(features, num_samples, return_type='numpy')[source]
Generates samples from the trained normalizing flow model.
- Parameters:
features (array-like or torch.Tensor) – Input features to condition the samples on.
num_samples (int) – Number of samples to generate per input feature set.
return_type (str, optional) – Return type, either “numpy” or “tensor”. Defaults to “numpy”.
- Returns:
Generated samples of shape (num_samples, output_size).
- Return type:
np.ndarray or torch.Tensor