"""Single LSTM network for time series sea-level projection.
This module provides the ``LSTM`` class — the constituent building block of
``DeepEnsemble``. Each instance is an independent stacked LSTM followed by a
two-layer fully-connected head::
LSTM layers (num_layers, hidden_size)
→ FC layer (hidden_size → 32) + ReLU
→ FC output layer (32 → output_size)
The architecture is deliberately simple: hidden-to-output mapping uses only
the final hidden state ``hn[-1]``, making this a many-to-one sequence model
that takes a window of ``sequence_length`` feature vectors and predicts a
single SLE value for the last timestep.
Usage — stand-alone
-------------------
::
from ise.models.lstm import LSTM
import torch.nn as nn
model = LSTM(
lstm_num_layers=2,
lstm_hidden_size=256,
input_size=84, # features after NF latent concat
output_size=1,
criterion=nn.HuberLoss(),
)
model.fit(X_train, y_train, epochs=200, sequence_length=5,
X_val=X_val, y_val=y_val, early_stopping=True, patience=15)
preds = model.predict(X_test, sequence_length=5) # Tensor shape (N, 1)
model.save("lstm.pth")
model_loaded = LSTM.load("lstm.pth")
Usage — inside DeepEnsemble
---------------------------
``LSTM`` instances are assembled into a ``DeepEnsemble`` by passing a list of
them to the constructor. The ensemble calls ``member.predict(x)`` on each
member and aggregates the results. In this context the ``LSTM`` is never
called with ``fit()`` directly — ``DeepEnsemble.fit()`` handles that.
Checkpointing and early stopping
---------------------------------
``fit()`` uses ``CheckpointSaver`` / ``EarlyStoppingCheckpointer`` from
``ise.models.training``. If a checkpoint already exists at the given path,
training resumes from the saved epoch. After training, the best checkpoint
is reloaded before ``fit()`` returns.
``save()`` writes both a ``.pth`` state dict and a ``_metadata.json`` file
storing the full architecture and optimizer hyperparameters. ``load()``
reads both to reconstruct the model without any prior constructor call.
"""
import json
import os
import warnings
import pandas as pd
import torch
import torch.nn.functional as F
import wandb
from torch import nn, optim
from ise.data.dataclasses import EmulatorDataset
from ise.models.training import CheckpointSaver, EarlyStoppingCheckpointer
from ise.utils.functions import get_device, to_tensor
[docs]
class LSTM(nn.Module):
"""
Long Short-Term Memory (LSTM) model for time series forecasting.
This class implements an LSTM network with multiple layers, dropout, and fully connected
layers to generate predictions for sequential data.
Attributes:
lstm_num_layers (int): Number of LSTM layers in the model.
lstm_num_hidden (int): Number of hidden units in each LSTM layer.
input_size (int): Number of input features.
output_size (int): Number of output features.
output_sequence_length (int): Number of time steps predicted by the model.
device (str): Device on which the model runs ('cuda' or 'cpu').
lstm (nn.LSTM): LSTM layer for sequence modeling.
relu (nn.ReLU): ReLU activation function.
linear1 (nn.Linear): Intermediate fully connected layer.
linear_out (nn.Linear): Output layer mapping to final predictions.
optimizer (torch.optim.Optimizer): Optimization algorithm used for training.
dropout (nn.Dropout): Dropout layer to prevent overfitting.
criterion (torch.nn.modules.loss._Loss): Loss function used for training.
trained (bool): Flag indicating whether the model has been trained.
Args:
lstm_num_layers (int): Number of LSTM layers.
lstm_hidden_size (int): Number of hidden units in each LSTM layer.
input_size (int, optional): Number of input features. Defaults to 83.
output_size (int, optional): Number of output features. Defaults to 1.
criterion (torch.nn.modules.loss._Loss, optional): Loss function. Defaults to MSELoss.
output_sequence_length (int, optional): Number of output time steps. Defaults to 86.
optimizer (torch.optim.Optimizer, optional): Optimizer type. Defaults to AdamW.
"""
def __init__(
self,
lstm_num_layers,
lstm_hidden_size,
input_size=83,
output_size=1,
criterion=torch.nn.MSELoss(),
output_sequence_length=86,
optimizer=optim.AdamW,
lr=1e-4,
wd=1e-6,
dropout=0.0,
):
"""Construct the LSTM network.
Args:
lstm_num_layers (int): Number of stacked LSTM layers.
lstm_hidden_size (int): Hidden units per LSTM layer.
input_size (int, optional): Number of input features. Defaults to 83.
output_size (int, optional): Number of output features. Defaults to 1.
criterion (torch.nn.Module, optional): Loss function. Defaults to MSELoss.
output_sequence_length (int, optional): Projection length (used in dataset batching).
Defaults to 86.
optimizer (type, optional): Optimizer class. Defaults to ``optim.AdamW``.
lr (float, optional): Learning rate. Defaults to 1e-4.
wd (float, optional): Weight decay. Defaults to 1e-6.
dropout (float, optional): Dropout probability applied inside LSTM layers.
Defaults to 0.0 (no dropout).
"""
super().__init__()
# Initialize attributes
self.lstm_num_layers = int(lstm_num_layers)
self.lstm_num_hidden = int(lstm_hidden_size)
self.input_size = input_size
self.output_size = output_size
self.output_sequence_length = output_sequence_length
self.device = get_device()
self.to(self.device)
# Initialize model layers
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=int(lstm_hidden_size),
batch_first=True,
num_layers=lstm_num_layers,
dropout=dropout,
)
self.relu = nn.ReLU()
self.linear1 = nn.Linear(in_features=lstm_hidden_size, out_features=32)
self.linear_out = nn.Linear(in_features=32, out_features=output_size)
# Initialize optimizer and other components
self.optimizer = optimizer(self.parameters(), lr=lr, weight_decay=wd)
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else None
self.criterion = criterion
self.trained = False
self.sequence_length = None
[docs]
def forward(self, x):
"""
Performs a forward pass through the LSTM network.
Given an input sequence, the LSTM processes the sequence to extract features,
which are passed through a fully connected network to generate predictions.
Args:
x (Tensor): Input tensor of shape (batch_size, sequence_length, input_size).
Returns:
Tensor: Output tensor of shape (batch_size, output_size), representing
the model’s predictions.
"""
_, (hn, _) = self.lstm(x)
x = hn[-1, :, :]
# Perform linear layer operations
x = self.linear1(x)
x = self.relu(x)
if self.dropout is not None:
x = self.dropout(x)
x = self.linear_out(x)
return x
[docs]
def fit(
self,
X,
y,
epochs=100,
sequence_length=5,
batch_size=64,
criterion=None,
X_val=None,
y_val=None,
save_checkpoints=True,
checkpoint_path="checkpoint.pt",
early_stopping=False,
patience=10,
verbose=True,
dataclass=EmulatorDataset,
wandb_run=None,
):
"""
Trains the LSTM model on the provided data.
Supports optional checkpointing and early stopping. If a checkpoint exists,
training resumes from the last saved state.
Args:
X (Tensor or DataFrame): Input training data.
y (Tensor or DataFrame): Target values corresponding to the input data.
epochs (int, optional): Number of epochs for training. Defaults to 100.
sequence_length (int, optional): Length of input sequences. Defaults to 5.
batch_size (int, optional): Batch size used in training. Defaults to 64.
criterion (torch.nn.modules.loss._Loss, optional): Loss function. Defaults to None.
X_val (Tensor or DataFrame, optional): Validation input data. Defaults to None.
y_val (Tensor or DataFrame, optional): Validation target data. Defaults to None.
save_checkpoints (bool, optional): Whether to save model checkpoints. Defaults to True.
checkpoint_path (str, optional): Path to save model checkpoints. Defaults to 'checkpoint.pt'.
early_stopping (bool, optional): Whether to enable early stopping. Defaults to False.
patience (int, optional): Number of epochs to wait before stopping. Defaults to 10.
verbose (bool, optional): Whether to print training progress. Defaults to True.
dataclass (type, optional): Dataset class for handling data. Defaults to EmulatorDataset.
wandb_run (wandb.run, optional): Weights & Biases run for per-epoch metric logging.
Defaults to None.
Raises:
ValueError: If no loss function is provided.
Notes:
- If validation data is provided but early stopping is disabled, a warning is issued.
- If a checkpoint exists, training resumes from the saved epoch.
- If early stopping is enabled, the model stops training when validation loss stops improving.
"""
X, y = to_tensor(X).to(self.device), to_tensor(y).to(self.device)
if y.ndimension() == 1:
y = y.unsqueeze(1)
self.wandb_run = wandb_run
self.sequence_length = sequence_length
# Check if a checkpoint exists and load it
start_epoch = 1
best_loss = float("inf")
self.checkpoint_path = checkpoint_path
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, weights_only=True)
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"] + 1
best_loss = checkpoint.get("best_loss", float("inf"))
if verbose:
print(
f"Resuming from checkpoint at epoch {start_epoch} with validation loss {best_loss:.6f}"
)
# Check if validation data is provided
if X_val is not None and y_val is not None:
validate = True
if not early_stopping:
warnings.warn(
"Validation data provided but early_stopping is False. Early stopping is recommended for validation data."
)
X_val, y_val = to_tensor(X_val).to(self.device), to_tensor(y_val).to(self.device)
else:
validate = False
# Set loss criterion
if criterion is not None:
self.criterion = criterion.to(self.device)
elif criterion is None and self.criterion is None:
raise ValueError("loss must be provided if criterion is None.")
self.criterion = self.criterion.to(self.device)
# Convert data to numpy arrays if pandas DataFrames
if isinstance(X, pd.DataFrame):
X = X.values
if isinstance(y, pd.DataFrame):
y = y.values
# Create dataset and data loader
dataset = dataclass(
X, y, sequence_length=sequence_length, projection_length=self.output_sequence_length
)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=int(batch_size), shuffle=True)
# Set model to training mode
self.train()
self.to(self.device)
# Initialize early stopping
if save_checkpoints:
if early_stopping:
checkpointer = EarlyStoppingCheckpointer(
self, self.optimizer, checkpoint_path, patience, verbose
)
else:
checkpointer = CheckpointSaver(self, self.optimizer, checkpoint_path, verbose)
checkpointer.best_loss = best_loss
else:
checkpointer = None
# Training loop
if start_epoch <= epochs:
for epoch in range(start_epoch, epochs + 1):
self.train()
batch_losses = []
for i, (x, y) in enumerate(data_loader):
x = x.to(self.device)
y = y.to(self.device)
self.optimizer.zero_grad()
y_pred = self.forward(x)
loss = self.criterion(y_pred, y) # Renamed to 'loss' for clarity
loss.backward()
self.optimizer.step()
batch_losses.append(loss.item())
# Print average batch loss and validation loss (if provided)
if validate:
val_preds = self.predict(
X_val, sequence_length=sequence_length, batch_size=batch_size
).to(self.device)
val_loss = F.mse_loss(val_preds.squeeze(), y_val.squeeze())
if save_checkpoints:
checkpointer(val_loss, epoch)
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
if verbose:
print("Early stopping")
break
if self.wandb_run:
log_dict = {
"epoch": epoch,
"train_loss": sum(batch_losses) / len(batch_losses),
"val_loss": val_loss.item(),
}
wandb.log(log_dict)
if verbose:
print(
f"[epoch/total]: [{epoch}/{epochs}], train loss: {sum(batch_losses) / len(batch_losses)}, val mse: {val_loss:.6f} -- {getattr(checkpointer, 'log', '') if checkpointer is not None else ''}"
)
else:
average_batch_loss = sum(batch_losses) / len(batch_losses)
# Without validation, checkpoint on training loss so the
# post-training "load best model" step has a file to read.
if save_checkpoints:
checkpointer(average_batch_loss, epoch)
if verbose:
print(
f"[epoch/total]: [{epoch}/{epochs}], train loss: {average_batch_loss}"
)
else:
if verbose:
print(f"Training already completed ({epochs}/{epochs}).")
self.trained = True
# Load best model — only if a checkpoint was actually written. Very
# short training runs (e.g. epochs=0 or all epochs already complete via
# checkpoint resume) can finish with no checkpoint file on disk.
if save_checkpoints and os.path.exists(checkpoint_path):
if self.wandb_run:
model_name = checkpoint_path.split("/")[-1]
artifact = wandb.Artifact(model_name, type="model")
artifact.add_file(checkpoint_path)
self.wandb_run.log_artifact(artifact)
checkpoint = torch.load(checkpoint_path, weights_only=False)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint.keys():
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.best_loss = checkpoint["best_loss"]
self.epochs_trained = checkpoint["epoch"]
else:
self.load_state_dict(checkpoint)
# os.remove(checkpoint_path)
self.trained = True
[docs]
def predict(self, X, sequence_length=None, batch_size=64, dataclass=EmulatorDataset):
"""
Generates predictions using the trained LSTM model.
The model processes input sequences and returns predictions. Predictions are computed
in a batch-wise manner to optimize memory usage.
Args:
X (Tensor or DataFrame): Input data for prediction.
sequence_length (int, optional): Length of input sequences. Defaults to 5.
batch_size (int, optional): Batch size used for inference. Defaults to 64.
dataclass (type, optional): Dataset class for handling data. Defaults to EmulatorDataset.
Returns:
Tensor: Predicted values for the input data.
Notes:
- The model is set to evaluation mode before making predictions.
- Data is converted to tensors if initially provided as pandas DataFrames.
"""
self.eval()
self.to(self.device)
if sequence_length is None:
sequence_length = self.sequence_length
# Convert data to numpy array if pandas DataFrame
if isinstance(X, pd.DataFrame):
X = X.values
# Create dataset and data loader
dataset = dataclass(
X,
y=None,
sequence_length=sequence_length,
projection_length=self.output_sequence_length,
)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
preds = torch.tensor([]).to(self.device)
for X_test_batch in data_loader:
self.eval()
X_test_batch = X_test_batch.to(self.device)
y_pred = self.forward(X_test_batch)
preds = torch.cat((preds, y_pred), 0)
return preds
[docs]
def save(self, model_path: str):
"""
Saves the LSTM model weights and metadata.
- Writes <model_path> (state_dict) and <model_path>_metadata.json (config).
- Records architecture, optimizer type & hparams (lr/weight_decay), and loss name.
- Removes the training checkpoint file if this instance has one.
Args:
model_path (str): Destination file path ending in '.pth'.
Raises:
ValueError: If the model has not been trained yet.
"""
if not getattr(self, "trained", False):
raise ValueError("Train the model before saving.")
model_dir = os.path.dirname(model_path) or "."
os.makedirs(model_dir, exist_ok=True)
# Pull optimizer hyperparams if available
opt_group = self.optimizer.param_groups[0] if hasattr(self, "optimizer") else {}
lr = float(opt_group.get("lr", 1e-4))
weight_decay = float(opt_group.get("weight_decay", 0.0))
metadata = {
"model_type": self.__class__.__name__,
"version": "1.0",
"device": get_device(),
"architecture": {
"lstm_num_layers": int(self.lstm_num_layers),
"lstm_num_hidden": int(self.lstm_num_hidden),
"input_size": int(self.input_size),
"output_size": int(self.output_size),
"output_sequence_length": int(self.output_sequence_length),
"sequence_length": int(self.sequence_length)
if self.sequence_length is not None
else 5,
# Useful to have if you ever change these later:
"fc_hidden": int(self.linear1.out_features),
"dropout_p": float(getattr(self.dropout, "p", 0.0)),
},
"criterion": getattr(self.criterion, "__class__", type(self.criterion)).__name__,
"optimizer": {
"type": (
self.optimizer.__class__.__name__ if hasattr(self, "optimizer") else "AdamW"
),
"lr": lr,
"weight_decay": weight_decay,
},
"trained": bool(getattr(self, "trained", False)),
"best_loss": float(getattr(self, "best_loss", float("inf"))),
"epochs_trained": int(getattr(self, "epochs_trained", 0)),
"path": os.path.basename(model_path),
}
# Save metadata JSON
metadata_path = model_path.replace(".pth", "_metadata.json").replace(
".pt", "_metadata.json"
)
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=4)
print(f"Model metadata saved to {metadata_path}")
# Save model weights
torch.save(self.state_dict(), model_path)
print(f"Model parameters saved to {model_path}")
# Optionally remove training checkpoint if it exists
# if hasattr(self, "checkpoint_path") and isinstance(self.checkpoint_path, str):
# try:
# if os.path.isfile(self.checkpoint_path):
# os.remove(self.checkpoint_path)
# print(f"Removed training checkpoint: {self.checkpoint_path}")
# except OSError:
# pass
[docs]
@classmethod
def load(cls, model_path: str) -> "LSTM":
"""
Loads a trained LSTM model from disk.
Expects:
- <model_path> (a .pth with state_dict)
- <model_path>_metadata.json (hyperparams & config)
Returns:
LSTM: A model instance reconstructed with saved hyperparams, loss,
and optimizer type (with saved lr/weight_decay).
Raises:
FileNotFoundError: If weights or metadata files are missing.
ValueError: If the saved model_type does not match this class.
"""
metadata_path = model_path.replace(".pth", "_metadata.json").replace(
".pt", "_metadata.json"
)
if not os.path.isfile(metadata_path):
raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
if not os.path.isfile(model_path):
raise FileNotFoundError(f"Model weights file not found: {model_path}")
with open(metadata_path) as f:
metadata = json.load(f)
if metadata.get("model_type") != cls.__name__:
raise ValueError(
f"Metadata type {metadata.get('model_type')} does not match {cls.__name__}"
)
arch = metadata["architecture"]
crit_name = metadata.get("criterion", "MSELoss")
opt_info = metadata.get("optimizer", {})
opt_name = opt_info.get("type", "AdamW")
lr = float(opt_info.get("lr", 1e-4))
wd = float(opt_info.get("weight_decay", 0.0))
# Loss + Optimizer lookup (extend as needed)
loss_lookup = {
"MSELoss": torch.nn.MSELoss(),
"L1Loss": torch.nn.L1Loss(),
"HuberLoss": torch.nn.HuberLoss(),
"SmoothL1Loss": torch.nn.SmoothL1Loss(),
"CrossEntropyLoss": torch.nn.CrossEntropyLoss(),
"BCELoss": torch.nn.BCELoss(),
"BCEWithLogitsLoss": torch.nn.BCEWithLogitsLoss(),
}
optim_lookup = {
"Adam": optim.Adam,
"AdamW": optim.AdamW,
"SGD": optim.SGD,
"RMSprop": optim.RMSprop,
"Adagrad": optim.Adagrad,
}
criterion = loss_lookup.get(crit_name, torch.nn.MSELoss())
opt_cls = optim_lookup.get(opt_name, optim.AdamW)
# Re-instantiate the model with saved hyperparams
model = cls(
lstm_num_layers=int(arch["lstm_num_layers"]),
lstm_hidden_size=int(arch["lstm_num_hidden"]),
input_size=int(arch["input_size"]),
output_size=int(arch["output_size"]),
output_sequence_length=int(arch["output_sequence_length"]),
criterion=criterion,
optimizer=opt_cls,
lr=lr,
wd=wd,
)
model.sequence_length = int(arch["sequence_length"])
# Load weights (CPU-safe)
state_dict = torch.load(
model_path,
map_location="cpu" if get_device() == "cpu" else None,
weights_only=True,
)
model.load_state_dict(state_dict)
# Restore misc flags/attrs for convenience
model.trained = bool(metadata.get("trained", True))
model.best_loss = float(metadata.get("best_loss", float("inf")))
model.epochs_trained = int(metadata.get("epochs_trained", 0))
model.sequence_length = int(arch["sequence_length"])
model.eval()
return model