"""Custom loss functions for ice-sheet emulator training.
Standard MSE treats all timesteps and sectors equally, but ice sheet
projections have two properties that motivate custom losses:
1. **Extreme-value importance** — large SLE departures (high melt scenarios)
are scientifically most consequential and are underrepresented in the
training set. Weighted losses upweight these samples.
2. **Spatial smoothness** — when predicting full 2-D grid fields (rather than
sector-averaged scalars), neighbouring cells should vary smoothly. The
grid-aware losses add a total-variation regularisation term.
All classes are ``torch.nn.Module`` subclasses with the standard
``forward(input, target)`` signature.
Loss classes
------------
WeightedMSELoss:
MSE with per-sample weights proportional to the deviation of the target
from the dataset mean, normalised by the dataset std. Extreme targets
receive up to ``(1 + weight_factor * |z|) × MSE`` penalty::
from ise.models.loss import WeightedMSELoss
criterion = WeightedMSELoss(data_mean=y_mean, data_std=y_std, weight_factor=2.0)
loss = criterion(predictions, targets)
WeightedMSELossWithSignPenalty:
Extends ``WeightedMSELoss`` with an extra additive penalty when the
predicted sign differs from the true sign. Useful for preventing the
model from predicting sea-level fall when rise is expected.
WeightedMSEPCALoss:
Adds user-supplied per-batch ``custom_weights`` on top of the deviation
weights. Intended for PCA-based training where different principal
components should have different loss contributions.
WeightedGridLoss:
Pixel-wise weighted MSE + total variation regularisation (TVR) for
full 2-D grid predictions. TVR penalises large spatial gradients in
``predicted`` to encourage smooth output fields.
GridCriterion:
Simpler grid loss: pixel-wise MSE + TVR with a fixed ``smoothness_weight``,
without the extreme-value weighting.
WeightedPCALoss:
MSE with fixed per-component weights for PCA coefficient regression. The
first (dominant) principal component can be penalised more heavily.
MSEDeviationLoss:
MSE plus an extra ``penalty_multiplier * MSE`` term applied to samples
whose absolute error exceeds a ``threshold``.
"""
import torch
from ise.utils.functions import get_device
[docs]
class WeightedGridLoss(torch.nn.Module):
"""
Custom loss function that penalizes errors based on the total variation of a grid.
This loss function consists of two components:
1. **Pixel-wise Weighted Mean Squared Error (MSE):** Higher weight is assigned to extreme values.
2. **Total Variation Regularization (TVR):** Enforces spatial smoothness by penalizing large differences between adjacent grid values.
Attributes:
device (str): The device on which the model runs ('cuda' or 'cpu').
"""
def __init__(self):
super().__init__()
self.device = get_device()
self.to(self.device)
[docs]
def total_variation_regularization(self, grid):
"""
Computes the total variation regularization (TVR) loss for spatial smoothness.
Args:
grid (Tensor): A 2D tensor representing spatial data.
Returns:
Tensor: The total variation loss.
"""
# Calculate the sum of horizontal and vertical differences
horizontal_diff = torch.abs(torch.diff(grid, axis=2))
vertical_diff = torch.abs(torch.diff(grid, axis=1))
total_variation = torch.sum(horizontal_diff, axis=(1, 2)) + torch.sum(
vertical_diff, axis=(1, 2)
)
return torch.mean(total_variation)
[docs]
def weighted_pixelwise_mse(self, true, predicted, weights):
"""
Computes the pixel-wise mean squared error (MSE) with custom weights.
Args:
true (Tensor): Ground truth values.
predicted (Tensor): Model predictions.
weights (Tensor): Weighting factor for each pixel.
Returns:
Tensor: Weighted mean squared error.
"""
# Compute the squared error
squared_error = (true - predicted) ** 2
# Apply weights
weighted_error = weights * squared_error
# Return the mean of the weighted error
return torch.mean(weighted_error)
[docs]
def forward(self, true, predicted, smoothness_weight=0.001, extreme_value_threshold=1e-6):
"""
Computes the final weighted loss combining pixel-wise MSE and TVR.
Args:
true (Tensor): Ground truth values.
predicted (Tensor): Model predictions.
smoothness_weight (float, optional): Weighting factor for the TVR loss. Defaults to 0.001.
extreme_value_threshold (float, optional): Threshold to define extreme values. Defaults to 1e-6.
Returns:
Tensor: The total computed loss.
"""
true = (
true.to(self.device).float()
if isinstance(true, torch.Tensor)
else torch.as_tensor(true, dtype=torch.float32, device=self.device)
)
predicted = (
predicted.to(self.device).float()
if isinstance(predicted, torch.Tensor)
else torch.as_tensor(predicted, dtype=torch.float32, device=self.device)
)
# Determine weights based on extreme values
if extreme_value_threshold is not None:
# Identify extreme values in the true data
extreme_mask = torch.abs(true) > extreme_value_threshold
# Assign higher weight to extreme values, 1 to others
weights = torch.where(extreme_mask, 10.0 * torch.ones_like(true), torch.ones_like(true))
else:
# If no threshold is provided, use uniform weights
weights = torch.ones_like(true)
pixelwise_mse = self.weighted_pixelwise_mse(true, predicted, weights)
tvr = self.total_variation_regularization(predicted)
return pixelwise_mse + smoothness_weight * tvr
[docs]
class WeightedMSELoss(torch.nn.Module):
"""
Custom loss function that applies a weighted penalty to extreme values.
This function increases the weight of extreme values based on their deviation from the
dataset mean, normalizing by the standard deviation.
Attributes:
data_mean (Tensor): Mean value of the dataset.
data_std (Tensor): Standard deviation of the dataset.
weight_factor (Tensor): Factor controlling how much extreme values are penalized.
Methods:
- forward: Computes the weighted mean squared error loss.
"""
def __init__(self, data_mean, data_std, weight_factor=1.0):
super().__init__()
self.device = get_device()
self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device)
self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device)
self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device)
self.to(self.device)
[docs]
def forward(self, input, target):
"""
Computes the Weighted Mean Squared Error (MSE) Loss.
Args:
input (Tensor): Predicted values.
target (Tensor): Ground truth values.
Returns:
Tensor: Computed loss.
"""
# Ensure data_mean, data_std, and weight_factor are on the same device as input
input = input.to(self.device)
target = target.to(self.device)
# Calculate the deviation of each target value from the mean
deviation = torch.abs(target - self.data_mean)
# Scale deviations by the standard deviation to normalize them
normalized_deviation = deviation / self.data_std
# Compute weights: increase penalty for extreme values
weights = 1 + (normalized_deviation * self.weight_factor)
# Compute the squared error
squared_error = torch.nn.functional.mse_loss(input, target, reduction="none")
# Apply the weights and take the mean to get the final loss
weighted_squared_error = weights * squared_error
loss = torch.mean(weighted_squared_error)
return loss
[docs]
class WeightedMSEPCALoss(torch.nn.Module):
"""
Extension of WeightedMSELoss that allows for custom per-batch weighting.
This loss function enables additional user-defined weights to further adjust penalties for different predictions.
Attributes:
data_mean (Tensor): Mean of the dataset.
data_std (Tensor): Standard deviation of the dataset.
weight_factor (Tensor): Controls the penalty for extreme values.
custom_weights (Tensor, optional): User-defined weight tensor.
Methods:
- forward: Computes the batch-weighted MSE loss.
"""
def __init__(self, data_mean, data_std, weight_factor=1.0, custom_weights=None):
super().__init__()
self.device = get_device()
self.to(self.device)
self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device)
self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device)
self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device)
self.custom_weights = (
torch.tensor(custom_weights, dtype=torch.float32, device=self.device)
if custom_weights is not None
else None
)
[docs]
def forward(self, input, target):
"""
Computes the batch-weighted mean squared error loss.
Args:
input (Tensor): Predicted values.
target (Tensor): Ground truth values.
Returns:
Tensor: Computed loss.
"""
input = input.to(self.device)
target = target.to(self.device)
# Ensure input and target are of the same shape
if input.shape != target.shape:
raise ValueError("Input and target must have the same shape.")
# Calculate the deviation of each target value from the mean
deviation = torch.abs(target - self.data_mean)
# Scale deviations by the standard deviation to normalize them
normalized_deviation = deviation / self.data_std
# Compute weights: increase penalty for extreme values
weights = 1 + (normalized_deviation * self.weight_factor)
# If custom weights are provided, multiply them by the calculated weights
if self.custom_weights is not None:
cw = self.custom_weights
if cw.dim() == 1:
cw = cw.unsqueeze(0)
if cw.shape != weights.shape:
raise ValueError("Custom weights shape must match input/target shape.")
weights = weights * cw
# Compute the squared error for each element in the batch without reducing
squared_error = (input - target) ** 2
# Apply the weights to the squared error
weighted_squared_error = weights * squared_error
# Take the mean across all dimensions to get the final loss
loss = torch.mean(weighted_squared_error)
return loss
[docs]
class WeightedMSELossWithSignPenalty(torch.nn.Module):
"""
Custom loss function that penalizes errors on extreme values and opposite sign predictions.
This function extends WeightedMSELoss by adding a penalty when the sign of the prediction differs from the target.
Attributes:
data_mean (Tensor): Mean of the dataset.
data_std (Tensor): Standard deviation of the dataset.
weight_factor (Tensor): Factor controlling extreme value weighting.
sign_penalty_factor (Tensor): Factor controlling penalty for opposite sign predictions.
Methods:
- forward: Computes the weighted loss with sign penalties.
"""
def __init__(self, data_mean, data_std, weight_factor=1.0, sign_penalty_factor=1.0):
super().__init__()
self.device = get_device()
self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device)
self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device)
self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device)
self.sign_penalty_factor = torch.tensor(
sign_penalty_factor, dtype=torch.float32, device=self.device
)
self.to(self.device)
[docs]
def forward(self, input, target):
"""
Computes the Weighted MSE Loss with an additional sign penalty.
Args:
input (Tensor): Predicted values.
target (Tensor): Ground truth values.
Returns:
Tensor: Computed loss.
"""
# Calculate the deviation of each target value from the mean
deviation = torch.abs(target - self.data_mean)
# Scale deviations by the standard deviation to normalize them
normalized_deviation = deviation / self.data_std
# Compute weights: increase penalty for extreme values
weights = 1 + (normalized_deviation * self.weight_factor)
# Compute the squared error
squared_error = torch.nn.functional.mse_loss(input, target, reduction="none")
# Calculate sign penalty
sign_penalty = torch.where(
torch.sign(input) != torch.sign(target),
torch.abs(input - target) * self.sign_penalty_factor,
torch.zeros_like(input),
)
# Apply the weights and sign penalty, then take the mean to get the final loss
weighted_squared_error = weights * (squared_error + sign_penalty)
loss = torch.mean(weighted_squared_error)
return loss
[docs]
class GridCriterion(torch.nn.Module):
"""
Custom loss function enforcing spatial smoothness using total variation regularization.
This function encourages smoothness in spatial predictions by penalizing large variations.
Methods:
- total_variation_regularization: Computes the smoothness loss.
- forward: Computes the final loss.
"""
def __init__(
self,
):
super().__init__()
[docs]
def total_variation_regularization(self, grid):
"""
Computes total variation regularization (TVR) loss.
Args:
grid (Tensor): A 2D tensor representing spatial data.
Returns:
Tensor: TVR loss.
"""
# Calculate the sum of horizontal and vertical differences
horizontal_diff = torch.abs(torch.diff(grid, axis=2))
vertical_diff = torch.abs(torch.diff(grid, axis=1))
total_variation = torch.sum(horizontal_diff, axis=(1, 2)) + torch.sum(
vertical_diff, axis=(1, 2)
)
return torch.mean(total_variation)
[docs]
def forward(self, true, predicted, smoothness_weight=0.001):
"""
Computes the final loss by combining pixel-wise MSE and TVR.
Args:
true (Tensor): Ground truth values.
predicted (Tensor): Model predictions.
smoothness_weight (float, optional): Weight for TVR. Defaults to 0.001.
Returns:
Tensor: Computed loss.
"""
pixelwise_mse = torch.mean(
torch.abs(true - predicted) ** 2,
) # loss for each image in the batch (batch_size,)
tvr = self.total_variation_regularization(
predicted,
)
return pixelwise_mse + smoothness_weight * tvr
[docs]
class WeightedPCALoss(torch.nn.Module):
"""
Custom loss function applying different weights to errors in principal component analysis.
This function allows assigning higher penalties to the first components.
Attributes:
component_weights (Tensor): Weighting factors for each principal component.
reduction (str): Specifies reduction mode ('mean', 'sum', or 'none').
Methods:
- forward: Computes the weighted PCA loss.
"""
def __init__(self, component_weights, reduction="mean"):
super().__init__()
self.device = get_device()
self.component_weights = torch.tensor(
component_weights, dtype=torch.float32, device=self.device
)
if len(self.component_weights.size()) == 1:
self.component_weights = self.component_weights.unsqueeze(0) # Make it a row vector
self.reduction = reduction
self.to(self.device)
[docs]
def forward(self, input, target):
"""
Computes the weighted PCA loss.
Args:
input (Tensor): Predicted principal components.
target (Tensor): Actual principal components.
Returns:
Tensor: Computed loss.
"""
input = input.to(self.device)
target = target.to(self.device)
# Ensure input and target are of the same shape
if input.shape != target.shape:
raise ValueError("Input and target must have the same shape")
# Calculate the squared error
squared_error = (input - target) ** 2
# Apply weights to the squared error
weighted_error = squared_error * self.component_weights.to(input.device)
# Apply reduction
if self.reduction == "mean":
return torch.mean(weighted_error)
elif self.reduction == "sum":
return torch.sum(weighted_error)
else:
return weighted_error
[docs]
class MSEDeviationLoss(torch.nn.Module):
"""
Custom MSE Loss with an additional penalty for large deviations.
This function penalizes predictions that deviate significantly from the target.
Attributes:
threshold (float): Deviation threshold for applying penalties.
penalty_multiplier (float): Multiplier controlling penalty severity.
Methods:
- forward: Computes the loss with deviation penalties.
"""
def __init__(self, threshold=1.0, penalty_multiplier=2.0):
super().__init__()
self.threshold = threshold
self.penalty_multiplier = penalty_multiplier
[docs]
def forward(self, predictions, targets):
"""
Computes the MSE loss with an additional deviation penalty.
Args:
predictions (Tensor): Predicted values.
targets (Tensor): Ground truth values.
Returns:
Tensor: Computed loss.
"""
mse_loss = torch.mean((predictions - targets) ** 2)
large_deviation_penalty = torch.mean(
torch.where(
torch.abs(predictions - targets) > self.threshold,
self.penalty_multiplier * (predictions - targets) ** 2,
torch.tensor(0.0, device=predictions.device),
)
)
return mse_loss + large_deviation_penalty