ciss_vae.utils.helpers.plot_vae_architecture

plot_vae_architecture(model, title=None, color_shared='skyblue', color_unshared='lightcoral', color_latent='gold', color_input='lightgreen', color_output='lightgreen', figsize=(16, 8), return_fig=False, fontsize_layer=12, fontsize_section=14, fontsize_title=16)[source]

Plots a horizontal schematic of the VAE architecture, showing shared and cluster-specific layers.

Parameters:
  • model (nn.Module) – An instance of CISSVAE model to visualize

  • title (str, optional) – Title of the plot, defaults to None

  • color_shared (str, optional) – Color for shared hidden layers, defaults to “skyblue”

  • color_unshared (str, optional) – Color for unshared hidden layers, defaults to “lightcoral”

  • color_latent (str, optional) – Color for latent layer, defaults to “gold”

  • color_input (str, optional) – Color for input layer, defaults to “lightgreen”

  • color_output (str, optional) – Color for output layer, defaults to “lightgreen”

  • figsize (tuple, optional) – Size of the matplotlib figure, defaults to (16, 8)

  • return_fig (bool, optional) – Whether to return the figure object instead of displaying, defaults to False

  • fontsize_layer (int, optional) – Font size of layer blocks, defaults to 12

  • fontsize_section (int, optional) – Font size of encoder/decoder labels, defaults to 14

  • fontsize_title (int, optional) – Font size of title, defaults to 16

Returns:

Matplotlib figure object if return_fig is True, otherwise None

Return type:

matplotlib.figure.Figure or None