Source code for ciss_vae.training.train_initial

import torch
import pandas as pd
import torch.optim as optim
from torch.optim import lr_scheduler
from ciss_vae.utils.loss import loss_function
import torch.nn.functional as F
from ciss_vae.utils.helpers import compute_val_mse, get_imputed_df

[docs] def train_vae_initial( model, train_loader, epochs: int = 10, initial_lr: float = 0.01, decay_factor: float = 0.999, beta: float = 0.1, device: str = "cpu", verbose: bool = False, *, return_history: bool = False, progress_callback = None, weight_decay = 0.001, seed = 42, ): """Train a VAE on masked data with validation monitoring for initial training phase. Performs the initial training of a CISSVAE model on data with missing values using masked loss computation. Tracks training loss and validation MSE across epochs, with optional progress reporting and learning rate decay. :param model: CISSVAE or compatible VAE model that implements forward(x, cluster_id) :type model: torch.nn.Module :param train_loader: DataLoader built on ClusterDataset containing validation data :type train_loader: torch.utils.data.DataLoader :param epochs: Number of training epochs, defaults to 10 :type epochs: int, optional :param initial_lr: Starting learning rate for Adam optimizer, defaults to 0.01 :type initial_lr: float, optional :param decay_factor: Exponential learning rate decay factor applied per epoch, defaults to 0.999 :type decay_factor: float, optional :param beta: Weight coefficient for KL divergence term in VAE loss, defaults to 0.1 :type beta: float, optional :param device: Device for training computations ("cpu" or "cuda"), defaults to "cpu" :type device: str, optional :param verbose: Whether to print per-epoch training metrics, defaults to False :type verbose: bool, optional :param return_history: Whether to return training history DataFrame along with model, defaults to False :type return_history: bool, optional :param progress_callback: Optional callback function for progress reporting, defaults to None :type progress_callback: callable, optional :return: Trained model, or tuple of (model, history_dataframe) if return_history=True :rtype: torch.nn.Module or tuple[torch.nn.Module, pandas.DataFrame] :raises ValueError: If dataset does not contain 'val_data' attribute for validation """ model.to(device) optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=decay_factor) g = torch.Generator(device=device) g.manual_seed(seed) dataset = train_loader.dataset if not hasattr(dataset, "val_data"): raise ValueError("Dataset must include 'val_data'.") def _to_scalar(x): return x.detach().cpu().item() if torch.is_tensor(x) else x history = { "epoch": [], "train_loss": [], "train_mse": [], "train_bce": [], "train_ce": [], "imputation_error": [], "val_mse": [], "val_bce": [], "val_ce": [], "lr": [], } for epoch in range(epochs): model.train() total_loss = 0 for batch in train_loader: x_batch, cluster_batch, mask_batch, idx_batch = batch x_batch = x_batch.to(device) cluster_batch = cluster_batch.to(device) mask_batch = mask_batch.to(device) if hasattr(dataset, "imputable") and dataset.imputable is not None: imputable_batch = dataset.imputable[idx_batch].to(device).float() else: imputable_batch = None # ------------------------- # Forward (LOGITS) # ------------------------- recon_x, mu, logvar = model(x_batch, cluster_batch, generator=g) # ------------------------- # UPDATED LOSS CALL # ------------------------- loss, train_mse, train_bce, train_ce = loss_function( cluster_batch, mask_batch, recon_x, x_batch, dataset.activation_groups, # 🔥 FIXED mu, logvar, beta=beta, return_components=True, device=device, debug=model.debug, ) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() avg_train_loss = total_loss / len(train_loader.dataset) current_lr = optimizer.param_groups[0]["lr"] # ------------------------- # VALIDATION # ------------------------- try: imputation_error, val_mse, val_bce, val_ce = compute_val_mse( model, dataset, device ) except ValueError as e: if verbose: print(f"[WARNING] Epoch {epoch+1}: {e}") val_mse, val_bce, val_ce = float("inf"), float("inf"), float("inf") imputation_error = float("inf") # ------------------------- # LOGGING # ------------------------- history["epoch"].append(epoch) history["train_loss"].append(avg_train_loss) history["train_mse"].append(_to_scalar(train_mse)) history["train_bce"].append(_to_scalar(train_bce)) history["train_ce"].append(_to_scalar(train_ce)) history["imputation_error"].append(_to_scalar(imputation_error)) history["val_mse"].append(_to_scalar(val_mse)) history["val_bce"].append(_to_scalar(val_bce)) history["val_ce"].append(_to_scalar(val_ce)) history["lr"].append(current_lr) if verbose: print( f"Epoch {epoch:3d} | " f"Train Loss: {avg_train_loss:.6f} | " f"MSE: {train_mse:.4f} BCE: {train_bce:.4f} CE: {train_ce:.4f} | " f"Val MSE: {val_mse:.4f} BCE: {val_bce:.4f} CE: {val_ce:.4f} | " f"LR: {current_lr:.6f}" ) if progress_callback is not None: try: progress_callback(n=1) except Exception: pass scheduler.step() model.set_final_lr(optimizer.param_groups[0]["lr"]) history_df = pd.DataFrame( history, columns=[ "epoch", "train_loss", "train_mse", "train_bce", "train_ce", "imputation_error", "val_mse", "val_bce", "val_ce", "lr", ], ) model.training_history_ = history_df return (model, history_df) if return_history else model