Source code for ise.data.dataclasses

"""PyTorch Dataset classes for ISEFlow training and inference.

This module provides four ``torch.utils.data.Dataset`` subclasses for loading
ice-sheet emulator data.  The default for ISEFlow is ``EmulatorDataset``,
which handles the 86-timestep projection structure and the sequence padding
needed by the LSTM members of ``DeepEnsemble``.

Dataset classes
---------------
EmulatorDataset (default for ISEFlow):
    Wraps a flat ``(N_projections * 86, features)`` or batched
    ``(N_projections, 86, features)`` feature matrix.  ``__getitem__``
    returns a zero-padded sliding window of ``sequence_length`` timesteps so
    that the LSTM always receives a fixed-length context window even at the
    start of a projection.  Used by both ``LSTM.fit()`` and
    ``NormalizingFlow.fit()``::

        from ise.data.dataclasses import EmulatorDataset
        from torch.utils.data import DataLoader

        ds = EmulatorDataset(X, y, sequence_length=5, projection_length=86)
        loader = DataLoader(ds, batch_size=64, shuffle=True)

PyTorchDataset:
    Minimal ``(X[i], y[i])`` pair dataset with no sequence logic.  Used when
    data is already structured as individual feature vectors (e.g. for the
    normalizing flow, which uses ``sequence_length=1``).

TSDataset:
    Similar to ``EmulatorDataset`` but expects pre-batched 3-D tensors
    ``(N, T, F)``.  Kept for backward compatibility.

ScenarioDataset:
    Simple ``(features[idx], labels[idx])`` pair dataset used in the
    experimental scenario-classification models.

Padding convention
------------------
All sequence-aware datasets pad *at the beginning* of each projection with
the zero vector so that the most recent timestep is always at index ``-1``
of the returned sequence.  This means the LSTM sees a causal context that
grows from zero padding at t=1 to a full ``sequence_length`` window by
t=``sequence_length``.
"""

import warnings

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset


