ciss_vae.classes package

Submodules

ciss_vae.classes.cluster_dataset module

Dataset utilities for clustering-aware masking and normalization.

This module defines ClusterDataset, a PyTorch torch.utils.data.Dataset that (1) optionally holds out a validation subset of observed entries on a per-cluster basis, (2) normalizes features using statistics computed on the masked training matrix, and (3) exposes tensors required by the CISS-VAE training loops: normalized data with missing values filled, cluster labels, and binary observation masks.

Typical usage:

ds = ClusterDataset(
    data=df,                       # (N, P) with NaNs for missing
    cluster_labels=clusters,       # length-N array-like
    val_proportion=0.1,            # or per-cluster mapping/sequence
    replacement_value=0.0,
    columns_ignore=["id"]          # columns to exclude from validation masking
)
class ClusterDataset(*args: Any, **kwargs: Any)[source]

Bases: Dataset

Dataset that handles cluster-wise masking and normalization for VAE training.

1. Optionally holds out a validation subset per cluster from observed (non-NaN) entries according to val_proportion. 2. Combines original missingness with validation-held-out entries. 3. Normalizes observed values column-wise (mean/std), keeps masks for NaNs, and replaces NaNs (including held-out values) with replacement_value.

Parameters:
  • data (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Input matrix of shape (n_samples, n_features). May contain NaNs.

  • cluster_labels (array-like or None) – Cluster assignment per sample (length n_samples). If None, all rows are assigned to a single cluster 0.

  • val_proportion (float | collections.abc.Sequence | collections.abc.Mapping | pandas.Series) –

    Per-cluster fraction of non-missing entries to hold out for validation. Accepted forms:

    • float in [0, 1]: same fraction for all clusters

    • sequence (length = number of clusters): aligned to sorted(unique(cluster_labels))

    • mapping (e.g. {cluster_id: fraction}) covering all clusters

    • pandas.Series indexed by cluster IDs covering all clusters

  • replacement_value (float) – Value used to fill missing and held-out entries after masking.

  • columns_ignore (list[str | int] or None) – Columns to exclude from validation masking. Use column names for DataFrame and indices otherwise.

  • imputable (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Matrix indicating which entries should be excluded from imputation (1 = impute, 0 = exclude). Must have the same shape as data.

  • binary_feature_mask (list[bool] | numpy.ndarray) – Boolean vector of length n_features indicating binary columns. Used to construct activation_groups. Categorical dummy columns must also be marked as True.

  • categorical_column_map (dict[str, list[str | int]] or None) –

    Optional mapping from original categorical variable names to their corresponding dummy-variable columns. Example:

    {"C1": ["C1b1", "C1b2"], "C2": ["C2b1", "C2b2"]}
    

    These columns are grouped together in activation_groups and treated as categorical variables. All listed columns must also be marked as True in binary_feature_mask.

Variables:
  • raw_data (torch.FloatTensor) – Original data converted to float tensor (NaNs preserved).

  • data (torch.FloatTensor) – Normalized data with NaNs replaced by replacement_value.

  • masks (torch.BoolTensor) – Boolean mask where True indicates observed (non-NaN) entries before replacement.

  • val_data (torch.FloatTensor) – Tensor containing only validation-held-out values (others are NaN).

  • cluster_labels (torch.LongTensor) – Cluster ID for each row.

  • indices (torch.LongTensor) – Original row indices (from DataFrame index or arange for arrays/tensors).

  • feature_names (list[str]) – Column names (from DataFrame) or synthetic names (V1, V2, …).

  • n_clusters (int) – Number of unique clusters.

  • shape (tuple[int, int]) – Shape of self.data as (n_samples, n_features).

  • binary_feature_mask (numpy.ndarray) – Boolean mask indicating binary features.

  • activation_groups (dict) –

    Mapping of feature groups to column indices. Structure:

    {
        "continuous": [int, ...],
        "binary": [int, ...],
        "<categorical_name>": [int, ...],
        ...
    }
    
    • ”continuous”: indices of continuous-valued features

    • ”binary”: indices of binary features

    • Each additional key corresponds to a grouped categorical variable

    This structure is used for loss computation, imputation, and validation logic.

Raises:
  • TypeError – If data or cluster_labels are invalid types, or if val_proportion is not a supported type.

  • ValueError – If any proportion is outside [0, 1], or if cluster coverage is incomplete, or sequence lengths do not match number of clusters.

Note

  • Normalization uses column-wise mean and standard deviation computed from

observed values after validation masking. * Zero standard deviations are replaced with 1 to avoid division by zero. * Feature types are resolved into activation_groups and used throughout training, loss computation, and imputation.

get_activation_groups(exclude_ignored=False)[source]

Return activation groups, optionally excluding ignored columns.

Parameters:

exclude_ignored (bool) – If True, removes columns listed in columns_ignore.

Returns:

Filtered activation groups with ignored columns removed.

Return type:

dict

copy()[source]

Creates a deep copy of the ClusterDataset method containing all attributes.

Returns:

Deep copy of the dataset

Return type:

ClusterDataset

ciss_vae.classes.vae module

Variational Autoencoder with cluster‑aware shared/unshared layers.

This module defines CISSVAE, a VAE that can route samples through either shared or cluster‑specific (unshared) layers in the encoder and decoder, controlled by per‑layer directives. For all features, the model outputs raw logits. Feature-specific activation functions (e.g., sigmoid for binary, softmax for categorical) are applied externally using activation_groups to map features to correct activation function.

class CISSVAE(*args: Any, **kwargs: Any)[source]

Bases: Module

Clustering-Informed Shared-Structure Variational Autoencoder (CISSVAE).

Supports flexible mixtures of shared and unshared layers across clusters in both encoder and decoder. Unshared layers are applied by cluster, shared layers are applied to all samples.

Parameters:
  • input_dim (int) – Number of input features (columns).

  • hidden_dims (list[int]) – Width of each hidden layer (encoder goes forward, decoder uses the reverse).

  • layer_order_enc (list[str]) – Per‑encoder‑layer directive: "shared" or "unshared".

  • layer_order_dec (list[str]) – Per‑decoder‑layer directive: "shared" or "unshared".

  • latent_shared (bool) – If True, the latent heads (mu, logvar) are shared across clusters; otherwise one head per cluster.

  • latent_dim (int) – Dimensionality of the latent space.

  • output_shared (bool) – If True, the final output layer is shared; otherwise one output layer per cluster.

  • num_clusters (int) – Number of clusters present in the data.

  • debug (bool) – If True, prints routing shapes and asserts row order invariants.

Raises:

ValueError – If an item of layer_order_enc or layer_order_dec is not one of

{"shared","unshared","s","u"} (case‑insensitive), or if their lengths do not match len(hidden_dims) for the respective path.

Expected shapes
  • Encoder input x: (batch, input_dim)

  • Cluster labels cluster_labels: (batch,) (LongTensor with values in [0, num_clusters-1])

  • Decoder/Output: (batch, input_dim)

Notes
  • The decoder consumes hidden_dims[::-1] (reverse order).

  • Unshared layers maintain per‑cluster ModuleList/ModuleDict replicas.

  • Routing never reorders rows; masks are used to apply cluster‑specific sublayers in‑place.

route_through_layers(x, cluster_labels, layer_type_list, shared_layers, unshared_layers)[source]

Apply a sequence of shared/unshared layers according to layer_type_list.

For each position i: * if layer_type_list[i] is shared → apply shared_layers[i_shared] to all rows; * if unshared → for each cluster c, apply the c‑specific layer at that depth to the subset of rows where cluster_labels == c.

Parameters:
  • x (torch.Tensor, shape (batch, d_in)) – Input activations to be routed.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

  • layer_type_list (list[str]) – Sequence of "shared"/"unshared" flags (length = number of layers at this stage).

  • shared_layers (torch.nn.ModuleList) – Layers used when the directive is shared (index increases only when a shared layer is consumed).

  • unshared_layers (dict[str, torch.nn.ModuleList] | torch.nn.ModuleDict) – Per‑cluster lists of layers for unshared directives (index per cluster increases only when an unshared layer is consumed).

Returns:

Routed activations.

Return type:

torch.Tensor

Raises:

ValueError – If an entry in layer_type_list is invalid or if per‑cluster unshared stacks are inconsistent with the directives.

encode(x, cluster_labels)[source]

Encoder forward pass producing mu and logvar.

Parameters:
  • x (torch.Tensor, shape (batch, input_dim)) – Input batch.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

Returns:

Tuple (mu, logvar).

Return type:

tuple[torch.Tensor, torch.Tensor]

reparameterize(mu, logvar, generator=None)[source]

Reparameterization trick: z = mu + eps * exp(0.5 * logvar).

Parameters:
  • mu (torch.Tensor, shape (batch, latent_dim)) – Mean of the approximate posterior.

  • logvar (torch.Tensor, shape (batch, latent_dim)) – Log‑variance of the approximate posterior.

  • generator (torch.Generator) – Optionl for RNG control (default None)

Returns:

Sampled latent codes z.

Return type:

torch.Tensor

decode(z, cluster_labels)[source]

Decoder forward pass from latent z to reconstruction.

Parameters:
  • z (torch.Tensor, shape (batch, latent_dim)) – Latent codes.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

Returns:

Reconstructed inputs.

Return type:

torch.Tensor, shape (batch, input_dim)

forward(x, cluster_labels, deterministic=False, *, generator=None)[source]

Full VAE forward pass: encode → reparameterize → decode.

Parameters:
  • x (torch.Tensor, shape (batch, input_dim)) – Input batch.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

  • deterministic (bool) – Deterministic Evaluation of Model for Imputation (default False)

  • generator – Optionl for RNG control (default None)

:type torch.Generator

Returns:

Tuple (recon, mu, logvar).

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

set_final_lr(final_lr)[source]

Stores final lr from initial training loop in model attributes to be accessed in refit loop.

get_final_lr()[source]

Returns the learning rate stored with self.set_final_lr/

get_imputed_valdata(dataset, device='cpu', deterministic=True)[source]

Compute imputed values for the validation dataset.

Performs a forward pass through the trained VAE and reconstructs only the validation-masked entries. Outputs are interpreted according to activation_groups:

  • continuous → denormalized using mean/std

  • binary → sigmoid applied

  • categorical → softmax + argmax (one-hot reconstruction)

Parameters:
  • dataset (ClusterDataset) –

    Dataset object containing:
    • data : normalized tensor (N, D)

    • val_data : original data with NaNs at validation positions

    • cluster_labels : (N,)

    • feature_means : (D,)

    • feature_stds : (D,)

    • feature_names : list[str]

    • activation_groups : dict[str, list[int]]

  • device (str, default="cpu") – Device to run computations on.

  • deterministic (bool, default=True) – Whether to use deterministic forward pass.

Returns:

Tensor of shape (N, D) containing imputed validation values. Non-validation entries are set to NaN.

Return type:

torch.Tensor

set_activation_groups(activation_groups)

Attach feature activation structure to the model.

This method stores the resolved activation_groups dictionary, which defines how each input column should be interpreted during loss computation and imputation.

The model itself does NOT use this information during the forward pass; it is stored purely as metadata to ensure consistency between the model, dataset, loss functions, and imputation routines.

Parameters:

activation_groups (dict) –

Dictionary mapping feature types to column indices. Expected format:

{

“continuous”: [col_idx, …], “binary”: [col_idx, …], “<categorical_name>”: [col_idx, …], …

}

  • ”continuous”: indices of continuous-valued features

  • ”binary”: indices of binary features

  • Each additional key represents a grouped categorical variable (multiple columns corresponding to one variable)

Raises:

ValueError – If activation_groups is not a dictionary or contains invalid column indices.

Return type:

None

Notes

  • This replaces the old set_binary_features functionality.

  • The model outputs raw logits for all features; interpretation is handled

externally using activation_groups. - This function is primarily used after loading a model to reattach dataset structure if needed.

Module contents

class CISSVAE(*args: Any, **kwargs: Any)[source]

Bases: Module

Clustering-Informed Shared-Structure Variational Autoencoder (CISSVAE).

Supports flexible mixtures of shared and unshared layers across clusters in both encoder and decoder. Unshared layers are applied by cluster, shared layers are applied to all samples.

Parameters:
  • input_dim (int) – Number of input features (columns).

  • hidden_dims (list[int]) – Width of each hidden layer (encoder goes forward, decoder uses the reverse).

  • layer_order_enc (list[str]) – Per‑encoder‑layer directive: "shared" or "unshared".

  • layer_order_dec (list[str]) – Per‑decoder‑layer directive: "shared" or "unshared".

  • latent_shared (bool) – If True, the latent heads (mu, logvar) are shared across clusters; otherwise one head per cluster.

  • latent_dim (int) – Dimensionality of the latent space.

  • output_shared (bool) – If True, the final output layer is shared; otherwise one output layer per cluster.

  • num_clusters (int) – Number of clusters present in the data.

  • debug (bool) – If True, prints routing shapes and asserts row order invariants.

Raises:

ValueError – If an item of layer_order_enc or layer_order_dec is not one of

{"shared","unshared","s","u"} (case‑insensitive), or if their lengths do not match len(hidden_dims) for the respective path.

Expected shapes
  • Encoder input x: (batch, input_dim)

  • Cluster labels cluster_labels: (batch,) (LongTensor with values in [0, num_clusters-1])

  • Decoder/Output: (batch, input_dim)

Notes
  • The decoder consumes hidden_dims[::-1] (reverse order).

  • Unshared layers maintain per‑cluster ModuleList/ModuleDict replicas.

  • Routing never reorders rows; masks are used to apply cluster‑specific sublayers in‑place.

route_through_layers(x, cluster_labels, layer_type_list, shared_layers, unshared_layers)[source]

Apply a sequence of shared/unshared layers according to layer_type_list.

For each position i: * if layer_type_list[i] is shared → apply shared_layers[i_shared] to all rows; * if unshared → for each cluster c, apply the c‑specific layer at that depth to the subset of rows where cluster_labels == c.

Parameters:
  • x (torch.Tensor, shape (batch, d_in)) – Input activations to be routed.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

  • layer_type_list (list[str]) – Sequence of "shared"/"unshared" flags (length = number of layers at this stage).

  • shared_layers (torch.nn.ModuleList) – Layers used when the directive is shared (index increases only when a shared layer is consumed).

  • unshared_layers (dict[str, torch.nn.ModuleList] | torch.nn.ModuleDict) – Per‑cluster lists of layers for unshared directives (index per cluster increases only when an unshared layer is consumed).

Returns:

Routed activations.

Return type:

torch.Tensor

Raises:

ValueError – If an entry in layer_type_list is invalid or if per‑cluster unshared stacks are inconsistent with the directives.

encode(x, cluster_labels)[source]

Encoder forward pass producing mu and logvar.

Parameters:
  • x (torch.Tensor, shape (batch, input_dim)) – Input batch.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

Returns:

Tuple (mu, logvar).

Return type:

tuple[torch.Tensor, torch.Tensor]

reparameterize(mu, logvar, generator=None)[source]

Reparameterization trick: z = mu + eps * exp(0.5 * logvar).

Parameters:
  • mu (torch.Tensor, shape (batch, latent_dim)) – Mean of the approximate posterior.

  • logvar (torch.Tensor, shape (batch, latent_dim)) – Log‑variance of the approximate posterior.

  • generator (torch.Generator) – Optionl for RNG control (default None)

Returns:

Sampled latent codes z.

Return type:

torch.Tensor

decode(z, cluster_labels)[source]

Decoder forward pass from latent z to reconstruction.

Parameters:
  • z (torch.Tensor, shape (batch, latent_dim)) – Latent codes.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

Returns:

Reconstructed inputs.

Return type:

torch.Tensor, shape (batch, input_dim)

forward(x, cluster_labels, deterministic=False, *, generator=None)[source]

Full VAE forward pass: encode → reparameterize → decode.

Parameters:
  • x (torch.Tensor, shape (batch, input_dim)) – Input batch.

  • cluster_labels (torch.LongTensor, shape (batch,)) – Cluster id per row.

  • deterministic (bool) – Deterministic Evaluation of Model for Imputation (default False)

  • generator – Optionl for RNG control (default None)

:type torch.Generator

Returns:

Tuple (recon, mu, logvar).

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

set_final_lr(final_lr)[source]

Stores final lr from initial training loop in model attributes to be accessed in refit loop.

get_final_lr()[source]

Returns the learning rate stored with self.set_final_lr/

get_imputed_valdata(dataset, device='cpu', deterministic=True)[source]

Compute imputed values for the validation dataset.

Performs a forward pass through the trained VAE and reconstructs only the validation-masked entries. Outputs are interpreted according to activation_groups:

  • continuous → denormalized using mean/std

  • binary → sigmoid applied

  • categorical → softmax + argmax (one-hot reconstruction)

Parameters:
  • dataset (ClusterDataset) –

    Dataset object containing:
    • data : normalized tensor (N, D)

    • val_data : original data with NaNs at validation positions

    • cluster_labels : (N,)

    • feature_means : (D,)

    • feature_stds : (D,)

    • feature_names : list[str]

    • activation_groups : dict[str, list[int]]

  • device (str, default="cpu") – Device to run computations on.

  • deterministic (bool, default=True) – Whether to use deterministic forward pass.

Returns:

Tensor of shape (N, D) containing imputed validation values. Non-validation entries are set to NaN.

Return type:

torch.Tensor

set_activation_groups(activation_groups)

Attach feature activation structure to the model.

This method stores the resolved activation_groups dictionary, which defines how each input column should be interpreted during loss computation and imputation.

The model itself does NOT use this information during the forward pass; it is stored purely as metadata to ensure consistency between the model, dataset, loss functions, and imputation routines.

Parameters:

activation_groups (dict) –

Dictionary mapping feature types to column indices. Expected format:

{

“continuous”: [col_idx, …], “binary”: [col_idx, …], “<categorical_name>”: [col_idx, …], …

}

  • ”continuous”: indices of continuous-valued features

  • ”binary”: indices of binary features

  • Each additional key represents a grouped categorical variable (multiple columns corresponding to one variable)

Raises:

ValueError – If activation_groups is not a dictionary or contains invalid column indices.

Return type:

None

Notes

  • This replaces the old set_binary_features functionality.

  • The model outputs raw logits for all features; interpretation is handled

externally using activation_groups. - This function is primarily used after loading a model to reattach dataset structure if needed.

class ClusterDataset(*args: Any, **kwargs: Any)[source]

Bases: Dataset

Dataset that handles cluster-wise masking and normalization for VAE training.

1. Optionally holds out a validation subset per cluster from observed (non-NaN) entries according to val_proportion. 2. Combines original missingness with validation-held-out entries. 3. Normalizes observed values column-wise (mean/std), keeps masks for NaNs, and replaces NaNs (including held-out values) with replacement_value.

Parameters:
  • data (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Input matrix of shape (n_samples, n_features). May contain NaNs.

  • cluster_labels (array-like or None) – Cluster assignment per sample (length n_samples). If None, all rows are assigned to a single cluster 0.

  • val_proportion (float | collections.abc.Sequence | collections.abc.Mapping | pandas.Series) –

    Per-cluster fraction of non-missing entries to hold out for validation. Accepted forms:

    • float in [0, 1]: same fraction for all clusters

    • sequence (length = number of clusters): aligned to sorted(unique(cluster_labels))

    • mapping (e.g. {cluster_id: fraction}) covering all clusters

    • pandas.Series indexed by cluster IDs covering all clusters

  • replacement_value (float) – Value used to fill missing and held-out entries after masking.

  • columns_ignore (list[str | int] or None) – Columns to exclude from validation masking. Use column names for DataFrame and indices otherwise.

  • imputable (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Matrix indicating which entries should be excluded from imputation (1 = impute, 0 = exclude). Must have the same shape as data.

  • binary_feature_mask (list[bool] | numpy.ndarray) – Boolean vector of length n_features indicating binary columns. Used to construct activation_groups. Categorical dummy columns must also be marked as True.

  • categorical_column_map (dict[str, list[str | int]] or None) –

    Optional mapping from original categorical variable names to their corresponding dummy-variable columns. Example:

    {"C1": ["C1b1", "C1b2"], "C2": ["C2b1", "C2b2"]}
    

    These columns are grouped together in activation_groups and treated as categorical variables. All listed columns must also be marked as True in binary_feature_mask.

Variables:
  • raw_data (torch.FloatTensor) – Original data converted to float tensor (NaNs preserved).

  • data (torch.FloatTensor) – Normalized data with NaNs replaced by replacement_value.

  • masks (torch.BoolTensor) – Boolean mask where True indicates observed (non-NaN) entries before replacement.

  • val_data (torch.FloatTensor) – Tensor containing only validation-held-out values (others are NaN).

  • cluster_labels (torch.LongTensor) – Cluster ID for each row.

  • indices (torch.LongTensor) – Original row indices (from DataFrame index or arange for arrays/tensors).

  • feature_names (list[str]) – Column names (from DataFrame) or synthetic names (V1, V2, …).

  • n_clusters (int) – Number of unique clusters.

  • shape (tuple[int, int]) – Shape of self.data as (n_samples, n_features).

  • binary_feature_mask (numpy.ndarray) – Boolean mask indicating binary features.

  • activation_groups (dict) –

    Mapping of feature groups to column indices. Structure:

    {
        "continuous": [int, ...],
        "binary": [int, ...],
        "<categorical_name>": [int, ...],
        ...
    }
    
    • ”continuous”: indices of continuous-valued features

    • ”binary”: indices of binary features

    • Each additional key corresponds to a grouped categorical variable

    This structure is used for loss computation, imputation, and validation logic.

Raises:
  • TypeError – If data or cluster_labels are invalid types, or if val_proportion is not a supported type.

  • ValueError – If any proportion is outside [0, 1], or if cluster coverage is incomplete, or sequence lengths do not match number of clusters.

Note

  • Normalization uses column-wise mean and standard deviation computed from

observed values after validation masking. * Zero standard deviations are replaced with 1 to avoid division by zero. * Feature types are resolved into activation_groups and used throughout training, loss computation, and imputation.

get_activation_groups(exclude_ignored=False)[source]

Return activation groups, optionally excluding ignored columns.

Parameters:

exclude_ignored (bool) – If True, removes columns listed in columns_ignore.

Returns:

Filtered activation groups with ignored columns removed.

Return type:

dict

copy()[source]

Creates a deep copy of the ClusterDataset method containing all attributes.

Returns:

Deep copy of the dataset

Return type:

ClusterDataset