ciss_vae.classes package
Submodules
ciss_vae.classes.cluster_dataset module
Dataset utilities for clustering-aware masking and normalization.
This module defines ClusterDataset, a PyTorch torch.utils.data.Dataset
that (1) optionally holds out a validation subset of observed entries on a
per-cluster basis, (2) normalizes features using statistics computed on the
masked training matrix, and (3) exposes tensors required by the CISS-VAE
training loops: normalized data with missing values filled, cluster labels,
and binary observation masks.
Typical usage:
ds = ClusterDataset(
data=df, # (N, P) with NaNs for missing
cluster_labels=clusters, # length-N array-like
val_proportion=0.1, # or per-cluster mapping/sequence
replacement_value=0.0,
columns_ignore=["id"] # columns to exclude from validation masking
)
- class ClusterDataset(*args: Any, **kwargs: Any)[source]
Bases:
DatasetDataset that handles cluster-wise masking and normalization for VAE training.
1. Optionally holds out a validation subset per cluster from observed (non-NaN) entries according to
val_proportion. 2. Combines original missingness with validation-held-out entries. 3. Normalizes observed values column-wise (mean/std), keeps masks for NaNs, and replaces NaNs (including held-out values) withreplacement_value.- Parameters:
data (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Input matrix of shape
(n_samples, n_features). May contain NaNs.cluster_labels (array-like or None) – Cluster assignment per sample (length
n_samples). IfNone, all rows are assigned to a single cluster0.val_proportion (float | collections.abc.Sequence | collections.abc.Mapping | pandas.Series) –
Per-cluster fraction of non-missing entries to hold out for validation. Accepted forms:
float in
[0, 1]: same fraction for all clusterssequence (length = number of clusters): aligned to
sorted(unique(cluster_labels))mapping (e.g.
{cluster_id: fraction}) covering all clusterspandas.Series indexed by cluster IDs covering all clusters
replacement_value (float) – Value used to fill missing and held-out entries after masking.
columns_ignore (list[str | int] or None) – Columns to exclude from validation masking. Use column names for DataFrame and indices otherwise.
imputable (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Matrix indicating which entries should be excluded from imputation (1 = impute, 0 = exclude). Must have the same shape as
data.binary_feature_mask (list[bool] | numpy.ndarray) – Boolean vector of length
n_featuresindicating binary columns. Used to constructactivation_groups. Categorical dummy columns must also be marked as True.categorical_column_map (dict[str, list[str | int]] or None) –
Optional mapping from original categorical variable names to their corresponding dummy-variable columns. Example:
{"C1": ["C1b1", "C1b2"], "C2": ["C2b1", "C2b2"]}
These columns are grouped together in
activation_groupsand treated as categorical variables. All listed columns must also be marked as True inbinary_feature_mask.
- Variables:
raw_data (torch.FloatTensor) – Original data converted to float tensor (NaNs preserved).
data (torch.FloatTensor) – Normalized data with NaNs replaced by
replacement_value.masks (torch.BoolTensor) – Boolean mask where
Trueindicates observed (non-NaN) entries before replacement.val_data (torch.FloatTensor) – Tensor containing only validation-held-out values (others are NaN).
cluster_labels (torch.LongTensor) – Cluster ID for each row.
indices (torch.LongTensor) – Original row indices (from DataFrame index or
arangefor arrays/tensors).feature_names (list[str]) – Column names (from DataFrame) or synthetic names (
V1,V2, …).n_clusters (int) – Number of unique clusters.
shape (tuple[int, int]) – Shape of
self.dataas(n_samples, n_features).binary_feature_mask (numpy.ndarray) – Boolean mask indicating binary features.
activation_groups (dict) –
Mapping of feature groups to column indices. Structure:
{ "continuous": [int, ...], "binary": [int, ...], "<categorical_name>": [int, ...], ... }
”continuous”: indices of continuous-valued features
”binary”: indices of binary features
Each additional key corresponds to a grouped categorical variable
This structure is used for loss computation, imputation, and validation logic.
- Raises:
TypeError – If
dataorcluster_labelsare invalid types, or ifval_proportionis not a supported type.ValueError – If any proportion is outside
[0, 1], or if cluster coverage is incomplete, or sequence lengths do not match number of clusters.
Note
Normalization uses column-wise mean and standard deviation computed from
observed values after validation masking. * Zero standard deviations are replaced with 1 to avoid division by zero. * Feature types are resolved into
activation_groupsand used throughout training, loss computation, and imputation.- get_activation_groups(exclude_ignored=False)[source]
Return activation groups, optionally excluding ignored columns.
ciss_vae.classes.vae module
Variational Autoencoder with cluster‑aware shared/unshared layers.
This module defines CISSVAE, a VAE that can route samples through
either shared or cluster‑specific (unshared) layers in the encoder and
decoder, controlled by per‑layer directives. For all features, the model outputs raw logits. Feature-specific activation
functions (e.g., sigmoid for binary, softmax for categorical) are applied
externally using activation_groups to map features to correct activation function.
- class CISSVAE(*args: Any, **kwargs: Any)[source]
Bases:
ModuleClustering-Informed Shared-Structure Variational Autoencoder (CISSVAE).
Supports flexible mixtures of shared and unshared layers across clusters in both encoder and decoder. Unshared layers are applied by cluster, shared layers are applied to all samples.
- Parameters:
input_dim (int) – Number of input features (columns).
hidden_dims (list[int]) – Width of each hidden layer (encoder goes forward, decoder uses the reverse).
layer_order_enc (list[str]) – Per‑encoder‑layer directive:
"shared"or"unshared".layer_order_dec (list[str]) – Per‑decoder‑layer directive:
"shared"or"unshared".latent_shared (bool) – If
True, the latent heads (mu,logvar) are shared across clusters; otherwise one head per cluster.latent_dim (int) – Dimensionality of the latent space.
output_shared (bool) – If
True, the final output layer is shared; otherwise one output layer per cluster.num_clusters (int) – Number of clusters present in the data.
debug (bool) – If
True, prints routing shapes and asserts row order invariants.
- Raises:
ValueError – If an item of
layer_order_encorlayer_order_decis not one of
{"shared","unshared","s","u"}(case‑insensitive), or if their lengths do not matchlen(hidden_dims)for the respective path.- Expected shapes
Encoder input
x:(batch, input_dim)Cluster labels
cluster_labels:(batch,)(LongTensorwith values in[0, num_clusters-1])Decoder/Output:
(batch, input_dim)
- Notes
The decoder consumes
hidden_dims[::-1](reverse order).Unshared layers maintain per‑cluster
ModuleList/ModuleDictreplicas.Routing never reorders rows; masks are used to apply cluster‑specific sublayers in‑place.
- route_through_layers(x, cluster_labels, layer_type_list, shared_layers, unshared_layers)[source]
Apply a sequence of shared/unshared layers according to
layer_type_list.For each position
i: * iflayer_type_list[i]is shared → applyshared_layers[i_shared]to all rows; * if unshared → for each clusterc, apply thec‑specific layer at that depth to the subset of rows wherecluster_labels == c.- Parameters:
x (torch.Tensor, shape
(batch, d_in)) – Input activations to be routed.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.layer_type_list (list[str]) – Sequence of
"shared"/"unshared"flags (length = number of layers at this stage).shared_layers (torch.nn.ModuleList) – Layers used when the directive is shared (index increases only when a shared layer is consumed).
unshared_layers (dict[str, torch.nn.ModuleList] | torch.nn.ModuleDict) – Per‑cluster lists of layers for unshared directives (index per cluster increases only when an unshared layer is consumed).
- Returns:
Routed activations.
- Return type:
- Raises:
ValueError – If an entry in
layer_type_listis invalid or if per‑cluster unshared stacks are inconsistent with the directives.
- encode(x, cluster_labels)[source]
Encoder forward pass producing
muandlogvar.- Parameters:
x (torch.Tensor, shape
(batch, input_dim)) – Input batch.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.
- Returns:
Tuple
(mu, logvar).- Return type:
- reparameterize(mu, logvar, generator=None)[source]
Reparameterization trick:
z = mu + eps * exp(0.5 * logvar).- Parameters:
mu (torch.Tensor, shape
(batch, latent_dim)) – Mean of the approximate posterior.logvar (torch.Tensor, shape
(batch, latent_dim)) – Log‑variance of the approximate posterior.generator (torch.Generator) – Optionl for RNG control (default None)
- Returns:
Sampled latent codes
z.- Return type:
- decode(z, cluster_labels)[source]
Decoder forward pass from latent
zto reconstruction.- Parameters:
z (torch.Tensor, shape
(batch, latent_dim)) – Latent codes.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.
- Returns:
Reconstructed inputs.
- Return type:
torch.Tensor, shape
(batch, input_dim)
- forward(x, cluster_labels, deterministic=False, *, generator=None)[source]
Full VAE forward pass: encode → reparameterize → decode.
- Parameters:
x (torch.Tensor, shape
(batch, input_dim)) – Input batch.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.deterministic (bool) – Deterministic Evaluation of Model for Imputation (default False)
generator – Optionl for RNG control (default None)
:type torch.Generator
- Returns:
Tuple
(recon, mu, logvar).- Return type:
- set_final_lr(final_lr)[source]
Stores final lr from initial training loop in model attributes to be accessed in refit loop.
- get_imputed_valdata(dataset, device='cpu', deterministic=True)[source]
Compute imputed values for the validation dataset.
Performs a forward pass through the trained VAE and reconstructs only the validation-masked entries. Outputs are interpreted according to activation_groups:
continuous → denormalized using mean/std
binary → sigmoid applied
categorical → softmax + argmax (one-hot reconstruction)
- Parameters:
dataset (ClusterDataset) –
- Dataset object containing:
data : normalized tensor (N, D)
val_data : original data with NaNs at validation positions
cluster_labels : (N,)
feature_means : (D,)
feature_stds : (D,)
feature_names : list[str]
activation_groups : dict[str, list[int]]
device (str, default="cpu") – Device to run computations on.
deterministic (bool, default=True) – Whether to use deterministic forward pass.
- Returns:
Tensor of shape (N, D) containing imputed validation values. Non-validation entries are set to NaN.
- Return type:
- set_activation_groups(activation_groups)
Attach feature activation structure to the model.
This method stores the resolved activation_groups dictionary, which defines how each input column should be interpreted during loss computation and imputation.
The model itself does NOT use this information during the forward pass; it is stored purely as metadata to ensure consistency between the model, dataset, loss functions, and imputation routines.
- Parameters:
activation_groups (dict) –
Dictionary mapping feature types to column indices. Expected format:
- {
“continuous”: [col_idx, …], “binary”: [col_idx, …], “<categorical_name>”: [col_idx, …], …
}
”continuous”: indices of continuous-valued features
”binary”: indices of binary features
Each additional key represents a grouped categorical variable (multiple columns corresponding to one variable)
- Raises:
ValueError – If activation_groups is not a dictionary or contains invalid column indices.
- Return type:
Notes
This replaces the old set_binary_features functionality.
The model outputs raw logits for all features; interpretation is handled
externally using activation_groups. - This function is primarily used after loading a model to reattach dataset structure if needed.
Module contents
- class CISSVAE(*args: Any, **kwargs: Any)[source]
Bases:
ModuleClustering-Informed Shared-Structure Variational Autoencoder (CISSVAE).
Supports flexible mixtures of shared and unshared layers across clusters in both encoder and decoder. Unshared layers are applied by cluster, shared layers are applied to all samples.
- Parameters:
input_dim (int) – Number of input features (columns).
hidden_dims (list[int]) – Width of each hidden layer (encoder goes forward, decoder uses the reverse).
layer_order_enc (list[str]) – Per‑encoder‑layer directive:
"shared"or"unshared".layer_order_dec (list[str]) – Per‑decoder‑layer directive:
"shared"or"unshared".latent_shared (bool) – If
True, the latent heads (mu,logvar) are shared across clusters; otherwise one head per cluster.latent_dim (int) – Dimensionality of the latent space.
output_shared (bool) – If
True, the final output layer is shared; otherwise one output layer per cluster.num_clusters (int) – Number of clusters present in the data.
debug (bool) – If
True, prints routing shapes and asserts row order invariants.
- Raises:
ValueError – If an item of
layer_order_encorlayer_order_decis not one of
{"shared","unshared","s","u"}(case‑insensitive), or if their lengths do not matchlen(hidden_dims)for the respective path.- Expected shapes
Encoder input
x:(batch, input_dim)Cluster labels
cluster_labels:(batch,)(LongTensorwith values in[0, num_clusters-1])Decoder/Output:
(batch, input_dim)
- Notes
The decoder consumes
hidden_dims[::-1](reverse order).Unshared layers maintain per‑cluster
ModuleList/ModuleDictreplicas.Routing never reorders rows; masks are used to apply cluster‑specific sublayers in‑place.
- route_through_layers(x, cluster_labels, layer_type_list, shared_layers, unshared_layers)[source]
Apply a sequence of shared/unshared layers according to
layer_type_list.For each position
i: * iflayer_type_list[i]is shared → applyshared_layers[i_shared]to all rows; * if unshared → for each clusterc, apply thec‑specific layer at that depth to the subset of rows wherecluster_labels == c.- Parameters:
x (torch.Tensor, shape
(batch, d_in)) – Input activations to be routed.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.layer_type_list (list[str]) – Sequence of
"shared"/"unshared"flags (length = number of layers at this stage).shared_layers (torch.nn.ModuleList) – Layers used when the directive is shared (index increases only when a shared layer is consumed).
unshared_layers (dict[str, torch.nn.ModuleList] | torch.nn.ModuleDict) – Per‑cluster lists of layers for unshared directives (index per cluster increases only when an unshared layer is consumed).
- Returns:
Routed activations.
- Return type:
- Raises:
ValueError – If an entry in
layer_type_listis invalid or if per‑cluster unshared stacks are inconsistent with the directives.
- encode(x, cluster_labels)[source]
Encoder forward pass producing
muandlogvar.- Parameters:
x (torch.Tensor, shape
(batch, input_dim)) – Input batch.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.
- Returns:
Tuple
(mu, logvar).- Return type:
- reparameterize(mu, logvar, generator=None)[source]
Reparameterization trick:
z = mu + eps * exp(0.5 * logvar).- Parameters:
mu (torch.Tensor, shape
(batch, latent_dim)) – Mean of the approximate posterior.logvar (torch.Tensor, shape
(batch, latent_dim)) – Log‑variance of the approximate posterior.generator (torch.Generator) – Optionl for RNG control (default None)
- Returns:
Sampled latent codes
z.- Return type:
- decode(z, cluster_labels)[source]
Decoder forward pass from latent
zto reconstruction.- Parameters:
z (torch.Tensor, shape
(batch, latent_dim)) – Latent codes.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.
- Returns:
Reconstructed inputs.
- Return type:
torch.Tensor, shape
(batch, input_dim)
- forward(x, cluster_labels, deterministic=False, *, generator=None)[source]
Full VAE forward pass: encode → reparameterize → decode.
- Parameters:
x (torch.Tensor, shape
(batch, input_dim)) – Input batch.cluster_labels (torch.LongTensor, shape
(batch,)) – Cluster id per row.deterministic (bool) – Deterministic Evaluation of Model for Imputation (default False)
generator – Optionl for RNG control (default None)
:type torch.Generator
- Returns:
Tuple
(recon, mu, logvar).- Return type:
- set_final_lr(final_lr)[source]
Stores final lr from initial training loop in model attributes to be accessed in refit loop.
- get_imputed_valdata(dataset, device='cpu', deterministic=True)[source]
Compute imputed values for the validation dataset.
Performs a forward pass through the trained VAE and reconstructs only the validation-masked entries. Outputs are interpreted according to activation_groups:
continuous → denormalized using mean/std
binary → sigmoid applied
categorical → softmax + argmax (one-hot reconstruction)
- Parameters:
dataset (ClusterDataset) –
- Dataset object containing:
data : normalized tensor (N, D)
val_data : original data with NaNs at validation positions
cluster_labels : (N,)
feature_means : (D,)
feature_stds : (D,)
feature_names : list[str]
activation_groups : dict[str, list[int]]
device (str, default="cpu") – Device to run computations on.
deterministic (bool, default=True) – Whether to use deterministic forward pass.
- Returns:
Tensor of shape (N, D) containing imputed validation values. Non-validation entries are set to NaN.
- Return type:
- set_activation_groups(activation_groups)
Attach feature activation structure to the model.
This method stores the resolved activation_groups dictionary, which defines how each input column should be interpreted during loss computation and imputation.
The model itself does NOT use this information during the forward pass; it is stored purely as metadata to ensure consistency between the model, dataset, loss functions, and imputation routines.
- Parameters:
activation_groups (dict) –
Dictionary mapping feature types to column indices. Expected format:
- {
“continuous”: [col_idx, …], “binary”: [col_idx, …], “<categorical_name>”: [col_idx, …], …
}
”continuous”: indices of continuous-valued features
”binary”: indices of binary features
Each additional key represents a grouped categorical variable (multiple columns corresponding to one variable)
- Raises:
ValueError – If activation_groups is not a dictionary or contains invalid column indices.
- Return type:
Notes
This replaces the old set_binary_features functionality.
The model outputs raw logits for all features; interpretation is handled
externally using activation_groups. - This function is primarily used after loading a model to reattach dataset structure if needed.
- class ClusterDataset(*args: Any, **kwargs: Any)[source]
Bases:
DatasetDataset that handles cluster-wise masking and normalization for VAE training.
1. Optionally holds out a validation subset per cluster from observed (non-NaN) entries according to
val_proportion. 2. Combines original missingness with validation-held-out entries. 3. Normalizes observed values column-wise (mean/std), keeps masks for NaNs, and replaces NaNs (including held-out values) withreplacement_value.- Parameters:
data (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Input matrix of shape
(n_samples, n_features). May contain NaNs.cluster_labels (array-like or None) – Cluster assignment per sample (length
n_samples). IfNone, all rows are assigned to a single cluster0.val_proportion (float | collections.abc.Sequence | collections.abc.Mapping | pandas.Series) –
Per-cluster fraction of non-missing entries to hold out for validation. Accepted forms:
float in
[0, 1]: same fraction for all clusterssequence (length = number of clusters): aligned to
sorted(unique(cluster_labels))mapping (e.g.
{cluster_id: fraction}) covering all clusterspandas.Series indexed by cluster IDs covering all clusters
replacement_value (float) – Value used to fill missing and held-out entries after masking.
columns_ignore (list[str | int] or None) – Columns to exclude from validation masking. Use column names for DataFrame and indices otherwise.
imputable (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Matrix indicating which entries should be excluded from imputation (1 = impute, 0 = exclude). Must have the same shape as
data.binary_feature_mask (list[bool] | numpy.ndarray) – Boolean vector of length
n_featuresindicating binary columns. Used to constructactivation_groups. Categorical dummy columns must also be marked as True.categorical_column_map (dict[str, list[str | int]] or None) –
Optional mapping from original categorical variable names to their corresponding dummy-variable columns. Example:
{"C1": ["C1b1", "C1b2"], "C2": ["C2b1", "C2b2"]}
These columns are grouped together in
activation_groupsand treated as categorical variables. All listed columns must also be marked as True inbinary_feature_mask.
- Variables:
raw_data (torch.FloatTensor) – Original data converted to float tensor (NaNs preserved).
data (torch.FloatTensor) – Normalized data with NaNs replaced by
replacement_value.masks (torch.BoolTensor) – Boolean mask where
Trueindicates observed (non-NaN) entries before replacement.val_data (torch.FloatTensor) – Tensor containing only validation-held-out values (others are NaN).
cluster_labels (torch.LongTensor) – Cluster ID for each row.
indices (torch.LongTensor) – Original row indices (from DataFrame index or
arangefor arrays/tensors).feature_names (list[str]) – Column names (from DataFrame) or synthetic names (
V1,V2, …).n_clusters (int) – Number of unique clusters.
shape (tuple[int, int]) – Shape of
self.dataas(n_samples, n_features).binary_feature_mask (numpy.ndarray) – Boolean mask indicating binary features.
activation_groups (dict) –
Mapping of feature groups to column indices. Structure:
{ "continuous": [int, ...], "binary": [int, ...], "<categorical_name>": [int, ...], ... }
”continuous”: indices of continuous-valued features
”binary”: indices of binary features
Each additional key corresponds to a grouped categorical variable
This structure is used for loss computation, imputation, and validation logic.
- Raises:
TypeError – If
dataorcluster_labelsare invalid types, or ifval_proportionis not a supported type.ValueError – If any proportion is outside
[0, 1], or if cluster coverage is incomplete, or sequence lengths do not match number of clusters.
Note
Normalization uses column-wise mean and standard deviation computed from
observed values after validation masking. * Zero standard deviations are replaced with 1 to avoid division by zero. * Feature types are resolved into
activation_groupsand used throughout training, loss computation, and imputation.- get_activation_groups(exclude_ignored=False)[source]
Return activation groups, optionally excluding ignored columns.