CISS-VAE Quickstart

The Clustering-Informed Shared-Structure Variational Autoencoder (CISS-VAE) is a flexible deep learning model for missing data imputation that accommodates all three types of missing data mechanisms: Missing Completely At Random (MCAR), Missing At Random (MAR), and Missing Not At Random (MNAR). While it is particularly well-suited to MNAR scenarios where missingness patterns carry informative signals, CISS-VAE also functions effectively under MAR assumptions. Please see our publication for more details.

Example CISS-VAE for Imputation Workflow

Installation

The CISS-VAE package is currently available for python, with an R package to be released soon. It can be installed from either github or PyPI.

# From PyPI 
pip install ciss-vae
# From GitHub (latest development version)
pip install git+https://github.com/CISS-VAE/CISS-VAE-python.git

Note

If you want run_cissvae to handle clustering, please install the clustering dependencies scikit-learn, leidenalg, python-igraph with pip.

pip install scikit-learn leidenalg python-igraph

OR

pip install ciss-vae[clustering]

Quickstart

If your data has binary or categorical variables, please see the binary data vignette. To load the sample dataset:

from ciss_vae.data import load_example_dataset
df_missing, df_complete, clusters = load_example_dataset()

To run the CISSVAE imputation model with default parameters (assuming known clusters).

import pandas as pd
from ciss_vae.training.run_cissvae import run_cissvae
# optional, display vae architecture
from ciss_vae.utils.helpers import plot_vae_architecture

imputed_data, vae = run_cissvae(
    data = df_missing,
    columns_ignore = df_missing.columns[:5], ## columns to ignore when selecting validation dataset (and clustering if you do not provide clusters).
    clusters = clusters
)

## OPTIONAL - PLOT VAE ARCHITECTURE
plot_vae_architecture(model = vae,
                        title = None)
Cluster dataset:
 ClusterDataset(n_samples=8000, n_features=30, n_clusters=4)
  • Original missing: 61800 / 200000 (30.90%)
  • Validation held-out: 13783 (9.97% of non-missing)
  • .data shape:     (8000, 30)
  • .masks shape:    (8000, 30)
  • .val_data shape: (8000, 30)
  • Validation units: {'Y11': {'kind': 'continuous', 'cols': ['Y11']}, 'Y12': {'kind': 'continuous', 'cols': ['Y12']}, 'Y13': {'kind': 'continuous', 'cols': ['Y13']}, 'Y14': {'kind': 'continuous', 'cols': ['Y14']}, 'Y15': {'kind': 'continuous', 'cols': ['Y15']}, 'Y21': {'kind': 'continuous', 'cols': ['Y21']}, 'Y22': {'kind': 'continuous', 'cols': ['Y22']}, 'Y23': {'kind': 'continuous', 'cols': ['Y23']}, 'Y24': {'kind': 'continuous', 'cols': ['Y24']}, 'Y25': {'kind': 'continuous', 'cols': ['Y25']}, 'Y31': {'kind': 'continuous', 'cols': ['Y31']}, 'Y32': {'kind': 'continuous', 'cols': ['Y32']}, 'Y33': {'kind': 'continuous', 'cols': ['Y33']}, 'Y34': {'kind': 'continuous', 'cols': ['Y34']}, 'Y35': {'kind': 'continuous', 'cols': ['Y35']}, 'Y41': {'kind': 'continuous', 'cols': ['Y41']}, 'Y42': {'kind': 'continuous', 'cols': ['Y42']}, 'Y43': {'kind': 'continuous', 'cols': ['Y43']}, 'Y44': {'kind': 'continuous', 'cols': ['Y44']}, 'Y45': {'kind': 'continuous', 'cols': ['Y45']}, 'Y51': {'kind': 'continuous', 'cols': ['Y51']}, 'Y52': {'kind': 'continuous', 'cols': ['Y52']}, 'Y53': {'kind': 'continuous', 'cols': ['Y53']}, 'Y54': {'kind': 'continuous', 'cols': ['Y54']}, 'Y55': {'kind': 'continuous', 'cols': ['Y55']}}
_images/d4335a43be56a88a3e25ab06e54a36593ee53e42e38bff2c00c73f32cb230cd0.png

To have run_cissvae() perform data clustering with Leiden:

imputed_data, vae = run_cissvae(
    data = df_missing,
    columns_ignore = df_missing.columns[:5], ## columns to ignore when selecting validation dataset (and clustering if you do not provide clusters).
    clusters = None,
    k_neighbors = 500,
    leiden_resolution = 0.001,
    epochs = 100
)
## OPTIONAL - PLOT VAE ARCHITECTURE
plot_vae_architecture(model = vae,
                        title = None)
Cluster dataset:
 ClusterDataset(n_samples=8000, n_features=30, n_clusters=2)
  • Original missing: 61800 / 200000 (30.90%)
  • Validation held-out: 13808 (9.99% of non-missing)
  • .data shape:     (8000, 30)
  • .masks shape:    (8000, 30)
  • .val_data shape: (8000, 30)
  • Validation units: {'Y11': {'kind': 'continuous', 'cols': ['Y11']}, 'Y12': {'kind': 'continuous', 'cols': ['Y12']}, 'Y13': {'kind': 'continuous', 'cols': ['Y13']}, 'Y14': {'kind': 'continuous', 'cols': ['Y14']}, 'Y15': {'kind': 'continuous', 'cols': ['Y15']}, 'Y21': {'kind': 'continuous', 'cols': ['Y21']}, 'Y22': {'kind': 'continuous', 'cols': ['Y22']}, 'Y23': {'kind': 'continuous', 'cols': ['Y23']}, 'Y24': {'kind': 'continuous', 'cols': ['Y24']}, 'Y25': {'kind': 'continuous', 'cols': ['Y25']}, 'Y31': {'kind': 'continuous', 'cols': ['Y31']}, 'Y32': {'kind': 'continuous', 'cols': ['Y32']}, 'Y33': {'kind': 'continuous', 'cols': ['Y33']}, 'Y34': {'kind': 'continuous', 'cols': ['Y34']}, 'Y35': {'kind': 'continuous', 'cols': ['Y35']}, 'Y41': {'kind': 'continuous', 'cols': ['Y41']}, 'Y42': {'kind': 'continuous', 'cols': ['Y42']}, 'Y43': {'kind': 'continuous', 'cols': ['Y43']}, 'Y44': {'kind': 'continuous', 'cols': ['Y44']}, 'Y45': {'kind': 'continuous', 'cols': ['Y45']}, 'Y51': {'kind': 'continuous', 'cols': ['Y51']}, 'Y52': {'kind': 'continuous', 'cols': ['Y52']}, 'Y53': {'kind': 'continuous', 'cols': ['Y53']}, 'Y54': {'kind': 'continuous', 'cols': ['Y54']}, 'Y55': {'kind': 'continuous', 'cols': ['Y55']}}
_images/9268fed5b306b9eb1a4f73bec4ac4340ec91a4b6c468953a700cff652b0f3f03.png