ciss_vae.classes.vae.CISSVAE

class CISSVAE(*args: Any, **kwargs: Any)[source]

Bases: Module

Clustering-Informed Shared-Structure Variational Autoencoder (CISSVAE).

Supports flexible mixtures of shared and unshared layers across clusters in both encoder and decoder. Unshared layers are applied by cluster, shared layers are applied to all samples.

Parameters:
  • input_dim (int) – Number of input features (columns).

  • hidden_dims (list[int]) – Width of each hidden layer (encoder goes forward, decoder uses the reverse).

  • layer_order_enc (list[str]) – Per‑encoder‑layer directive: "shared" or "unshared".

  • layer_order_dec (list[str]) – Per‑decoder‑layer directive: "shared" or "unshared".

  • latent_shared (bool) – If True, the latent heads (mu, logvar) are shared across clusters; otherwise one head per cluster.

  • latent_dim (int) – Dimensionality of the latent space.

  • output_shared (bool) – If True, the final output layer is shared; otherwise one output layer per cluster.

  • num_clusters (int) – Number of clusters present in the data.

  • debug (bool) – If True, prints routing shapes and asserts row order invariants.

Raises:

ValueError – If an item of layer_order_enc or layer_order_dec is not one of

{"shared","unshared","s","u"} (case‑insensitive), or if their lengths do not match len(hidden_dims) for the respective path.

Expected shapes
  • Encoder input x: (batch, input_dim)

  • Cluster labels cluster_labels: (batch,) (LongTensor with values in [0, num_clusters-1])

  • Decoder/Output: (batch, input_dim)

Notes
  • The decoder consumes hidden_dims[::-1] (reverse order).

  • Unshared layers maintain per‑cluster ModuleList/ModuleDict replicas.

  • Routing never reorders rows; masks are used to apply cluster‑specific sublayers in‑place.

__init__(input_dim, hidden_dims, layer_order_enc, layer_order_dec, latent_shared, latent_dim, output_shared, num_clusters, activation_groups=None, debug=False)[source]

Variational Autoencoder supporting flexible shared/unshared layers across clusters.

Parameters:
  • input_dim (int) – Number of input features.

  • hidden_dims (list[int]) – Dimensions of hidden layers.

  • layer_order_enc (list[str]) – Layer type for each encoder layer ("shared" or "unshared").

  • layer_order_dec (list[str]) – Layer type for each decoder layer ("shared" or "unshared").

  • latent_shared (bool) – Whether latent representation is shared across clusters.

  • latent_dim (int) – Dimensionality of the latent space.

  • output_shared (bool) – Whether output layer is shared across clusters.

  • num_clusters (int) – Number of clusters.

  • activation_groups – Dictionary mapping feature groups to column indices. Attribute of ClusterDataset object.

Expected format:

{

“continuous”: [int, …], “binary”: [int, …], “<categorical_name>”: [int, …], …

}

Each key defines a feature group: - “continuous”: indices of continuous-valued features - “binary”: indices of binary features - additional keys correspond to grouped categorical variables (e.g., one-hot encoded columns belonging to the same variable)

This structure is used to determine loss functions and output transformations outside the model.

Parameters:

debug (bool) – If True, print shape and routing information.

Methods

__init__(input_dim, hidden_dims, ...[, ...])

Variational Autoencoder supporting flexible shared/unshared layers across clusters.

decode(z, cluster_labels)

Decoder forward pass from latent z to reconstruction.

encode(x, cluster_labels)

Encoder forward pass producing mu and logvar.

forward(x, cluster_labels[, deterministic, ...])

Full VAE forward pass: encode → reparameterize → decode.

get_final_lr()

Returns the learning rate stored with self.set_final_lr/

get_imputed_valdata(dataset[, device, ...])

Compute imputed values for the validation dataset.

reparameterize(mu, logvar[, generator])

Reparameterization trick: z = mu + eps * exp(0.5 * logvar).

route_through_layers(x, cluster_labels, ...)

Apply a sequence of shared/unshared layers according to layer_type_list.

set_activation_groups(self, activation_groups)

set_final_lr(final_lr)

Stores final lr from initial training loop in model attributes to be accessed in refit loop.

route_through_layers(x, cluster_labels, layer_type_list, shared_layers, unshared_layers)[source]

Apply a sequence of shared/unshared layers according to layer_type_list.

For each position i: * if layer_type_list[i] is shared → apply shared_layers[i_shared] to all rows; * if unshared → for each cluster c, apply the c‑specific layer at that depth to the subset of rows where cluster_labels == c.