[docs] class EmulatorDataset(Dataset): """ A PyTorch dataset for loading emulator data, designed to handle sequence-based inputs and projections. Args: X (pandas.DataFrame, numpy.ndarray, or torch.Tensor): The input data. y (pandas.DataFrame, numpy.ndarray, or torch.Tensor): The target data. sequence_length (int, optional): The length of the input sequence. Default is 5. projection_length (int or tuple, optional): The length of the projection period. Default is 86. Attributes: X (torch.Tensor): The input data converted to a PyTorch tensor. y (torch.Tensor): The target data converted to a PyTorch tensor. sequence_length (int): The length of the input sequence. xdim (int): The number of dimensions in X. num_projections (int): The number of projections in the dataset. num_timesteps (int): The number of timesteps per projection. num_features (int): The number of features in the dataset. Methods: _to_tensor(x): Converts input data to a PyTorch tensor. __len__(): Returns the total number of samples. __getitem__(i): Retrieves the i-th sample from the dataset, including proper padding. """ def __init__(self, X, y, sequence_length=5, projection_length=86): super().__init__() if isinstance(projection_length, tuple): if len(projection_length) == 1: projection_length = projection_length[0] else: raise ValueError( "Projection length must be a single integer or a tuple of two integers." ) if X.shape[0] < projection_length and projection_length == 86: warnings.warn( f"Full projections of {projection_length} timesteps are not present in the dataset. This may lead to unexpected behavior." ) self.X = self._to_tensor(X) self.y = self._to_tensor(y) self.sequence_length = sequence_length self.xdim = len(X.shape) if self.xdim == 3: # Batched by projection self.num_projections, self.num_timesteps, self.num_features = X.shape elif self.xdim == 2: # Unbatched (rows of projections*timestamps) self.projections_and_timesteps, _ = X.shape self.num_timesteps = projection_length self.num_projections = self.projections_and_timesteps // self.num_timesteps self.num_features = X.shape[1] self.features = self.num_features
[docs] def _to_tensor(self, x): """ Converts input data to a PyTorch tensor of type float. Args: x (pandas.DataFrame, numpy.ndarray, or torch.Tensor): The input data. Returns: torch.Tensor: A PyTorch tensor of type float. """ if x is None: return None if isinstance(x, pd.DataFrame): x = torch.tensor(x.values) elif isinstance(x, np.ndarray): x = torch.tensor(x) elif isinstance(x, torch.Tensor): pass else: raise ValueError("Data must be a pandas dataframe, numpy array, or PyTorch tensor") return x.float()
[docs] def __len__(self): """ Returns the total number of samples in the dataset. Returns: int: The dataset length. """ if self.xdim == 2: return self.X.shape[0] else: return self.X.shape[0] * self.X.shape[1]
[docs] def __getitem__(self, i): """ Retrieves the i-th sample from the dataset, applying padding if necessary. Args: i (int): The index of the item to retrieve. Returns: tuple: A tuple containing the input sequence and corresponding target value (if available). """ # Calculate projection index and timestep index projection_index = i // self.num_timesteps time_step_index = i % self.num_timesteps # Initialize a sequence with zeros for padding sequence = torch.zeros((self.sequence_length, self.num_features)) # Calculate start and end points for copying data start_point = max(0, time_step_index - self.sequence_length + 1) end_point = time_step_index + 1 length_of_data = end_point - start_point # Copy the data from the dataset to the end of the sequence to preserve recent data at the end if self.xdim == 3: sequence[-length_of_data:] = self.X[projection_index, start_point:end_point] elif self.xdim == 2: sequence[-length_of_data:] = self.X[ projection_index * self.num_timesteps + start_point : projection_index * self.num_timesteps + end_point ] if self.y is None: return sequence return sequence, self.y[i]
[docs] class PyTorchDataset(Dataset): """ A PyTorch dataset for general-purpose data loading. Args: X (torch.Tensor): The input data. y (torch.Tensor): The target data. Methods: __getitem__(index): Retrieves the sample at the specified index. __len__(): Returns the total dataset length. """ def __init__(self, X, y): self.X_data = X self.y_data = y
[docs] def __getitem__(self, index): """ Retrieves the sample at the specified index. Args: index (int): The index of the sample. Returns: tuple: The input data and corresponding target (if available). """ if self.y_data is None: return self.X_data[index] return self.X_data[index], self.y_data[index]
[docs] def __len__(self): """ Returns the total number of samples. Returns: int: The dataset length. """ return len(self.X_data)
[docs] class TSDataset(Dataset): """ A PyTorch dataset for handling time series data with sequence-based input. Args: X (torch.Tensor): The input data. y (torch.Tensor): The target data. sequence_length (int, optional): The length of the input sequence. Default is 5. Attributes: X (torch.Tensor): The input data. y (torch.Tensor): The target data. sequence_length (int): The sequence length. Methods: __len__(): Returns the dataset length. __getitem__(i): Retrieves the i-th time series sample. """ def __init__(self, X, y, sequence_length=5): super().__init__() self.X = X self.y = y self.sequence_length = sequence_length
[docs] def __len__(self): """ Returns the length of the dataset. Returns: int: The dataset length. """ return len(self.X)
[docs] def __getitem__(self, i): """ Retrieves the i-th sample, applying padding if needed. Args: i (int): The index of the sample. Returns: tuple: A tuple containing the input sequence and corresponding target value (if available). """ if i >= self.sequence_length - 1: i_start = i - self.sequence_length + 1 x = self.X[i_start : (i + 1), :] else: padding = self.X[0].repeat(self.sequence_length - i - 1, 1) x = self.X[0 : (i + 1), :] x = torch.cat((padding, x), 0) if self.y is None: return x return x, self.y[i]
[docs] class ScenarioDataset(Dataset): """ A PyTorch dataset designed for scenario-based data loading. Args: features (torch.Tensor): The input features. labels (torch.Tensor): The target labels. Attributes: features (torch.Tensor): The input features. labels (torch.Tensor): The target labels. Methods: __len__(): Returns the dataset length. __getitem__(idx): Retrieves the sample at the given index. """ def __init__(self, features, labels): self.features = features self.labels = labels
[docs] def __len__(self): """ Returns the total number of samples. Returns: int: The dataset length. """ return len(self.features)
[docs] def __getitem__(self, idx): """ Retrieves the sample at the given index. Args: idx (int): The index of the sample. Returns: tuple: A tuple containing the input features and corresponding labels. """ return self.features[idx], self.labels[idx]