Source code for ciss_vae.classes.cluster_dataset

"""
Dataset utilities for clustering-aware masking and normalization.

This module defines :class:`ClusterDataset`, a PyTorch :class:`torch.utils.data.Dataset`
that (1) optionally holds out a validation subset of *observed* entries on a
per-cluster basis, (2) normalizes features using statistics computed on the
masked training matrix, and (3) exposes tensors required by the CISS-VAE
training loops: normalized data with missing values filled, cluster labels,
and binary observation masks.

Typical usage::

    ds = ClusterDataset(
        data=df,                       # (N, P) with NaNs for missing
        cluster_labels=clusters,       # length-N array-like
        val_proportion=0.1,            # or per-cluster mapping/sequence
        replacement_value=0.0,
        columns_ignore=["id"]          # columns to exclude from validation masking
    )
"""
from torch.utils.data import Dataset
import torch
import numpy as np
import pandas as pd
import copy
from collections.abc import Mapping, Sequence

[docs] class ClusterDataset(Dataset): r""" Dataset that handles cluster-wise masking and normalization for VAE training. 1. Optionally holds out a validation subset **per cluster** from *observed* (non-NaN) entries according to ``val_proportion``. 2. Combines original missingness with validation-held-out entries. 3. Normalizes observed values column-wise (mean/std), keeps masks for NaNs, and replaces NaNs (including held-out values) with ``replacement_value``. :param data: Input matrix of shape ``(n_samples, n_features)``. May contain NaNs. :type data: pandas.DataFrame | numpy.ndarray | torch.Tensor :param cluster_labels: Cluster assignment per sample (length ``n_samples``). If ``None``, all rows are assigned to a single cluster ``0``. :type cluster_labels: array-like or None :param val_proportion: Per-cluster fraction of **non-missing** entries to hold out for validation. Accepted forms: * float in ``[0, 1]``: same fraction for all clusters * sequence (length = number of clusters): aligned to ``sorted(unique(cluster_labels))`` * mapping (e.g. ``{cluster_id: fraction}``) covering all clusters * pandas.Series indexed by cluster IDs covering all clusters :type val_proportion: float | collections.abc.Sequence | collections.abc.Mapping | pandas.Series :param replacement_value: Value used to fill missing and held-out entries after masking. :type replacement_value: float :param columns_ignore: Columns to exclude from validation masking. Use column names for DataFrame and indices otherwise. :type columns_ignore: list[str | int] or None :param imputable: Matrix indicating which entries should be excluded from imputation (1 = impute, 0 = exclude). Must have the same shape as ``data``. :type imputable: pandas.DataFrame | numpy.ndarray | torch.Tensor :param binary_feature_mask: Boolean vector of length ``n_features`` indicating binary columns. Used to construct ``activation_groups``. Categorical dummy columns must also be marked as True. :type binary_feature_mask: list[bool] | numpy.ndarray :param categorical_column_map: Optional mapping from original categorical variable names to their corresponding dummy-variable columns. Example:: {"C1": ["C1b1", "C1b2"], "C2": ["C2b1", "C2b2"]} These columns are grouped together in ``activation_groups`` and treated as categorical variables. All listed columns must also be marked as True in ``binary_feature_mask``. :type categorical_column_map: dict[str, list[str | int]] or None :ivar raw_data: Original data converted to float tensor (NaNs preserved). :vartype raw_data: torch.FloatTensor :ivar data: Normalized data with NaNs replaced by ``replacement_value``. :vartype data: torch.FloatTensor :ivar masks: Boolean mask where ``True`` indicates observed (non-NaN) entries before replacement. :vartype masks: torch.BoolTensor :ivar val_data: Tensor containing only validation-held-out values (others are NaN). :vartype val_data: torch.FloatTensor :ivar cluster_labels: Cluster ID for each row. :vartype cluster_labels: torch.LongTensor :ivar indices: Original row indices (from DataFrame index or ``arange`` for arrays/tensors). :vartype indices: torch.LongTensor :ivar feature_names: Column names (from DataFrame) or synthetic names (``V1``, ``V2``, ...). :vartype feature_names: list[str] :ivar n_clusters: Number of unique clusters. :vartype n_clusters: int :ivar shape: Shape of ``self.data`` as ``(n_samples, n_features)``. :vartype shape: tuple[int, int] :ivar binary_feature_mask: Boolean mask indicating binary features. :vartype binary_feature_mask: numpy.ndarray :ivar activation_groups: Mapping of feature groups to column indices. Structure:: { "continuous": [int, ...], "binary": [int, ...], "<categorical_name>": [int, ...], ... } * "continuous": indices of continuous-valued features * "binary": indices of binary features * Each additional key corresponds to a grouped categorical variable This structure is used for loss computation, imputation, and validation logic. :vartype activation_groups: dict :raises TypeError: If ``data`` or ``cluster_labels`` are invalid types, or if ``val_proportion`` is not a supported type. :raises ValueError: If any proportion is outside ``[0, 1]``, or if cluster coverage is incomplete, or sequence lengths do not match number of clusters. .. note:: * Normalization uses column-wise mean and standard deviation computed from observed values after validation masking. * Zero standard deviations are replaced with 1 to avoid division by zero. * Feature types are resolved into ``activation_groups`` and used throughout training, loss computation, and imputation. """
[docs] def __init__( self, data, cluster_labels, val_proportion = 0.1, replacement_value = 0, columns_ignore = None, imputable = None, val_seed = 42, binary_feature_mask = None, categorical_column_map = None,): """Build the dataset, apply per-cluster validation masking, and normalize. Steps: 1. Convert inputs to tensors; preserve indices/column names if a DataFrame. 2. Resolve per-cluster validation proportions from ``val_proportion``. 3. For each cluster and feature, randomly mark the requested fraction of **observed** entries as validation targets. 4. Create ``val_data`` (validation targets only) and training ``data`` where validation entries are set to NaN. 5. Compute per-feature mean/std over non-NaN entries in ``data`` and apply normalization; then replace remaining NaNs with ``replacement_value``. :param data: Input matrix, shape ``(n_samples, n_features)``. May contain NaNs :type data: pandas.DataFrame or numpy.ndarray or torch.Tensor :param cluster_labels: Cluster assignment per sample (length ``n_samples``). If ``None``, all rows are assigned to a single cluster ``0`` :type cluster_labels: array-like or None :param val_proportion: Per-cluster fraction of **non-missing** entries to hold out for validation, defaults to 0.1 :type val_proportion: float or collections.abc.Sequence or collections.abc.Mapping or pandas.Series, optional :param replacement_value: Value to fill missing/held-out entries in ``self.data`` after masking, defaults to 0 :type replacement_value: float, optional :param columns_ignore: Columns to exclude from validation masking (names for DataFrame, indices otherwise), defaults to None :type columns_ignore: list[str or int] or None, optional :param imputable: Optional Matrix showing which data entries to exclude from imputation (1 for impute, 0 for exclude from imputation), shape ``(n_samples, n_features)``. Should be same shape as ``data``. :type imputable: pandas.DataFrame | numpy.ndarray | torch.Tensor, optional :param val_seed: Optional (default 42), seed for random number generator for selecting validation dataset :type val_seed: int :param binary_feature_mask: 1D bool vector of length 'input_dim' -> true if column is binary. :type binary_feature_mask: list[bool] :param categorical_column_map: Optional dictionary where keys are original categories and values are resulting dummy variables. Must set binary_feature_mask if using! :type categorical_column_map: dict """ ## set seed for selecting valdata self.val_seed = val_seed self._rng = np.random.default_rng(self.val_seed) ## set columns ignore -> no validation data selected from these columns if columns_ignore is None: self.columns_ignore = [] else: # If columns_ignore is a pandas Index or Series, convert to list if hasattr(columns_ignore, "tolist"): self.columns_ignore = columns_ignore.tolist() else: self.columns_ignore = list(columns_ignore) if binary_feature_mask is None: self.binary_feature_mask = None else: self.binary_feature_mask = np.array(binary_feature_mask) ## set to one cluster as default!! ## if categorical_column_map is used and bfm not set, give error if categorical_column_map is not None and binary_feature_mask is None: raise RuntimeError("binary_feature_mask required to use categorical_column_map") # ---------------------------------------- # Convert input data to numpy # ---------------------------------------- ## Additions -> check if the index column is non-numeric && give error if there are other non-numeric columns if hasattr(data, "iloc"): # pandas DataFrame n_rows, n_cols = data.shape self.indices = torch.arange(n_rows, dtype=torch.long) # safe for any index dtype self.feature_names = list(data.columns) # Build ignore index list by name self.ignore_indices = [i for i, col in enumerate(self.feature_names) if col in self.columns_ignore] # Build a numeric matrix column-by-column: # - ignored columns -> if not numeric become float column filled with NaN (kept in shape, never used) # - non-ignored columns -> must be numeric; error if not converted_cols = [] bad_cols = [] for j, col in enumerate(self.feature_names): s = data[col] if j in self.ignore_indices: # If column is numeric, keep as-is; if not, replace with NaN float column if pd.api.types.is_numeric_dtype(s): converted_cols.append(s.astype("float32")) else: converted_cols.append(pd.Series(np.nan, index=s.index, dtype="float32")) else: # Must be numeric; coerce and detect non-numeric values (not counting real NaNs) sc = pd.to_numeric(s, errors="coerce") introduced_nonnumeric = (~s.isna()) & (sc.isna()) if introduced_nonnumeric.any(): bad_cols.append(col) converted_cols.append(sc.astype("float32")) if bad_cols: raise TypeError( "Non-numeric values found in columns not listed in columns_ignore: " f"{bad_cols}. Convert them to numeric or add them to `columns_ignore`." ) # Stack back to (n_rows, n_cols) float32 raw_data_np = np.column_stack([c.to_numpy(dtype=np.float32) for c in converted_cols]) elif isinstance(data, np.ndarray): self.indices = torch.arange(data.shape[0], dtype=torch.long) self.feature_names = [f"V{i+1}" for i in range(data.shape[1])] # Ensure numeric array if not np.issubdtype(data.dtype, np.number): raise TypeError("ndarray input must be numeric. For mixed types, pass a DataFrame and use columns_ignore.") raw_data_np = data.astype(np.float32, copy=False) # For ndarray, columns_ignore is by index only self.ignore_indices = self.columns_ignore if isinstance(self.columns_ignore, list) else [] elif isinstance(data, torch.Tensor): self.indices = torch.arange(data.shape[0], dtype=torch.long) self.feature_names = [f"V{i+1}" for i in range(data.shape[1])] if not torch.is_floating_point(data) and not torch.is_complex(data) and not data.dtype.is_floating_point: data = data.float() raw_data_np = data.cpu().numpy().astype(np.float32, copy=False) self.ignore_indices = self.columns_ignore if isinstance(self.columns_ignore, list) else [] else: raise TypeError("Unsupported data format. Must be DataFrame, ndarray, or Tensor.") self.raw_data = torch.tensor(raw_data_np, dtype=torch.float32) ## added check for binary feature mask matches number of features if self.binary_feature_mask is not None: if len(self.binary_feature_mask) != raw_data_np.shape[1]: raise ValueError("binary_feature_mask must match number of features") # -------------------- # Added 'imputable' matrix # -------------------- if imputable is not None: if hasattr(imputable, 'iloc'): # pandas DataFrame self.imputable = imputable.values.astype(np.float32) elif isinstance(imputable, np.ndarray): self.imputable = imputable.astype(np.float32) elif isinstance(imputable, torch.Tensor): self.imputable = imputable.cpu().numpy().astype(np.float32) else: raise TypeError("Unsupported imputable matrix format. Must be DataFrame, ndarray, or Tensor.") self.imputable = torch.tensor(self.imputable, dtype=torch.int64) expected_shape = tuple(self.raw_data.shape) # (n_samples, n_features) if self.imputable.shape != expected_shape: raise ValueError( f"`imputable` shape {self.imputable.shape} does not match " f"data shape {expected_shape}." ) dni_np = self.imputable.cpu().numpy().astype(bool) else: self.imputable = None dni_np = None # ---------------------------------------- # Cluster labels to numpy # ---------------------------------------- if cluster_labels is None: # create a LongTensor of zeros, one per sample self.cluster_labels = torch.zeros(self.raw_data.shape[0], dtype=torch.long) cluster_labels_np = self.cluster_labels.numpy() else: if hasattr(cluster_labels, 'iloc'): cluster_labels_np = cluster_labels.values elif isinstance(cluster_labels, np.ndarray): cluster_labels_np = cluster_labels elif isinstance(cluster_labels, torch.Tensor): cluster_labels_np = cluster_labels.cpu().numpy() else: raise TypeError("Unsupported cluster_labels format. Must be Series, ndarray, or Tensor.") ## cluster labels stored as torch tensor ## Setting unique clusters once in a deterministic way! self.unique_clusters = np.sort(np.unique(cluster_labels_np)) self.cluster_labels = torch.tensor(cluster_labels_np, dtype=torch.long) self.n_clusters = len(np.unique(cluster_labels_np)) # unique_clusters = np.unique(cluster_labels_np) # ========================================= # VALIDATION BEGINS # - need to separate columns in categorical_column_map from all the others # - for others, do validation extraction normally # - for catcols do validation holdout by cat (holdout all columns of that cat in row) # ========================================= # -------------------------- # Resolve per-cluster validation proportion # -------------------------- def _as_per_cluster_props(vp): # scalar → broadcast if isinstance(vp, (int, float, np.floating)): p = float(vp) if not (0 <= p <= 1): raise ValueError("`val_proportion` scalar must be in [0, 1].") return {c: p for c in self.unique_clusters} # pandas Series with labeled index if isinstance(vp, pd.Series): mapping = {int(k): float(v) for k, v in vp.items()} missing = [c for c in self.unique_clusters if c not in mapping] if missing: raise ValueError(f"`val_proportion` Series missing clusters: {missing}") return mapping # Mapping (e.g., dict) if isinstance(vp, Mapping): mapping = {int(k): float(v) for k, v in vp.items()} missing = [c for c in self.unique_clusters if c not in mapping] if missing: raise ValueError(f"`val_proportion` mapping missing clusters: {missing}") return mapping # Sequence aligned to sorted unique clusters if isinstance(vp, Sequence): vals = list(vp) if len(vals) != len(self.unique_clusters): raise ValueError( f"`val_proportion` sequence length ({len(vals)}) must equal number of clusters ({len(self.unique_clusters)})." ) return {c: float(v) for c, v in zip(self.unique_clusters, vals)} raise TypeError( "`val_proportion` must be float in [0,1], a sequence (len = #clusters), " "a pandas Series (index=cluster), or a mapping {cluster: proportion}." ) per_cluster_prop = _as_per_cluster_props(val_proportion) for cid, p in per_cluster_prop.items(): if not (0.0 <= p <= 1.0): raise ValueError(f"`val_proportion` for cluster {cid} must be in [0, 1]; got {p}.") # ------------ # Build validation Units # - each non-categorical feature becomes its own unit # - each original categorical becomes one grouped unit # # Example: # validation_units["C1"] = {"kind": "categorical", "cols": [4, 5]} # ------------- # ============================================================ # Resolve categorical groups with FULL validation # ============================================================ self.categorical_column_map = categorical_column_map self.categorical_group_indices = {} def _resolve_col_to_index(col_id): """ Convert column identifier → integer index Supports: - column names (str) - integer indices """ if isinstance(col_id, str): if col_id not in self.feature_names: raise ValueError( f"Column '{col_id}' from categorical_column_map not found in data." ) return self.feature_names.index(col_id) elif isinstance(col_id, (int, np.integer)): col_id = int(col_id) if not (0 <= col_id < len(self.feature_names)): raise ValueError( f"Column index {col_id} from categorical_column_map is out of bounds." ) return col_id else: raise TypeError( "categorical_column_map entries must be column names (str) or integer indices." ) if categorical_column_map is not None: if not isinstance(categorical_column_map, dict): raise TypeError( "`categorical_column_map` must be a dictionary like " "{'C1': ['C1b1','C1b2'], ...}" ) used_cols = set() for cat_name, dummy_cols in categorical_column_map.items(): # ---- must be non-empty sequence ---- if not isinstance(dummy_cols, (list, tuple)): raise TypeError( f"Value for category '{cat_name}' must be a list/tuple of dummy columns." ) if len(dummy_cols) == 0: raise ValueError( f"Category '{cat_name}' cannot have empty dummy column list." ) # ---- resolve to indices ---- resolved = [_resolve_col_to_index(c) for c in dummy_cols] # ---- check duplicates within category ---- if len(set(resolved)) != len(resolved): raise ValueError( f"Duplicate dummy columns found within category '{cat_name}'." ) # ---- check overlap across categories ---- overlap = used_cols.intersection(resolved) if overlap: overlap_names = [self.feature_names[i] for i in sorted(overlap)] raise ValueError( f"Dummy columns {overlap_names} appear in more than one category." ) # ---- check binary mask correctness ---- for idx in resolved: if self.binary_feature_mask is None or not self.binary_feature_mask[idx]: raise ValueError( f"Dummy column '{self.feature_names[idx]}' is listed in categorical_column_map but is not marked True in binary_feature_mask." f"'{self.feature_names[idx]}' must be binary and marked True" "in binary_feature_mask." ) # ---- IMPORTANT CHANGE ---- # DO NOT block overlap with columns_ignore anymore self.categorical_group_indices[cat_name] = resolved used_cols.update(resolved) # ============================================================ # Build activation_groups (INCLUDES ignored categorical columns) # ============================================================ dummy_cols = set() for g in self.categorical_group_indices.values(): dummy_cols.update(g) self.activation_groups = {"binary": [], "continuous": []} for i in range(len(self.feature_names)): if i in dummy_cols: continue is_binary = ( False if self.binary_feature_mask is None else self.binary_feature_mask[i] ) if is_binary: self.activation_groups["binary"].append(i) else: self.activation_groups["continuous"].append(i) for k, cols in self.categorical_group_indices.items(): self.activation_groups[k] = list(cols) # ========================================================== # Ensure ALL columns are covered exactly once # ========================================================== all_grouped_cols = set() for cols in self.activation_groups.values(): all_grouped_cols.update(cols) expected_cols = set(range(len(self.feature_names))) missing = expected_cols - all_grouped_cols extra = all_grouped_cols - expected_cols if missing: raise RuntimeError( f"The following columns are missing from activation_groups: " f"{[self.feature_names[i] for i in sorted(missing)]}" ) if extra: raise RuntimeError( f"Invalid column indices found in activation_groups: {sorted(extra)}" ) seen = set() for name, cols in self.activation_groups.items(): overlap = seen.intersection(cols) if overlap: raise RuntimeError( f"Columns appear in multiple activation groups: {overlap}" ) seen.update(cols) # ============================================================ # Build validation_units (EXCLUDES ignored columns) # ============================================================ self.validation_units = {} for i, name in enumerate(self.feature_names): if i in self.ignore_indices: continue if i in dummy_cols: continue is_binary = ( False if self.binary_feature_mask is None else self.binary_feature_mask[i] ) if is_binary: self.validation_units[name] = {"kind": "binary", "cols": [i]} else: self.validation_units[name] = {"kind": "continuous", "cols": [i]} ignore_set = set(self.ignore_indices) for k, cols in self.categorical_group_indices.items(): cols_set = set(cols) ignored_in_group = cols_set.intersection(ignore_set) if len(ignored_in_group) == 0: self.validation_units[k] = {"kind": "categorical", "cols": list(cols)} elif len(ignored_in_group) == len(cols_set): # all ignored → skip entirely continue else: bad_cols = [self.feature_names[i] for i in sorted(ignored_in_group)] raise ValueError( f"Categorical group '{k}' has partially ignored columns: {bad_cols}. " "Either ignore ALL dummy columns for this category or NONE." ) # ============================================================ # Validation masking # ============================================================ val_mask_np = np.zeros_like(raw_data_np, dtype=bool) for cid in self.unique_clusters: rows = np.where(cluster_labels_np == cid)[0] if rows.size == 0: continue prop = per_cluster_prop[cid]# prop = val_proportion if isinstance(val_proportion, float) else val_proportion[cid] cluster_data = raw_data_np[rows] for unit_name, info in self.validation_units.items(): cols = info["cols"] kind = info["kind"] if kind in ["binary", "continuous"]: col = cols[0] valid = ~np.isnan(cluster_data[:, col]) # ---------------------------------------- # Exclude DNI entries (DO NOT IMPUTE) # ---------------------------------------- if dni_np is not None: valid = valid & (dni_np[rows, col] == 1) idxs = np.where(valid)[0] if len(idxs) == 0: continue n_val = int(len(idxs) * prop) if n_val == 0 and prop > 0: n_val = 1 chosen = self._rng.choice(idxs, size=n_val, replace=False) val_mask_np[rows[chosen], col] = True elif kind == "categorical": group = np.array(cols) valid = np.all(~np.isnan(cluster_data[:, group]), axis=1) if dni_np is not None: valid = valid & (dni_np[rows][:, group].all(axis=1)) idxs = np.where(valid)[0] if len(idxs) == 0: continue n_val = int(len(idxs) * prop) if n_val == 0 and prop > 0: n_val = 1 chosen = self._rng.choice(idxs, size=n_val, replace=False) val_mask_np[np.ix_(rows[chosen], group)] = True # ========================================= # END VALIDATION CHANGES # ========================================= val_mask_tensor = torch.tensor(val_mask_np, dtype=torch.bool) ## val_mask is a tensor self.val_mask = val_mask_tensor # ---------------------------------------- # Set aside val_data # ---------------------------------------- self.val_data = self.raw_data.clone() self.val_data[~val_mask_tensor] = torch.nan # keep only validation-masked values if len(self.ignore_indices) > 0: ignore_idx = torch.tensor(self.ignore_indices, dtype=torch.long) self.val_data[:, ignore_idx] = torch.nan # ---------------------------------------- # Combine true + validation-masked missingness # ---------------------------------------- self.data = self.raw_data.clone() self.data[val_mask_tensor] = torch.nan # mask validation entries # ---------------------------------------- # Normalize non-missing entries # ---------------------------------------- ## Compute mean and std on observed (non-NaN) entries data_np = self.data.numpy() self.feature_means = np.nanmean(data_np, axis=0) self.feature_stds = np.nanstd(data_np, axis=0) self.feature_means = np.nan_to_num(self.feature_means, nan=0.0) self.feature_stds = np.nan_to_num(self.feature_stds, nan=1.0) zero_std_idx = np.where(self.feature_stds == 0)[0] if zero_std_idx.size > 0: bad_feats = [self.feature_names[i] for i in zero_std_idx] print( f"[Warning] {len(zero_std_idx)} feature(s) had zero std after masking. " f"Replaced with 1.0 to avoid div-by-zero. " f"Features: {bad_feats}" ) self.feature_stds[self.feature_stds == 0] = 1.0 # avoid division by zero ## improved handling of bfm if self.binary_feature_mask is not None: norm_data_cont = (data_np - self.feature_means) / self.feature_stds bfm_mask = self.binary_feature_mask.astype(bool) norm_data_np = data_np * bfm_mask + norm_data_cont * (~bfm_mask) else: ## Normalize (in-place) norm_data_np = (data_np - self.feature_means) / self.feature_stds self.data = torch.tensor(norm_data_np, dtype=torch.float32) # ---------------------------------------- # Track missing & replace with value # ---------------------------------------- self.masks = ~torch.isnan(self.data) ## true where value not na self.data = torch.where( self.masks, self.data, torch.tensor(replacement_value, dtype=torch.float32) ) self.shape = self.data.shape
[docs] def get_activation_groups(self, exclude_ignored: bool = False): """ Return activation groups, optionally excluding ignored columns. Parameters ---------- exclude_ignored : bool If True, removes columns listed in columns_ignore. Returns ------- dict Filtered activation groups with ignored columns removed. """ # -------------------------------------------------- # 1. Fast path: no filtering needed # -------------------------------------------------- if not exclude_ignored: return self.activation_groups # -------------------------------------------------- # 2. Convert ignore list to set (O(1) lookup) # -------------------------------------------------- ignore_set = set(self.ignore_indices) filtered = {} # -------------------------------------------------- # 3. Filter each group # -------------------------------------------------- for name, cols in self.activation_groups.items(): # Remove ignored columns kept = [c for c in cols if c not in ignore_set] # Only keep non-empty groups if len(kept) > 0: filtered[name] = kept # -------------------------------------------------- # 4. SAFETY CHECK (prevents silent bugs) # -------------------------------------------------- # Ensure no ignored columns leaked through for name, cols in filtered.items(): overlap = set(cols).intersection(ignore_set) if overlap: raise RuntimeError( f"BUG: ignored columns still present in activation group '{name}': {overlap}" ) return filtered
def __len__(self): """ Number of samples in the dataset. :return: ``N`` (number of rows). """ return len(self.data) def __getitem__(self, index): """ Get a single sample. :param index: Row index. :return: Tuple ``(x, cluster_id, mask, original_index)`` where: * **x** – normalized input row with NaNs replaced (``(P,)``). * **cluster_id** – integer cluster label (``()``). * **mask** – boolean mask of observed entries before replacement (``(P,)``). * **original_index** – original row index from the source DataFrame (if provided) or the integer position. """ return ( self.data[index], # input with missing replaced self.cluster_labels[index], # cluster label self.masks[index], # binary mask self.indices[index], # original row index ) def __repr__(self): """Displays the number of samples, features, and clusters, the percentage of missing data before masking, and the percentage of non-missing data held out for validation. :return: String representation of the dataset :rtype: str """ n, p = self.data.shape total_values = n * (p-len(self.columns_ignore)) ## Percent originally missing (before validation mask) original_missing = torch.isnan(self.raw_data).sum().item() original_missing_pct = 100 * original_missing / total_values ## Percent used for validation (out of non-missing entries) val_entries = torch.sum(~torch.isnan(self.val_data)).item() # number of validation-held entries val_pct_of_nonmissing = 100 * val_entries / (total_values - original_missing) ## Count non-imputable entries (where can_impute == 0) non_imputable_count = None if hasattr(self, "imputable") and self.imputable is not None: non_imputable_count = int((self.imputable == 0).sum().item()) ## Build string out = ( f"ClusterDataset(n_samples={n}, n_features={p}, n_clusters={len(torch.unique(self.cluster_labels))})\n" f" • Original missing: {original_missing} / {total_values} " f"({original_missing_pct:.2f}%)\n" f" • Validation held-out: {val_entries} " f"({val_pct_of_nonmissing:.2f}% of non-missing)\n" f" • .data shape: {tuple(self.data.shape)}\n" f" • .masks shape: {tuple(self.masks.shape)}\n" f" • .val_data shape: {tuple(self.val_data.shape)}" ) if non_imputable_count is not None: out += f"\n • Non-imputable entries: {non_imputable_count}" if hasattr(self, "validation_units"): unit_summary = { k: {"kind": v["kind"], "cols": [self.feature_names[i] for i in v["cols"]]} for k, v in self.validation_units.items() } out += f"\n • Validation units: {unit_summary}" return out # ---------------------------------------- # Added copy method # ----------------------------------------
[docs] def copy(self): """Creates a deep copy of the ClusterDataset method containing all attributes. :return: Deep copy of the dataset :rtype: ClusterDataset """ return copy.deepcopy(self)
def __str__(self): """Displays the number of samples, features, and clusters, the percentage of missing data before masking, and the percentage of non-missing data held out for validation. :return: String representation of the dataset :rtype: str """ return self.__repr__()