Parameters:
  • x (torch.Tensor, shape (batch, d_in)) – Input activations to be routed.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

  • layer_type_list (list[str]) – Sequence of "shared"/"unshared" flags (length = number of layers at this stage).

  • shared_layers (torch.nn.ModuleList) – Layers used when the directive is shared (index increases only when a shared layer is consumed).

  • unshared_layers (dict[str, torch.nn.ModuleList] | torch.nn.ModuleDict) – Per‑cluster lists of layers for unshared directives (index per cluster increases only when an unshared layer is consumed).

Returns:

Routed activations.

Return type:

torch.Tensor

Raises:

ValueError – If an entry in layer_type_list is invalid or if per‑cluster unshared stacks are inconsistent with the directives.

encode(x, cluster_labels)[source]

Encoder forward pass producing mu and logvar.

Parameters:
  • x (torch.Tensor, shape (batch, input_dim)) – Input batch.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

Returns:

Tuple (mu, logvar).

Return type:

tuple[torch.Tensor, torch.Tensor]

reparameterize(mu, logvar, generator=None)[source]

Reparameterization trick: z = mu + eps * exp(0.5 * logvar).

Parameters:
  • mu (torch.Tensor, shape (batch, latent_dim)) – Mean of the approximate posterior.

  • logvar (torch.Tensor, shape (batch, latent_dim)) – Log‑variance of the approximate posterior.

  • generator (torch.Generator) – Optionl for RNG control (default None)

Returns:

Sampled latent codes z.

Return type:

torch.Tensor

decode(z, cluster_labels)[source]

Decoder forward pass from latent z to reconstruction.

Parameters:
  • z (torch.Tensor, shape (batch, latent_dim)) – Latent codes.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

Returns:

Reconstructed inputs.

Return type:

torch.Tensor, shape (batch, input_dim)

forward(x, cluster_labels, deterministic=False, *, generator=None)[source]

Full VAE forward pass: encode → reparameterize → decode.

Parameters:
  • x (torch.Tensor, shape (batch, input_dim)) – Input batch.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

  • deterministic (bool) – Deterministic Evaluation of Model for Imputation (default False)

  • generator – Optionl for RNG control (default None)

:type torch.Generator

Returns:

Tuple (recon, mu, logvar).

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

set_final_lr(final_lr)[source]

Stores final lr from initial training loop in model attributes to be accessed in refit loop.

get_final_lr()[source]

Returns the learning rate stored with self.set_final_lr/

get_imputed_valdata(dataset, device='cpu', deterministic=True)[source]

Compute imputed values for the validation dataset.

Performs a forward pass through the trained VAE and reconstructs only the validation-masked entries. Outputs are interpreted according to activation_groups:

  • continuous → denormalized using mean/std

  • binary → sigmoid applied

  • categorical → softmax + argmax (one-hot reconstruction)

Parameters:
  • dataset (ClusterDataset) –

    Dataset object containing:
    • data : normalized tensor (N, D)

    • val_data : original data with NaNs at validation positions

    • cluster_labels : (N,)

    • feature_means : (D,)

    • feature_stds : (D,)

    • feature_names : list[str]

    • activation_groups : dict[str, list[int]]

  • device (str, default="cpu") – Device to run computations on.

  • deterministic (bool, default=True) – Whether to use deterministic forward pass.

Returns:

Tensor of shape (N, D) containing imputed validation values. Non-validation entries are set to NaN.

Return type:

torch.Tensor

set_activation_groups(activation_groups)

Attach feature activation structure to the model.

This method stores the resolved activation_groups dictionary, which defines how each input column should be interpreted during loss computation and imputation.

The model itself does NOT use this information during the forward pass; it is stored purely as metadata to ensure consistency between the model, dataset, loss functions, and imputation routines.

Parameters:

activation_groups (dict) –

Dictionary mapping feature types to column indices. Expected format:

{

“continuous”: [col_idx, …], “binary”: [col_idx, …], “<categorical_name>”: [col_idx, …], …

}

  • ”continuous”: indices of continuous-valued features

  • ”binary”: indices of binary features

  • Each additional key represents a grouped categorical variable (multiple columns corresponding to one variable)

Raises:

ValueError – If activation_groups is not a dictionary or contains invalid column indices.

Return type:

None

Notes

  • This replaces the old set_binary_features functionality.

  • The model outputs raw logits for all features; interpretation is handled

externally using activation_groups. - This function is primarily used after loading a model to reattach dataset structure if needed.