ciss_vae.classes.cluster_dataset.ClusterDataset
- class ClusterDataset(*args: Any, **kwargs: Any)[source]
Bases:
DatasetDataset that handles cluster-wise masking and normalization for VAE training.
1. Optionally holds out a validation subset per cluster from observed (non-NaN) entries according to
val_proportion. 2. Combines original missingness with validation-held-out entries. 3. Normalizes observed values column-wise (mean/std), keeps masks for NaNs, and replaces NaNs (including held-out values) withreplacement_value.- Parameters:
data (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Input matrix of shape
(n_samples, n_features). May contain NaNs.cluster_labels (array-like or None) – Cluster assignment per sample (length
n_samples). IfNone, all rows are assigned to a single cluster0.val_proportion (float | collections.abc.Sequence | collections.abc.Mapping | pandas.Series) –
Per-cluster fraction of non-missing entries to hold out for validation. Accepted forms:
float in
[0, 1]: same fraction for all clusterssequence (length = number of clusters): aligned to
sorted(unique(cluster_labels))mapping (e.g.
{cluster_id: fraction}) covering all clusterspandas.Series indexed by cluster IDs covering all clusters
replacement_value (float) – Value used to fill missing and held-out entries after masking.
columns_ignore (list[str | int] or None) – Columns to exclude from validation masking. Use column names for DataFrame and indices otherwise.
imputable (pandas.DataFrame | numpy.ndarray | torch.Tensor) – Matrix indicating which entries should be excluded from imputation (1 = impute, 0 = exclude). Must have the same shape as
data.binary_feature_mask (list[bool] | numpy.ndarray) – Boolean vector of length
n_featuresindicating binary columns. Used to constructactivation_groups. Categorical dummy columns must also be marked as True.categorical_column_map (dict[str, list[str | int]] or None) –
Optional mapping from original categorical variable names to their corresponding dummy-variable columns. Example:
{"C1": ["C1b1", "C1b2"], "C2": ["C2b1", "C2b2"]}
These columns are grouped together in
activation_groupsand treated as categorical variables. All listed columns must also be marked as True inbinary_feature_mask.
- Variables:
raw_data (torch.FloatTensor) – Original data converted to float tensor (NaNs preserved).
data (torch.FloatTensor) – Normalized data with NaNs replaced by
replacement_value.masks (torch.BoolTensor) – Boolean mask where
Trueindicates observed (non-NaN) entries before replacement.val_data (torch.FloatTensor) – Tensor containing only validation-held-out values (others are NaN).
cluster_labels (torch.LongTensor) – Cluster ID for each row.
indices (torch.LongTensor) – Original row indices (from DataFrame index or
arangefor arrays/tensors).feature_names (list[str]) – Column names (from DataFrame) or synthetic names (
V1,V2, …).n_clusters (int) – Number of unique clusters.
shape (tuple[int, int]) – Shape of
self.dataas(n_samples, n_features).binary_feature_mask (numpy.ndarray) – Boolean mask indicating binary features.
activation_groups (dict) –
Mapping of feature groups to column indices. Structure:
{ "continuous": [int, ...], "binary": [int, ...], "<categorical_name>": [int, ...], ... }
”continuous”: indices of continuous-valued features
”binary”: indices of binary features
Each additional key corresponds to a grouped categorical variable
This structure is used for loss computation, imputation, and validation logic.
- Raises:
TypeError – If
dataorcluster_labelsare invalid types, or ifval_proportionis not a supported type.ValueError – If any proportion is outside
[0, 1], or if cluster coverage is incomplete, or sequence lengths do not match number of clusters.
Note
Normalization uses column-wise mean and standard deviation computed from
observed values after validation masking. * Zero standard deviations are replaced with 1 to avoid division by zero. * Feature types are resolved into
activation_groupsand used throughout training, loss computation, and imputation.- __init__(data, cluster_labels, val_proportion=0.1, replacement_value=0, columns_ignore=None, imputable=None, val_seed=42, binary_feature_mask=None, categorical_column_map=None)[source]
Build the dataset, apply per-cluster validation masking, and normalize.
Steps: 1. Convert inputs to tensors; preserve indices/column names if a DataFrame. 2. Resolve per-cluster validation proportions from
val_proportion. 3. For each cluster and feature, randomly mark the requested fraction of observed entries as validation targets. 4. Createval_data(validation targets only) and trainingdatawhere validation entries are set to NaN. 5. Compute per-feature mean/std over non-NaN entries indataand apply normalization; then replace remaining NaNs withreplacement_value.- Parameters:
data (pandas.DataFrame or numpy.ndarray or torch.Tensor) – Input matrix, shape
(n_samples, n_features). May contain NaNscluster_labels (array-like or None) – Cluster assignment per sample (length
n_samples). IfNone, all rows are assigned to a single cluster0val_proportion (float or collections.abc.Sequence or collections.abc.Mapping or pandas.Series, optional) – Per-cluster fraction of non-missing entries to hold out for validation, defaults to 0.1
replacement_value (float, optional) – Value to fill missing/held-out entries in
self.dataafter masking, defaults to 0columns_ignore (list[str or int] or None, optional) – Columns to exclude from validation masking (names for DataFrame, indices otherwise), defaults to None
imputable (pandas.DataFrame | numpy.ndarray | torch.Tensor, optional) – Optional Matrix showing which data entries to exclude from imputation (1 for impute, 0 for exclude from imputation), shape
(n_samples, n_features). Should be same shape asdata.val_seed (int) – Optional (default 42), seed for random number generator for selecting validation dataset
binary_feature_mask (list[bool]) – 1D bool vector of length ‘input_dim’ -> true if column is binary.
categorical_column_map (dict) – Optional dictionary where keys are original categories and values are resulting dummy variables. Must set binary_feature_mask if using!
Methods
__init__(data, cluster_labels[, ...])Build the dataset, apply per-cluster validation masking, and normalize.
copy()Creates a deep copy of the ClusterDataset method containing all attributes.
get_activation_groups([exclude_ignored])Return activation groups, optionally excluding ignored columns.
- get_activation_groups(exclude_ignored=False)[source]
Return activation groups, optionally excluding ignored columns.