"""
Optuna-based hyperparameter tuning for CISS-VAE.
This module defines:
- :class:`SearchSpace`: a structured container describing tunable/fixed hyperparameters.
- :func:`autotune`: runs Optuna trials that train CISSVAE models and selects the best trial
by validation MSE, then retrains a final model with the best settings.
"""
import torch
import optuna
import json
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from ciss_vae.classes.vae import CISSVAE
from ciss_vae.classes.cluster_dataset import ClusterDataset
from ciss_vae.training.train_initial import train_vae_initial
from ciss_vae.training.train_refit import impute_and_refit_loop
from ciss_vae.utils.helpers import compute_val_mse
from itertools import combinations, product
import random
import sys
from pathlib import Path
# NEW: Rich imports for track() function
from rich.progress import track
from rich.console import Console
[docs]
class SearchSpace:
"""Defines tunable and fixed hyperparameter ranges for the Optuna search.
Parameters are specified as:
- **scalar**: fixed value (e.g., ``latent_dim=16``)
- **list**: categorical choice (e.g., ``hidden_dims=[64, 128, 256]``)
- **tuple**: range ``(min, max)`` for ``suggest_int`` or ``suggest_float``
:param num_hidden_layers: Number of encoder/decoder hidden layers, defaults to (1, 4)
:type num_hidden_layers: int or list[int] or tuple[int, int], optional
:param hidden_dims: Hidden dimension specification - int for repeated per layer, list for per-layer choices, tuple for range, defaults to [64, 512]
:type hidden_dims: int or list[int] or tuple[int, int], optional
:param latent_dim: Latent dimension size or range, defaults to [10, 100]
:type latent_dim: int or tuple[int, int], optional
:param latent_shared: Whether latent space is shared across clusters, defaults to [True, False]
:type latent_shared: bool or list[bool], optional
:param output_shared: Whether output layer is shared across clusters, defaults to [True, False]
:type output_shared: bool or list[bool], optional
:param lr: Initial learning rate or range, defaults to (1e-4, 1e-3)
:type lr: float or tuple[float, float], optional
:param decay_factor: Learning rate exponential decay factor or range, defaults to (0.9, 0.999)
:type decay_factor: float or tuple[float, float], optional
:param beta: KL divergence weight or range, defaults to 0.01
:type beta: float or tuple[float, float], optional
:param num_epochs: Number of epochs for initial training, defaults to 1000
:type num_epochs: int or tuple[int, int], optional
:param batch_size: Mini-batch size, defaults to 64
:type batch_size: int or tuple[int, int], optional
:param num_shared_encode: Candidate counts of shared encoder layers, defaults to [0, 1, 3]
:type num_shared_encode: list[int], optional
:param num_shared_decode: Candidate counts of shared decoder layers, defaults to [0, 1, 3]
:type num_shared_decode: list[int], optional
:param encoder_shared_placement: Strategy for arranging shared vs unshared layers in encoder, defaults to ["at_end", "at_start", "alternating", "random"]
:type encoder_shared_placement: list[str], optional
:param decoder_shared_placement: Strategy for arranging shared vs unshared layers in decoder, defaults to ["at_end", "at_start", "alternating", "random"]
:type decoder_shared_placement: list[str], optional
:param refit_patience: Early-stop patience for refit loops, defaults to 2
:type refit_patience: int or tuple[int, int], optional
:param refit_loops: Maximum number of refit loops, defaults to 100
:type refit_loops: int or tuple[int, int], optional
:param epochs_per_loop: Number of epochs per refit loop, defaults to 1000
:type epochs_per_loop: int or tuple[int, int], optional
:param reset_lr_refit: Whether to reset learning rate before refit, defaults to [True, False]
:type reset_lr_refit: bool or list[bool], optional
"""
[docs]
def __init__(self,
num_hidden_layers=(1, 4),
hidden_dims=[64, 512],
latent_dim=[10, 100],
latent_shared=[True, False],
output_shared=[True,False],
lr=(1e-4, 1e-3),
decay_factor=(0.9, 0.999),
weight_decay =0.001,
beta=0.01,
num_epochs=1000,
batch_size=64,
num_shared_encode=[0, 1, 3],
num_shared_decode=[0, 1, 3],
encoder_shared_placement=["at_end", "at_start", "alternating", "random"],
decoder_shared_placement=["at_end", "at_start", "alternating", "random"],
refit_patience=2,
refit_loops=100,
epochs_per_loop = 1000,
reset_lr_refit = [True, False]):
self.num_hidden_layers = num_hidden_layers
self.hidden_dims = hidden_dims
self.latent_dim = latent_dim
self.latent_shared = latent_shared
self.output_shared = output_shared
self.lr = lr
self.decay_factor = decay_factor
self.weight_decay = weight_decay
self.beta = beta
self.num_epochs = num_epochs
self.batch_size = batch_size
self.num_shared_encode = num_shared_encode
self.num_shared_decode = num_shared_decode
self.encoder_shared_placement = encoder_shared_placement
self.decoder_shared_placement = decoder_shared_placement
self.refit_patience = refit_patience
self.refit_loops = refit_loops
self.epochs_per_loop = epochs_per_loop
self.reset_lr_refit = reset_lr_refit
def _as_jsonable(self):
"""Return a dict of fields with tuples converted to lists (JSON-safe)."""
def convert(x):
if isinstance(x, tuple):
return [convert(v) for v in x]
if isinstance(x, list):
return [convert(v) for v in x]
if isinstance(x, dict):
return {k: convert(v) for k, v in x.items()}
return x
return {k: convert(v) for k, v in self.__dict__.items()}
[docs]
def save(self, file_path):
"""Save this search space to a JSON file.
:param file_path: Path to save file.
:type file_path: string
"""
p = Path(file_path)
with p.open("w", encoding="utf-8") as f:
json.dump(self._as_jsonable(), f, indent=2)
[docs]
@classmethod
def load(cls, file_path):
"""Load a search space from a JSON file and return a new instance.
:param file_path: Path to saved SearchSpace.
:type file_path: string
"""
p = Path(file_path)
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
# Note: JSON has lists, not tuples. The constructor accepts lists just fine.
return cls(**data)
def __str__(self):
"""Readable summary showing which parameters are tunable vs fixed."""
lines = ["SearchSpace("]
for k, v in self.__dict__.items():
tunable = isinstance(v, (list, tuple))
flag = "TUNABLE" if tunable else "FIXED"
lines.append(f" {k}: {v!r} [{flag}]")
lines.append(")")
return "\n".join(lines)
def __repr__(self):
"""Compact representation useful for debugging."""
tunables = [k for k, v in self.__dict__.items() if isinstance(v, (list, tuple))]
fixed = [k for k in self.__dict__ if k not in tunables]
return (
f"<SearchSpace tunable={tunables} fixed={fixed}>"
.replace("tunable", str(tunables))
.replace("fixed", str(fixed))
)
[docs]
def autotune(
search_space: SearchSpace,
train_dataset: ClusterDataset,
save_model_path=None,
save_search_space_path=None,
n_trials=20,
study_name="vae_autotune",
device_preference="cuda",
optuna_dashboard_db=None,
load_if_exists=True,
seed = 42,
verbose = False,
show_progress = False, # NEW: Added back progress parameter
# permute_hidden_layers: bool = True,
constant_layer_size: bool = False,
evaluate_all_orders: bool = False,
max_exhaustive_orders: int = 100,
return_history: bool = False,
n_jobs = 1, ## add param to docs,
debug = False,
):
r"""Optuna-based hyperparameter search for the CISSVAE model.
Runs initial training followed by impute-refit loops per trial, selecting the
trial with the lowest total imputation error (MSE + BCE + categorical CE). The best model is then retrained with
optimal hyperparameters and returned along with the imputed dataset.
:param search_space: Hyperparameter ranges and fixed values for optimization
:type search_space: SearchSpace
:param train_dataset: Dataset containing processed inputs, validation masks,
normalization statistics, cluster labels, and ``activation_groups`` defining
feature types (continuous, binary, categorical).
:type train_dataset: ClusterDataset
:param save_model_path: Optional path to save the best model's state_dict, defaults to None
:type save_model_path: str, optional
:param save_search_space_path: Optional path to dump the resolved search-space configuration, defaults to None
:type save_search_space_path: str, optional
:param n_trials: Number of Optuna trials to run, defaults to 20
:type n_trials: int, optional
:param study_name: Name identifier for the Optuna study, defaults to "vae_autotune"
:type study_name: str, optional
:param device_preference: Preferred device ("cuda" or "cpu"), falls back to CPU if CUDA unavailable, defaults to "cuda"
:type device_preference: str, optional
:param optuna_dashboard_db: RDB storage URL/file for Optuna dashboard or None for in-memory, defaults to None
:type optuna_dashboard_db: str, optional
:param load_if_exists: Whether to load existing study with the same name from storage, defaults to True
:type load_if_exists: bool, optional
:param seed: Base random number generator seed for reproducible order generation, defaults to 42
:type seed: int, optional
:param verbose: Whether to print detailed diagnostic logs during training, defaults to False
:type verbose: bool, optional
:param show_progress: Whether to display Rich progress bars during training, defaults to False
:type show_progress: bool, optional
:param constant_layer_size: Whether all hidden layers should use the same dimension size, defaults to False
:type constant_layer_size: bool, optional
:param evaluate_all_orders: Whether to permute and evaluate all possible shared/unshared layer orders, defaults to False
:type evaluate_all_orders: bool, optional
:param max_exhaustive_orders: Maximum number of layer order permutations to test when evaluate_all_orders is True, defaults to 100
:type max_exhaustive_orders: int, optional
:param return_history: Whether to return MSE training history dataframe of the best model, defaults to False
:type return_history: bool, optional
:param n_jobs: Number of parallel Optuna jobs. Defaults to 1.
Note: Values other than 1 may result in non-deterministic behavior
despite fixed random seeds.
:type n_jobs: int, optional
:param debug: Defaults to False. Set True for informative debugging statements.
:type debug: bool, optional
:return: Tuple containing (best_imputed_dataframe, best_model, optuna_study_object, results_dataframe, optional[best_model_history_df])
:rtype: tuple[pandas.DataFrame, CISSVAE, optuna.study.Study, pandas.DataFrame] or tuple[pandas.DataFrame, CISSVAE, optuna.study.Study, pandas.DataFrame, pandas.DataFrame]
:raises ValueError: If search space parameters are malformed or incompatible
:raises RuntimeError: If CUDA is requested but not available and fallback fails
"""
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# NEW: Initialize Rich console
console = Console()
## Set torch deterministic
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
# --------------------------
# Infer device
# --------------------------
if device_preference == "cuda" and not torch.cuda.is_available():
print("[Warning] CUDA requested but not available. Falling back to CPU.")
device = torch.device("cpu")
else:
device = torch.device(device_preference)
direction="minimize"
## Save search space if asked
if save_search_space_path is not None:
search_space.save(save_search_space_path)
# --------------------------
# Infer input dim and num clusters
# --------------------------
input_dim = train_dataset.raw_data.shape[1]
num_clusters = len(torch.unique(train_dataset.cluster_labels))
# --------------------------
# Helper to sample from fixed, categorical, or range param
# --------------------------
def sample_param(trial, name, value):
if isinstance(value, (int, float, bool, str)):
return value
elif isinstance(value, list):
return trial.suggest_categorical(name, value)
elif isinstance(value, tuple):
if all(isinstance(v, int) for v in value):
return trial.suggest_int(name, value[0], value[1])
elif all(isinstance(v, float) for v in value):
return trial.suggest_float(name, value[0], value[1], log=value[0] > 0)
raise ValueError(f"Unsupported parameter format for '{name}': {value}")
# --------------------------
# Helpers to format order of shared/unshared + control for enumerating or sampling form orders
# --------------------------
def _format_order(order_list):
"""['shared','unshared',...] → 'S,U,...' (stable, readable, categorical)"""
abbrev = {'shared': 'S', 'unshared': 'U'}
return ",".join(abbrev[x] for x in order_list)
def _decode_pattern(p: str):
"""'S,U,S' → ['shared','unshared','shared']"""
m = {'S': 'shared', 'U': 'unshared'}
return [m[x] for x in str(p).split(",")]
def _enumerate_orders(n_layers: int, n_shared: int):
"""Deterministically enumerate **all** valid orders (no randomness)."""
if n_layers < 0 or n_shared < 0 or n_shared > n_layers:
return []
patterns = []
for idxs in combinations(range(n_layers), n_shared):
arr = ['U'] * n_layers
for i in idxs:
arr[i] = 'S'
patterns.append(",".join(arr))
return patterns
## don't need this since we have the things from build order
# def _canonical_orders(n_layers: int, nse: int, nsd: int):
# """Canonical (non-permuted) layout. Unshared at beginning of encoder and end of decoder"""
# enc_list = (["unshared"] * (n_layers - nse)) + (["shared"] * nse)
# dec_list = (["shared"] * nsd) + (["unshared"] * (n_layers - nsd))
# return _format_order(enc_list), _format_order(dec_list)
def _build_order(style: str, n_layers: int, n_shared: int, rng: random.Random):
n_shared = max(0, min(int(n_shared), int(n_layers)))
shared_positions = list(range(n_layers))
if style == "at_end":
pos = list(range(n_layers - n_shared, n_layers))
elif style == "at_start":
pos = list(range(0, n_shared))
elif style == "alternating":
pos = list(range(0, n_layers, max(1, n_layers // max(1, n_shared))))[:n_shared] if n_shared > 0 else []
elif style == "random":
pos = rng.sample(shared_positions, n_shared)
else:
pos = list(range(n_layers - n_shared, n_layers)) # fallback
arr = ['unshared'] * n_layers
for i in pos:
arr[i] = 'shared'
return arr
# ------------------------------------------------
# Wrapper for train_initial and impute_refit functions that uses track()
# ------------------------------------------------
def train_vae_initial_with_progress(model, train_loader, epochs, initial_lr, decay_factor, weight_decay, beta, device, verbose_inner=False, seed = seed):
"""Wrapper for train_vae_initial that uses Rich track() for progress"""
if show_progress:
# Create a simple progress tracker using track()
for epoch in track(range(epochs), description="Initial training"):
# Call the actual training function for one epoch at a time
# Note: This assumes train_vae_initial can be called with epochs=1
model = train_vae_initial(
model=model,
train_loader=train_loader,
epochs=1,
initial_lr=initial_lr,
decay_factor=decay_factor,
beta=beta,
device=device,
weight_decay = weight_decay,
verbose=False, # Disable verbose to avoid spam
seed = seed,
)
else:
# Call original function without progress
model = train_vae_initial(
model=model,
train_loader=train_loader,
epochs=epochs,
initial_lr=initial_lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=verbose_inner,
seed = seed,
)
return model
def impute_and_refit_loop_with_progress(model, train_loader, max_loops, patience, epochs_per_loop,
initial_lr, decay_factor, weight_decay, beta, device, verbose_inner=False, batch_size=64, seed = seed):
"""Wrapper for impute_and_refit_loop that uses Rich track() for progress"""
if show_progress:
# Estimate total epochs for progress bar
estimated_total_epochs = max_loops * epochs_per_loop
# Use track() to show progress over estimated epochs
epoch_counter = 0
progress_iter = track(range(estimated_total_epochs), description="Refit loops")
progress_iter = iter(progress_iter) # Convert to iterator
# Create a callback that advances the progress bar
def progress_callback(n=1):
nonlocal epoch_counter
try:
for _ in range(n):
next(progress_iter)
epoch_counter += 1
except StopIteration:
pass # Progress bar completed
# Call the original function with progress callback
return impute_and_refit_loop(
model=model,
train_loader=train_loader,
max_loops=max_loops,
patience=patience,
epochs_per_loop=epochs_per_loop,
initial_lr=initial_lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=verbose_inner,
batch_size=batch_size,
progress_epoch=progress_callback,
seed = seed,
)
else:
# Call original function without progress
return impute_and_refit_loop(
model=model,
train_loader=train_loader,
max_loops=max_loops,
patience=patience,
epochs_per_loop=epochs_per_loop,
initial_lr=initial_lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=verbose_inner,
batch_size=batch_size,
seed = seed,
)
# --------------------------
# Create Optuna objective
# - Added set random generators
# --------------------------
def objective(trial):
"""Runs initial training followed by impute-refit loops per trial, selecting the
trial with the lowest total imputation error (MSE + BCE + categorical CE)."""
## Set random generators
trial_seed = seed * 10_000 + trial.number
random.seed(trial_seed)
np.random.seed(trial_seed)
torch.manual_seed(trial_seed)
torch.cuda.manual_seed_all(trial_seed)
# NEW: Use Rich console for trial progress
if show_progress:
console.print(f"\n[green]Trial {trial.number}/{n_trials}")
elif verbose:
print(f"\nStarting Trial {trial.number}/{n_trials}")
# ----------------
# Check train_dataset for na causers
# -----------------
with torch.no_grad():
cl = train_dataset.cluster_labels.clone()
uniq = torch.unique(cl).cpu().tolist()
remap = {old:i for i,old in enumerate(uniq)}
new_cl = cl.cpu().apply_(lambda v: remap[int(v)]).to(train_dataset.cluster_labels.device)
train_dataset.cluster_labels = new_cl
# --------------------------
# Parse Parameters
# --------------------------
num_hidden_layers = sample_param(trial, "num_hidden_layers", search_space.num_hidden_layers)
# ---- Hidden dimensions ----
if constant_layer_size:
base_dim = sample_param(trial, "hidden_dim_constant", search_space.hidden_dims)
hidden_dims = [base_dim] * num_hidden_layers
else:
hidden_dims = [
sample_param(trial, f"hidden_dim_{i}", search_space.hidden_dims)
for i in range(num_hidden_layers)
]
latent_dim = sample_param(trial, "latent_dim", search_space.latent_dim)
latent_shared = sample_param(trial, "latent_shared", search_space.latent_shared)
output_shared = sample_param(trial, "output_shared", search_space.output_shared)
learning_rate = sample_param(trial, "lr", search_space.lr)
decay_factor = sample_param(trial, "decay_factor", search_space.decay_factor)
weight_decay = sample_param(trial, "weight_decay", search_space.weight_decay)
beta = sample_param(trial, "beta", search_space.beta)
num_epochs = sample_param(trial, "num_epochs", search_space.num_epochs)
batch_size = sample_param(trial, "batch_size", search_space.batch_size)
# Handle num_shared_encode/decode ## updated 16SEP2025
nse_raw = sample_param(trial, "num_shared_encode", search_space.num_shared_encode)
nsd_raw = sample_param(trial, "num_shared_decode", search_space.num_shared_decode)
num_shared_encode = max(0, min(int(nse_raw), int(num_hidden_layers)))
num_shared_decode = max(0, min(int(nsd_raw), int(num_hidden_layers)))
encoder_shared_placement = trial.suggest_categorical("encoder_shared_placement", search_space.encoder_shared_placement)
decoder_shared_placement = trial.suggest_categorical("decoder_shared_placement", search_space.decoder_shared_placement)
refit_patience = sample_param(trial, "refit_patience", search_space.refit_patience)
refit_loops = sample_param(trial, "refit_loops", search_space.refit_loops)
epochs_per_loop = sample_param(trial, "epochs_per_loop", search_space.epochs_per_loop)
reset_lr_refit = sample_param(trial, "reset_lr_refit", search_space.reset_lr_refit)
trial.set_user_attr("num_shared_encode_effective", int(num_shared_encode))
trial.set_user_attr("num_shared_decode_effective", int(num_shared_decode))
lr_refit = learning_rate if reset_lr_refit else None
## Show parameters with Rich (progress bar)
if show_progress:
console.print(f"Parameters: layers={num_hidden_layers}, latent_dim={latent_dim}, lr={learning_rate:.2e}, batch_size={batch_size}")
elif verbose:
print(f" Parameters: layers={num_hidden_layers}, latent_dim={latent_dim}, lr={learning_rate:.2e}")
# --------------------------
# Build orders for shared/unshared layers
# --------------------------
if evaluate_all_orders:
enc_pool = _enumerate_orders(num_hidden_layers, num_shared_encode)
dec_pool = _enumerate_orders(num_hidden_layers, num_shared_decode)
combos_to_eval = list(product(enc_pool, dec_pool))
# FIXED: Apply max_exhaustive_orders limit !!
if len(combos_to_eval) > max_exhaustive_orders:
rng_ord = random.Random(seed * 9912 + trial.number) # Use trial number for reproducibility
combos_to_eval = rng_ord.sample(combos_to_eval, k=max_exhaustive_orders)
else:
# -----------------------
# Reproducable seeds for random generated layer order
# - uses _build_order() to determine the appropriate layer orders based on enc/decoder_shared_placement
# -----------------------
rng_enc = random.Random(seed * 9973 + trial.number)
rng_dec = random.Random(seed * 9967 + trial.number)
layer_order_enc = _build_order(encoder_shared_placement, num_hidden_layers, num_shared_encode, rng_enc)
layer_order_dec = _build_order(decoder_shared_placement, num_hidden_layers, num_shared_decode, rng_dec)
combos_to_eval = [(_format_order(layer_order_enc), _format_order(layer_order_dec))]
best_val = None
best_patterns = None
best_refit_history_df = None
# --------------------------
# Train each combination
# --------------------------
for enc_pat, dec_pat in combos_to_eval:
layer_order_enc = _decode_pattern(enc_pat)
layer_order_dec = _decode_pattern(dec_pat)
## Seeded Data Loader
g = torch.Generator()
g.manual_seed(trial_seed)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
generator=g,
)
model = CISSVAE(
input_dim=input_dim,
hidden_dims=hidden_dims,
layer_order_enc=layer_order_enc,
layer_order_dec=layer_order_dec,
latent_shared=latent_shared,
num_clusters=num_clusters,
latent_dim=latent_dim,
output_shared=output_shared,
activation_groups = train_dataset.activation_groups,
debug = debug
).to(device)
## Use progress wrappers instead of original functions
model = train_vae_initial_with_progress(
model=model,
train_loader=train_loader,
epochs=num_epochs,
initial_lr=learning_rate,
decay_factor=decay_factor,
weight_decay=weight_decay,
beta=beta,
device=device,
verbose_inner=verbose,
seed = seed,
)
_, model, _ = impute_and_refit_loop_with_progress(
model=model,
train_loader=train_loader,
max_loops=refit_loops,
patience=refit_patience,
epochs_per_loop=epochs_per_loop,
initial_lr=lr_refit,
decay_factor=decay_factor,
weight_decay=weight_decay,
beta=beta,
device=device,
verbose_inner=verbose,
batch_size=batch_size,
seed = seed,
)
# Get validation MSE
imputation_error, val_mse, val_bce, val_ce = compute_val_mse(model, train_loader.dataset, device)
if (best_val is None) or (imputation_error < best_val):
best_val = imputation_error
best_patterns = (enc_pat, dec_pat)
best_refit_history_df = model.training_history_
trial.set_user_attr("best_val_mse", val_mse)
trial.set_user_attr("best_val_bce", val_bce)
trial.set_user_attr("best_val_ce", val_ce)
# Show completion with Rich
if show_progress:
console.print(f"✓ Trial {trial.number + 1} complete - Total Imputation Error: {best_val:.4f}")
elif verbose:
print(f" Trial {trial.number + 1} complete - Total Imputation Error: {best_val:.6f}")
# Report intermediate values to Optuna
if best_refit_history_df is not None and "imputation_error" in best_refit_history_df.columns:
for i, v in enumerate(best_refit_history_df["imputation_error"]):
if pd.notna(v):
trial.report(float(v), step=i)
# Record the chosen best patterns for this trial
if best_patterns is not None:
trial.set_user_attr("best_layer_order_enc", best_patterns[0])
trial.set_user_attr("best_layer_order_dec", best_patterns[1])
return best_val
# -----------------------
# Optuna study setup
# - added optuna seeding
# -----------------------
sampler = optuna.samplers.TPESampler(seed = seed)
study = optuna.create_study(
direction=direction,
study_name=study_name,
storage=optuna_dashboard_db,
load_if_exists=load_if_exists,
sampler = sampler,
)
study.set_metric_names(["Total Imputation Error"])
# -----------------------
# Run optimization
# -----------------------
# Use Rich console for study start
if show_progress:
console.print(f"[bold blue]Starting Optuna optimization with {n_trials} trials...")
else:
print(f"Starting Optuna optimization with {n_trials} trials...")
study.optimize(objective, n_trials=n_trials, n_jobs = n_jobs)
# Use Rich console for completion
if show_progress:
console.print(f"\n[bold green]✓ Optimization complete!")
console.print(f"Best trial: {study.best_trial.number} (Total Imputation Error: {study.best_value:.6f})")
else:
print(f"Optimization complete. Best trial: {study.best_trial.number} (Total Imputation Error: {study.best_value:.6f})")
# -----------------------
# Final model training
# -----------------------
if show_progress:
console.print(f"\n[bold cyan]Training final model with best parameters...")
else:
print("Training final model with best parameters...")
# ---------------------
# Extract best params for final model training
# ---------------------
best_params = study.best_trial.params
def get_best_param(name):
if name in best_params:
return best_params[name]
else:
return getattr(search_space, name)
## get num hidden layers
best_num_hidden_layers = get_best_param("num_hidden_layers")
if constant_layer_size:
if "hidden_dim_constant" in best_params:
base_dim = best_params["hidden_dim_constant"]
else:
base_dim = getattr(search_space, "hidden_dims", 64)
if isinstance(base_dim, (list, tuple)):
base_dim = base_dim[0]
best_hidden_dims = [int(base_dim)] * best_num_hidden_layers
else:
if f"hidden_dim_0" in best_params:
best_hidden_dims = [best_params[f"hidden_dim_{i}"] for i in range(best_num_hidden_layers)]
else:
hdims = get_best_param("hidden_dims")
if isinstance(hdims, list):
if len(hdims) == 1:
best_hidden_dims = hdims * best_num_hidden_layers
elif len(hdims) < best_num_hidden_layers:
best_hidden_dims = (hdims * best_num_hidden_layers)[:best_num_hidden_layers]
else:
best_hidden_dims = hdims[:best_num_hidden_layers]
else:
best_hidden_dims = [hdims] * best_num_hidden_layers
## get layer orders
ua = study.best_trial.user_attrs
nse_eff = int(ua.get("num_shared_encode_effective",
min(int(best_params.get("num_shared_encode", 0)), int(best_num_hidden_layers))))
nsd_eff = int(ua.get("num_shared_decode_effective",
min(int(best_params.get("num_shared_decode", 0)), int(best_num_hidden_layers))))
enc_pat = study.best_trial.user_attrs.get("best_layer_order_enc")
dec_pat = study.best_trial.user_attrs.get("best_layer_order_dec")
best_layer_order_enc = _decode_pattern(enc_pat)
best_layer_order_dec = _decode_pattern(dec_pat)
latent_shared = bool(get_best_param("latent_shared"))
output_shared = bool(get_best_param("output_shared"))
latent_dim = int(get_best_param("latent_dim"))
num_epochs = int(get_best_param("num_epochs"))
lr = float(get_best_param("lr"))
decay_factor = float(get_best_param("decay_factor"))
beta = float(get_best_param("beta"))
batch_size = int(get_best_param("batch_size"))
refit_patience = int(get_best_param("refit_patience"))
refit_loops = int(get_best_param("refit_loops"))
epochs_per_loop = int(get_best_param("epochs_per_loop"))
reset_lr_refit = bool(get_best_param("reset_lr_refit"))
weight_decay = float(get_best_param("weight_decay"))
# ---------------------------
# Build & train final model
# ---------------------------
## added generator for dataloader
g_final = torch.Generator()
g_final.manual_seed(seed)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator = g_final)
best_model = CISSVAE(
input_dim=input_dim,
hidden_dims=best_hidden_dims,
layer_order_enc=best_layer_order_enc,
layer_order_dec=best_layer_order_dec,
latent_shared=latent_shared,
num_clusters=num_clusters,
latent_dim=latent_dim,
output_shared=output_shared,
activation_groups = train_dataset.activation_groups,
debug = debug
).to(device)
# Initialize history tracking for final model
final_model_history = None
# Use track() for final training too
if show_progress:
# Final initial training with progress
if return_history:
initial_history_list = []
for epoch in track(range(num_epochs), description="Final initial training"):
result = train_vae_initial(
model=best_model,
train_loader=train_loader,
epochs=1,
initial_lr=lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=False,
return_history=True,
seed = seed
)
if isinstance(result, tuple):
best_model, epoch_history = result
initial_history_list.append(epoch_history)
else:
best_model = result
else:
for epoch in track(range(num_epochs), description="Final initial training"):
best_model = train_vae_initial(
model=best_model,
train_loader=train_loader,
epochs=1,
initial_lr=lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=False,
return_history=False,
seed = seed
)
# Final refit with progress
estimated_refit_epochs = refit_loops * epochs_per_loop
if isinstance(refit_loops, tuple):
estimated_refit_epochs = refit_loops[0] * epochs_per_loop
if isinstance(epochs_per_loop, tuple):
estimated_refit_epochs = refit_loops * epochs_per_loop[0]
progress_iter = track(range(estimated_refit_epochs), description="Final refit loops")
progress_iter = iter(progress_iter)
def final_progress_callback(n=1):
try:
for _ in range(n):
next(progress_iter)
except StopIteration:
pass
best_imputed_df, best_model, _ = impute_and_refit_loop(
model=best_model,
train_loader=train_loader,
max_loops=refit_loops,
patience=refit_patience,
epochs_per_loop=epochs_per_loop,
initial_lr=lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=verbose,
progress_epoch=final_progress_callback,
seed = seed
)
# Combine initial and refit histories if requested
if return_history:
final_model_history = best_model.training_history_
else:
# Final training without progress
if return_history:
result = train_vae_initial(
model=best_model,
train_loader=train_loader,
epochs=num_epochs,
initial_lr=lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=verbose,
return_history=True,
seed = seed,
)
if isinstance(result, tuple):
best_model, initial_history_df = result
else:
best_model = result
initial_history_df = None
else:
best_model = train_vae_initial(
model=best_model,
train_loader=train_loader,
epochs=num_epochs,
initial_lr=lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=verbose,
return_history=False,
seed = seed
)
best_imputed_df, best_model, _ = impute_and_refit_loop(
model=best_model,
train_loader=train_loader,
max_loops=refit_loops,
patience=refit_patience,
epochs_per_loop=epochs_per_loop,
initial_lr=lr,
decay_factor=decay_factor,
weight_decay = weight_decay,
beta=beta,
device=device,
verbose=verbose,
seed = seed,
)
# Combine initial and refit histories if requested
if return_history:
final_model_history = best_model.training_history_
if show_progress:
console.print("[bold green]✓ Final model training complete!")
else:
print("Final model training complete.")
# -----------------------
# Save results
# -----------------------
if save_model_path:
torch.save(best_model, save_model_path)
print(f"Model saved to {save_model_path}")
# Create results DataFrame
rows = []
for t in study.trials:
row = {"trial_number": t.number, "imputation_error": t.value, **t.params}
row["layer_order_enc_used"] = (
t.params.get("layer_order_enc")
or t.user_attrs.get("best_layer_order_enc")
or t.user_attrs.get("layer_order_enc_used")
)
row["layer_order_dec_used"] = (
t.params.get("layer_order_dec")
or t.user_attrs.get("best_layer_order_dec")
or t.user_attrs.get("layer_order_dec_used")
)
rows.append(row)
results_df = pd.DataFrame(rows)
# Return results based on return_history parameter
if return_history:
return best_imputed_df, best_model, study, results_df, final_model_history
else:
return best_imputed_df, best_model, study, results_df