ciss_vae.utils.helpers.plot_vae_architecture
- plot_vae_architecture(model, title=None, color_shared='skyblue', color_unshared='lightcoral', color_latent='gold', color_input='lightgreen', color_output='lightgreen', figsize=(16, 8), return_fig=False, fontsize_layer=12, fontsize_section=14, fontsize_title=16)[source]
Plots a horizontal schematic of the VAE architecture, showing shared and cluster-specific layers.
- Parameters:
model (nn.Module) – An instance of CISSVAE model to visualize
title (str, optional) – Title of the plot, defaults to None
color_shared (str, optional) – Color for shared hidden layers, defaults to “skyblue”
color_unshared (str, optional) – Color for unshared hidden layers, defaults to “lightcoral”
color_latent (str, optional) – Color for latent layer, defaults to “gold”
color_input (str, optional) – Color for input layer, defaults to “lightgreen”
color_output (str, optional) – Color for output layer, defaults to “lightgreen”
figsize (tuple, optional) – Size of the matplotlib figure, defaults to (16, 8)
return_fig (bool, optional) – Whether to return the figure object instead of displaying, defaults to False
fontsize_layer (int, optional) – Font size of layer blocks, defaults to 12
fontsize_section (int, optional) – Font size of encoder/decoder labels, defaults to 14
fontsize_title (int, optional) – Font size of title, defaults to 16
- Returns:
Matplotlib figure object if return_fig is True, otherwise None
- Return type:
matplotlib.figure.Figure or None