Source code for ise.models.pretrained

"""Pretrained ISEFlow weight management.

Weights are hosted on HuggingFace Hub at ``pvankatwyk/ISEFlow`` and are
downloaded automatically on first use via ``huggingface_hub``.  The downloaded
files are cached in the default HuggingFace cache directory
(``~/.cache/huggingface/hub`` or ``$HF_HOME``).

During local development, if the HuggingFace download fails (e.g. no internet
access) and local weights exist under ``ise/models/pretrained/ISEFlow/``, the
loader falls back to those local paths transparently.
"""

import os
import sys

from huggingface_hub import snapshot_download, try_to_load_from_cache
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars

HF_REPO_ID = "pvankatwyk/ISEFlow"

ISEFLOW_LATEST_MODEL_VERSION = "v1.1.0"
"""The most recent ISEFlow pretrained model version. Distinct from the ise-py package version."""

_LOCAL_PRETRAINED_DIR = os.path.dirname(__file__)


[docs] def get_model_dir(version: str, ice_sheet: str) -> str: """Return the local directory containing weights for a given model version. Downloads the weights from HuggingFace Hub if not already cached. Falls back to the bundled local path when HF is unavailable and local weights exist (development / air-gapped environments). Parameters ---------- version : str Model version string, e.g. ``'v1.0.0'`` or ``'v1.1.0'``. ice_sheet : str Ice sheet identifier — ``'AIS'`` or ``'GrIS'``. Returns ------- str Absolute path to the directory containing ``deep_ensemble.pth``, ``normalizing_flow.pth``, ``scaler_X.pkl``, and ``scaler_y.pkl``. """ subfolder = _subfolder(version, ice_sheet) local_fallback = os.path.join(_LOCAL_PRETRAINED_DIR, "ISEFlow", subfolder) already_cached = _is_cached(subfolder) try: if already_cached: # Quiet path: weights already on disk, suppress HF's progress bars. disable_progress_bars() else: enable_progress_bars() print( f"[ise] Downloading ISEFlow {ice_sheet} {version} weights from " f"HuggingFace Hub ({HF_REPO_ID})... (first-time only; cached afterwards)", file=sys.stderr, flush=True, ) local_dir = snapshot_download( repo_id=HF_REPO_ID, allow_patterns=[f"{subfolder}/**"], ) if not already_cached: print( f"[ise] Finished downloading ISEFlow {ice_sheet} {version} weights.", file=sys.stderr, flush=True, ) return os.path.join(local_dir, subfolder) except Exception: # Fall back to bundled weights (local dev or air-gapped HPC). if os.path.isdir(local_fallback): return local_fallback raise RuntimeError( f"Could not download weights from HuggingFace Hub ({HF_REPO_ID}) " f"and no local fallback found at {local_fallback}. " "Install huggingface_hub and ensure internet access, or place " "the weights at the local path." )
def _is_cached(subfolder: str) -> bool: """Return True if the canonical weight files for this subfolder are already cached.""" sentinel_files = ( f"{subfolder}/deep_ensemble.pth", f"{subfolder}/normalizing_flow.pth", ) for filename in sentinel_files: cached = try_to_load_from_cache(repo_id=HF_REPO_ID, filename=filename) if not isinstance(cached, str): return False return True def _subfolder(version: str, ice_sheet: str) -> str: """Return the HuggingFace subfolder path for a given version and ice sheet.""" tag = version.replace(".", "-") # e.g. v1.1.0 -> v1-1-0 return f"{version}/ISEFlow_{ice_sheet}_{tag}" # --------------------------------------------------------------------------- # Backward-compat path constants (kept as shims; now resolved via get_model_dir) # --------------------------------------------------------------------------- def _lazy_path(version: str, ice_sheet: str) -> str: """Resolve a model path, preferring local if it exists (avoids HF call at import).""" subfolder = _subfolder(version, ice_sheet) local = os.path.join(_LOCAL_PRETRAINED_DIR, "ISEFlow", subfolder) if os.path.isdir(local): return local # Return the expected local path; actual download happens in get_model_dir() return local ISEFlow_AIS_v1_0_0_path = _lazy_path("v1.0.0", "AIS") ISEFlow_GrIS_v1_0_0_path = _lazy_path("v1.0.0", "GrIS") ISEFlow_AIS_v1_1_0_path = _lazy_path("v1.1.0", "AIS") ISEFlow_GrIS_v1_1_0_path = _lazy_path("v1.1.0", "GrIS") # --------------------------------------------------------------------------- # Variable lists (unchanged — these define model feature order) # --------------------------------------------------------------------------- ISEFlow_AIS_v1_1_0_variables = [ "year", "sector", "initial_year", "numerics_FD", "numerics_FE", "numerics_FE/FV", "stress_balance_HO", "stress_balance_Hybrid", "stress_balance_L1L2", "stress_balance_SIA_SSA", "stress_balance_SSA", "stress_balance_Stokes", "resolution_16", "resolution_20", "resolution_32", "resolution_4", "resolution_8", "resolution_variable", "init_method_DA", "init_method_DA_geom", "init_method_DA_relax", "init_method_Eq", "init_method_SP", "init_method_SP_icethickness", "melt_Floating_condition", "melt_No", "melt_Sub-grid", "ice_front_Div", "ice_front_Fix", "ice_front_MH", "ice_front_RO", "ice_front_StR", "open_melt_param_Lin", "open_melt_param_Nonlocal_Slope", "open_melt_param_PICO", "open_melt_param_PICOP", "open_melt_param_Plume", "open_melt_param_Quad", "standard_melt_param_Local", "standard_melt_param_Local_anom", "standard_melt_param_Nonlocal", "standard_melt_param_Nonlocal_anom", "Ocean forcing_Open", "Ocean forcing_Standard", "Ocean sensitivity_High", "Ocean sensitivity_Low", "Ocean sensitivity_Medium", "Ocean sensitivity_PIGL", "Ice shelf fracture_False", "Ice shelf fracture_True", "pr_anomaly", "evspsbl_anomaly", "smb_anomaly", "ts_anomaly", "thermal_forcing", "salinity", "temperature", "pr_anomaly.lag1", "evspsbl_anomaly.lag1", "smb_anomaly.lag1", "ts_anomaly.lag1", "thermal_forcing.lag1", "salinity.lag1", "temperature.lag1", "pr_anomaly.lag2", "evspsbl_anomaly.lag2", "smb_anomaly.lag2", "ts_anomaly.lag2", "thermal_forcing.lag2", "salinity.lag2", "temperature.lag2", "pr_anomaly.lag3", "evspsbl_anomaly.lag3", "smb_anomaly.lag3", "ts_anomaly.lag3", "thermal_forcing.lag3", "salinity.lag3", "temperature.lag3", "pr_anomaly.lag4", "evspsbl_anomaly.lag4", "smb_anomaly.lag4", "ts_anomaly.lag4", "thermal_forcing.lag4", "salinity.lag4", "temperature.lag4", "pr_anomaly.lag5", "evspsbl_anomaly.lag5", "smb_anomaly.lag5", "ts_anomaly.lag5", "thermal_forcing.lag5", "salinity.lag5", "temperature.lag5", ] ISEFlow_AIS_v1_0_0_variables = [ "year", "sector", "pr_anomaly", "evspsbl_anomaly", "mrro_anomaly", "smb_anomaly", "ts_anomaly", "thermal_forcing", "salinity", "temperature", "pr_anomaly.lag1", "evspsbl_anomaly.lag1", "mrro_anomaly.lag1", "smb_anomaly.lag1", "ts_anomaly.lag1", "thermal_forcing.lag1", "salinity.lag1", "temperature.lag1", "pr_anomaly.lag2", "evspsbl_anomaly.lag2", "mrro_anomaly.lag2", "smb_anomaly.lag2", "ts_anomaly.lag2", "thermal_forcing.lag2", "salinity.lag2", "temperature.lag2", "pr_anomaly.lag3", "evspsbl_anomaly.lag3", "mrro_anomaly.lag3", "smb_anomaly.lag3", "ts_anomaly.lag3", "thermal_forcing.lag3", "salinity.lag3", "temperature.lag3", "pr_anomaly.lag4", "evspsbl_anomaly.lag4", "mrro_anomaly.lag4", "smb_anomaly.lag4", "ts_anomaly.lag4", "thermal_forcing.lag4", "salinity.lag4", "temperature.lag4", "pr_anomaly.lag5", "evspsbl_anomaly.lag5", "mrro_anomaly.lag5", "smb_anomaly.lag5", "ts_anomaly.lag5", "thermal_forcing.lag5", "salinity.lag5", "temperature.lag5", "initial_year", "numerics_FD", "numerics_FE", "numerics_FE/FV", "stress_balance_HO", "stress_balance_Hybrid", "stress_balance_L1L2", "stress_balance_SIA_SSA", "stress_balance_SSA", "stress_balance_Stokes", "resolution_16", "resolution_20", "resolution_32", "resolution_4", "resolution_8", "resolution_variable", "init_method_DA", "init_method_DA_geom", "init_method_DA_relax", "init_method_Eq", "init_method_SP", "init_method_SP_icethickness", "melt_Floating_condition", "melt_No", "melt_Sub-grid", "ice_front_Div", "ice_front_Fix", "ice_front_MH", "ice_front_RO", "ice_front_StR", "open_melt_param_Lin", "open_melt_param_Nonlocal_Slope", "open_melt_param_PICO", "open_melt_param_PICOP", "open_melt_param_Plume", "open_melt_param_Quad", "standard_melt_param_Local", "standard_melt_param_Local_anom", "standard_melt_param_Nonlocal", "standard_melt_param_Nonlocal_anom", "Ocean forcing_Open", "Ocean forcing_Standard", "Ocean sensitivity_High", "Ocean sensitivity_Low", "Ocean sensitivity_Medium", "Ocean sensitivity_PIGL", "Ice shelf fracture_False", "Ice shelf fracture_True", ] ISEFlow_GrIS_v1_0_0_variables: list[str] = [ "year", "sector", "aSMB", "aST", "thermal_forcing", "basin_runoff", "aSMB.lag1", "aST.lag1", "thermal_forcing.lag1", "basin_runoff.lag1", "aSMB.lag2", "aST.lag2", "thermal_forcing.lag2", "basin_runoff.lag2", "aSMB.lag3", "aST.lag3", "thermal_forcing.lag3", "basin_runoff.lag3", "aSMB.lag4", "aST.lag4", "thermal_forcing.lag4", "basin_runoff.lag4", "aSMB.lag5", "aST.lag5", "thermal_forcing.lag5", "basin_runoff.lag5", "initial_year", "numerics_FD", "numerics_FD_FV5", "numerics_FE", "numerics_FV", "ice_flow_HO", "ice_flow_HYB", "ice_flow_SIA", "ice_flow_SSA", "initialization_CYC_DAI", "initialization_CYC_NDM", "initialization_CYC_NDS", "initialization_DAV", "initialization_SP_DAI", "initialization_SP_DAS", "initialization_SP_DAV", "initialization_SP_NDM", "initialization_SP_NDS", "initial_smb_BOX_MAR", "initial_smb_BOX_RA3", "initial_smb_HIR", "initial_smb_ISMB", "initial_smb_MAR", "initial_smb_RA1", "initial_smb_RA3", "velocity_J", "velocity_RM", "bed_B", "bed_M", "surface_thickness_M", "ghf_G", "ghf_MIX", "ghf_SR", "res_min_0.2", "res_min_0.25", "res_min_0.5", "res_min_0.75", "res_min_0.9", "res_min_1.0", "res_min_1.2", "res_min_2.0", "res_min_3.0", "res_min_4.0", "res_min_5.0", "res_min_8.0", "res_min_16.0", "res_max_0.9", "res_max_2.0", "res_max_4.0", "res_max_4.8", "res_max_5.0", "res_max_7.5", "res_max_8.0", "res_max_14.0", "res_max_15.0", "res_max_16.0", "res_max_20.0", "res_max_25.0", "res_max_30.0", "Ocean forcing_Standard", "Ocean sensitivity_High", "Ocean sensitivity_Low", "Ocean sensitivity_Medium", "Ice shelf fracture_False", ] ISEFlow_GrIS_v1_1_0_variables = [ "year", "sector", "initial_year", "numerics_FD", "numerics_FD_FV5", "numerics_FE", "numerics_FV", "ice_flow_HO", "ice_flow_HYB", "ice_flow_SIA", "ice_flow_SSA", "initialization_CYC_DAI", "initialization_CYC_NDM", "initialization_CYC_NDS", "initialization_DAV", "initialization_SP_DAI", "initialization_SP_DAS", "initialization_SP_DAV", "initialization_SP_NDM", "initialization_SP_NDS", "initial_smb_BOX_MAR", "initial_smb_BOX_RA3", "initial_smb_HIR", "initial_smb_ISMB", "initial_smb_MAR", "initial_smb_RA1", "initial_smb_RA3", "velocity_J", "velocity_RM", "bed_B", "bed_M", "surface_thickness_M", "ghf_G", "ghf_MIX", "ghf_SR", "res_min_0.2", "res_min_0.25", "res_min_0.5", "res_min_0.75", "res_min_0.9", "res_min_1.0", "res_min_1.2", "res_min_2.0", "res_min_3.0", "res_min_4.0", "res_min_5.0", "res_min_8.0", "res_min_16.0", "res_max_0.9", "res_max_2.0", "res_max_4.0", "res_max_4.8", "res_max_5.0", "res_max_7.5", "res_max_8.0", "res_max_14.0", "res_max_15.0", "res_max_16.0", "res_max_20.0", "res_max_25.0", "res_max_30.0", "Ocean forcing_Standard", "Ocean sensitivity_High", "Ocean sensitivity_Low", "Ocean sensitivity_Medium", "Ice shelf fracture_False", "aSMB", "aST", "thermal_forcing", "basin_runoff", "aSMB.lag1", "aST.lag1", "thermal_forcing.lag1", "basin_runoff.lag1", "aSMB.lag2", "aST.lag2", "thermal_forcing.lag2", "basin_runoff.lag2", "aSMB.lag3", "aST.lag3", "thermal_forcing.lag3", "basin_runoff.lag3", "aSMB.lag4", "aST.lag4", "thermal_forcing.lag4", "basin_runoff.lag4", "aSMB.lag5", "aST.lag5", "thermal_forcing.lag5", "basin_runoff.lag5", ]