ciss_vae.classes.vae.CISSVAE
- class CISSVAE(*args: Any, **kwargs: Any)[source]
Bases:
ModuleClustering-Informed Shared-Structure Variational Autoencoder (CISSVAE).
Supports flexible mixtures of shared and unshared layers across clusters in both encoder and decoder. Unshared layers are applied by cluster, shared layers are applied to all samples.
- Parameters:
input_dim (int) – Number of input features (columns).
hidden_dims (list[int]) – Width of each hidden layer (encoder goes forward, decoder uses the reverse).
layer_order_enc (list[str]) – Per‑encoder‑layer directive:
"shared"or"unshared".layer_order_dec (list[str]) – Per‑decoder‑layer directive:
"shared"or"unshared".latent_shared (bool) – If
True, the latent heads (mu,logvar) are shared across clusters; otherwise one head per cluster.latent_dim (int) – Dimensionality of the latent space.
output_shared (bool) – If
True, the final output layer is shared; otherwise one output layer per cluster.num_clusters (int) – Number of clusters present in the data.
debug (bool) – If
True, prints routing shapes and asserts row order invariants.
- Raises:
ValueError – If an item of
layer_order_encorlayer_order_decis not one of
{"shared","unshared","s","u"}(case‑insensitive), or if their lengths do not matchlen(hidden_dims)for the respective path.- Expected shapes
Encoder input
x:(batch, input_dim)Cluster labels
cluster_labels:(batch,)(LongTensorwith values in[0, num_clusters-1])Decoder/Output:
(batch, input_dim)
- Notes
The decoder consumes
hidden_dims[::-1](reverse order).Unshared layers maintain per‑cluster
ModuleList/ModuleDictreplicas.Routing never reorders rows; masks are used to apply cluster‑specific sublayers in‑place.
- __init__(input_dim, hidden_dims, layer_order_enc, layer_order_dec, latent_shared, latent_dim, output_shared, num_clusters, activation_groups=None, debug=False)[source]
Variational Autoencoder supporting flexible shared/unshared layers across clusters.
- Parameters:
input_dim (int) – Number of input features.
layer_order_enc (list[str]) – Layer type for each encoder layer (
"shared"or"unshared").layer_order_dec (list[str]) – Layer type for each decoder layer (
"shared"or"unshared").latent_shared (bool) – Whether latent representation is shared across clusters.
latent_dim (int) – Dimensionality of the latent space.
output_shared (bool) – Whether output layer is shared across clusters.
num_clusters (int) – Number of clusters.
activation_groups – Dictionary mapping feature groups to column indices. Attribute of ClusterDataset object.
Expected format:
- {
“continuous”: [int, …], “binary”: [int, …], “<categorical_name>”: [int, …], …
}
Each key defines a feature group: - “continuous”: indices of continuous-valued features - “binary”: indices of binary features - additional keys correspond to grouped categorical variables (e.g., one-hot encoded columns belonging to the same variable)
This structure is used to determine loss functions and output transformations outside the model.
- Parameters:
debug (bool) – If
True, print shape and routing information.
Methods
__init__(input_dim, hidden_dims, ...[, ...])Variational Autoencoder supporting flexible shared/unshared layers across clusters.
decode(z, cluster_labels)Decoder forward pass from latent
zto reconstruction.encode(x, cluster_labels)Encoder forward pass producing
muandlogvar.forward(x, cluster_labels[, deterministic, ...])Full VAE forward pass: encode → reparameterize → decode.
Returns the learning rate stored with self.set_final_lr/
get_imputed_valdata(dataset[, device, ...])Compute imputed values for the validation dataset.
reparameterize(mu, logvar[, generator])Reparameterization trick:
z = mu + eps * exp(0.5 * logvar).route_through_layers(x, cluster_labels, ...)Apply a sequence of shared/unshared layers according to
layer_type_list.set_activation_groups(self, activation_groups)set_final_lr(final_lr)Stores final lr from initial training loop in model attributes to be accessed in refit loop.
- route_through_layers(x, cluster_labels, layer_type_list, shared_layers, unshared_layers)[source]
Apply a sequence of shared/unshared layers according to
layer_type_list.For each position
i: * iflayer_type_list[i]is shared → applyshared_layers[i_shared]to all rows; * if unshared → for each clusterc, apply thec‑specific layer at that depth to the subset of rows wherecluster_labels == c.- Parameters:
x (torch.Tensor, shape
(batch, d_in)) – Input activations to be routed.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.layer_type_list (list[str]) – Sequence of
"shared"/"unshared"flags (length = number of layers at this stage).shared_layers (torch.nn.ModuleList) – Layers used when the directive is shared (index increases only when a shared layer is consumed).
unshared_layers (dict[str, torch.nn.ModuleList] | torch.nn.ModuleDict) – Per‑cluster lists of layers for unshared directives (index per cluster increases only when an unshared layer is consumed).
- Returns:
Routed activations.
- Return type:
- Raises:
ValueError – If an entry in
layer_type_listis invalid or if per‑cluster unshared stacks are inconsistent with the directives.
- encode(x, cluster_labels)[source]
Encoder forward pass producing
muandlogvar.- Parameters:
x (torch.Tensor, shape
(batch, input_dim)) – Input batch.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.
- Returns:
Tuple
(mu, logvar).- Return type:
- reparameterize(mu, logvar, generator=None)[source]
Reparameterization trick:
z = mu + eps * exp(0.5 * logvar).- Parameters:
mu (torch.Tensor, shape
(batch, latent_dim)) – Mean of the approximate posterior.logvar (torch.Tensor, shape
(batch, latent_dim)) – Log‑variance of the approximate posterior.generator (torch.Generator) – Optionl for RNG control (default None)
- Returns:
Sampled latent codes
z.- Return type:
- decode(z, cluster_labels)[source]
Decoder forward pass from latent
zto reconstruction.- Parameters:
z (torch.Tensor, shape
(batch, latent_dim)) – Latent codes.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.
- Returns:
Reconstructed inputs.
- Return type:
torch.Tensor, shape
(batch, input_dim)
- forward(x, cluster_labels, deterministic=False, *, generator=None)[source]
Full VAE forward pass: encode → reparameterize → decode.
- Parameters:
x (torch.Tensor, shape
(batch, input_dim)) – Input batch.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.deterministic (bool) – Deterministic Evaluation of Model for Imputation (default False)
generator – Optionl for RNG control (default None)
:type torch.Generator
- Returns:
Tuple
(recon, mu, logvar).- Return type:
- set_final_lr(final_lr)[source]
Stores final lr from initial training loop in model attributes to be accessed in refit loop.
- get_imputed_valdata(dataset, device='cpu', deterministic=True)[source]
Compute imputed values for the validation dataset.
Performs a forward pass through the trained VAE and reconstructs only the validation-masked entries. Outputs are interpreted according to activation_groups:
continuous → denormalized using mean/std
binary → sigmoid applied
categorical → softmax + argmax (one-hot reconstruction)
- Parameters:
dataset (ClusterDataset) –
- Dataset object containing:
data : normalized tensor (N, D)
val_data : original data with NaNs at validation positions
cluster_labels : (N,)
feature_means : (D,)
feature_stds : (D,)
feature_names : list[str]
activation_groups : dict[str, list[int]]
device (str, default="cpu") – Device to run computations on.
deterministic (bool, default=True) – Whether to use deterministic forward pass.
- Returns:
Tensor of shape (N, D) containing imputed validation values. Non-validation entries are set to NaN.
- Return type:
- set_activation_groups(activation_groups)
Attach feature activation structure to the model.
This method stores the resolved activation_groups dictionary, which defines how each input column should be interpreted during loss computation and imputation.
The model itself does NOT use this information during the forward pass; it is stored purely as metadata to ensure consistency between the model, dataset, loss functions, and imputation routines.
- Parameters:
activation_groups (dict) –
Dictionary mapping feature types to column indices. Expected format:
- {
“continuous”: [col_idx, …], “binary”: [col_idx, …], “<categorical_name>”: [col_idx, …], …
}
”continuous”: indices of continuous-valued features
”binary”: indices of binary features
Each additional key represents a grouped categorical variable (multiple columns corresponding to one variable)
- Raises:
ValueError – If activation_groups is not a dictionary or contains invalid column indices.
- Return type:
Notes
This replaces the old set_binary_features functionality.
The model outputs raw logits for all features; interpretation is handled
externally using activation_groups. - This function is primarily used after loading a model to reattach dataset structure if needed.