r"""
Variational Autoencoder with cluster‑aware shared/unshared layers.
This module defines :class:`CISSVAE`, a VAE that can route samples through
either **shared** or **cluster‑specific** (unshared) layers in the encoder and
decoder, controlled by per‑layer directives. For all features, the model outputs raw logits. Feature-specific activation
functions (e.g., sigmoid for binary, softmax for categorical) are applied
externally using ``activation_groups`` to map features to correct activation function.
"""
import torch
import torch.nn as nn
from typing import Iterable, Optional, Sequence, Union
import numpy as np
[docs]
class CISSVAE(nn.Module):
r"""
Clustering-Informed Shared-Structure Variational Autoencoder (CISSVAE).
Supports flexible mixtures of **shared** and **unshared** layers across
clusters in both encoder and decoder. Unshared layers are applied by cluster, shared layers are applied to all samples.
:param input_dim: Number of input features (columns).
:type input_dim: int
:param hidden_dims: Width of each hidden layer (encoder goes forward, decoder uses the reverse).
:type hidden_dims: list[int]
:param layer_order_enc: Per‑encoder‑layer directive: ``"shared"`` or ``"unshared"``.
:type layer_order_enc: list[str]
:param layer_order_dec: Per‑decoder‑layer directive: ``"shared"`` or ``"unshared"``.
:type layer_order_dec: list[str]
:param latent_shared: If ``True``, the latent heads (``mu``, ``logvar``) are shared across clusters; otherwise one head per cluster.
:type latent_shared: bool
:param latent_dim: Dimensionality of the latent space.
:type latent_dim: int
:param output_shared: If ``True``, the final output layer is shared; otherwise one output layer per cluster.
:type output_shared: bool
:param num_clusters: Number of clusters present in the data.
:type num_clusters: int
:param debug: If ``True``, prints routing shapes and asserts row order invariants.
:type debug: bool
:raises ValueError: If an item of ``layer_order_enc`` or ``layer_order_dec`` is not one of
``{"shared","unshared","s","u"}`` (case‑insensitive), or if their lengths do not match
``len(hidden_dims)`` for the respective path.
**Expected shapes**
* Encoder input ``x``: ``(batch, input_dim)``
* Cluster labels ``cluster_labels``: ``(batch,)`` (``LongTensor`` with values in ``[0, num_clusters-1]``)
* Decoder/Output: ``(batch, input_dim)``
**Notes**
* The decoder consumes ``hidden_dims[::-1]`` (reverse order).
* Unshared layers maintain per‑cluster ``ModuleList``/``ModuleDict`` replicas.
* Routing never reorders rows; masks are used to apply cluster‑specific sublayers in‑place.
"""
[docs]
def __init__(self,
input_dim,
hidden_dims,
layer_order_enc,
layer_order_dec,
latent_shared,
latent_dim,
output_shared,
num_clusters,
# new optional inputs to define binary features at init time -> udpdate 14OCT2025
activation_groups = None,
debug=False,):
"""
Variational Autoencoder supporting flexible shared/unshared layers across clusters.
:param input_dim: Number of input features.
:type input_dim: int
:param hidden_dims: Dimensions of hidden layers.
:type hidden_dims: list[int]
:param layer_order_enc: Layer type for each encoder layer (``"shared"`` or ``"unshared"``).
:type layer_order_enc: list[str]
:param layer_order_dec: Layer type for each decoder layer (``"shared"`` or ``"unshared"``).
:type layer_order_dec: list[str]
:param latent_shared: Whether latent representation is shared across clusters.
:type latent_shared: bool
:param latent_dim: Dimensionality of the latent space.
:type latent_dim: int
:param output_shared: Whether output layer is shared across clusters.
:type output_shared: bool
:param num_clusters: Number of clusters.
:type num_clusters: int
:param activation_groups: Dictionary mapping feature groups to column indices. Attribute of ClusterDataset object.
Expected format:
{
"continuous": [int, ...],
"binary": [int, ...],
"<categorical_name>": [int, ...],
...
}
Each key defines a feature group:
- "continuous": indices of continuous-valued features
- "binary": indices of binary features
- additional keys correspond to grouped categorical variables
(e.g., one-hot encoded columns belonging to the same variable)
This structure is used to determine loss functions and output
transformations outside the model.
:type activation_groups: dict[str, list[int]]
:param debug: If ``True``, print shape and routing information.
:type debug: bool
"""
super().__init__()
self.debug = debug
self.num_clusters = num_clusters
self.latent_shared = latent_shared
self.layer_order_enc = self._normalize_layer_order(layer_order_enc, "layer_order_enc")
self.layer_order_dec = self._normalize_layer_order(layer_order_dec, "layer_order_dec")
self.latent_dim = latent_dim
self.input_dim = input_dim
self.output_shared = output_shared
self.hidden_dims = hidden_dims
# -------------------------
# (NEW) Activation groups metadata
# - Replaces binary_feature_mask
# - Stores feature structure for downstream use (loss, imputation, validation)
# -------------------------
if activation_groups is not None:
if self.debug:
print(f"Activation Groups: {activation_groups}\n")
if not isinstance(activation_groups, dict):
raise ValueError("activation_groups must be a dictionary.")
# Validate and normalize
validated_groups = {}
for key, cols in activation_groups.items():
if not isinstance(cols, (list, tuple)):
raise ValueError(f"activation_groups['{key}'] must be a list of column indices.")
clean_cols = []
for c in cols:
if not isinstance(c, (int, np.integer)):
raise ValueError(f"Column index '{c}' in group '{key}' is not an integer.")
c = int(c)
if not (0 <= c < input_dim):
raise ValueError(
f"Column index {c} in group '{key}' is out of bounds for input_dim={input_dim}."
)
clean_cols.append(c)
validated_groups[key] = clean_cols
# Store as plain attribute (NOT a buffer)
self.activation_groups = validated_groups
else:
# Default: treat everything as continuous if nothing provided
if self.debug:
print("No activation_groups provided; defaulting to all continuous.\n")
self.activation_groups = {
"continuous": list(range(input_dim))
}
# ----------------------------
# Encoder: shared and unshared
# ----------------------------
self.encoder_layers = nn.ModuleList()
self.cluster_encoder_layers = nn.ModuleDict({
str(i): nn.ModuleList() for i in range(num_clusters)
})
in_features = input_dim
for idx, (out_features, layer_type) in enumerate(zip(hidden_dims, layer_order_enc)):
lt = layer_type.lower()
if lt in ["shared", "s"]:
# (A) no dtype/device kwargs → use PyTorch defaults
self.encoder_layers.append(
nn.Sequential(
nn.Linear(in_features, out_features),
nn.ReLU()
)
)
elif lt in ["unshared", "u"]:
for c in range(num_clusters):
self.cluster_encoder_layers[str(c)].append(
nn.Sequential(
nn.Linear(in_features, out_features),
nn.ReLU()
)
)
else:
raise ValueError(f"Invalid encoder layer type at index {idx}: {layer_type}")
in_features = out_features
# ----------------------------
# Latent layers
# ----------------------------
if latent_shared:
# These also use defaults
self.fc_mu = nn.Linear(in_features, latent_dim)
self.fc_logvar = nn.Linear(in_features, latent_dim)
else:
self.cluster_fc_mu = nn.ModuleDict({
str(i): nn.Linear(in_features, latent_dim)
for i in range(num_clusters)
})
self.cluster_fc_logvar = nn.ModuleDict({
str(i): nn.Linear(in_features, latent_dim)
for i in range(num_clusters)
})
# ----------------------------
# Decoder: shared and unshared
# ----------------------------
self.decoder_layers = nn.ModuleList()
self.cluster_decoder_layers = nn.ModuleDict({
str(i): nn.ModuleList() for i in range(num_clusters)
})
in_features = latent_dim
for idx, (out_features, layer_type) in enumerate(zip(hidden_dims[::-1], layer_order_dec)):
if layer_type == "shared":
self.decoder_layers.append(
nn.Sequential(
nn.Linear(in_features, out_features),
nn.ReLU()
)
)
elif layer_type == "unshared":
for c in range(num_clusters):
self.cluster_decoder_layers[str(c)].append(
nn.Sequential(
nn.Linear(in_features, out_features),
nn.ReLU()
)
)
else:
raise ValueError(f"Invalid decoder layer type at index {idx}: {layer_type}")
in_features = out_features
# ----------------------------
# Output Layer
# ----------------------------
if output_shared:
self.final_layer = nn.Linear(in_features, input_dim)
else:
self.cluster_final_layer = nn.ModuleDict({
str(i): nn.Linear(in_features, input_dim)
for i in range(num_clusters)
})
[docs]
def route_through_layers(self, x, cluster_labels,
layer_type_list,
shared_layers,
unshared_layers):
r"""
Apply a sequence of shared/unshared layers according to ``layer_type_list``.
For each position ``i``:
* if ``layer_type_list[i]`` is shared → apply ``shared_layers[i_shared]`` to all rows;
* if unshared → for each cluster ``c``, apply the ``c``‑specific layer
at that depth to the subset of rows where ``cluster_labels == c``.
:param x: Input activations to be routed.
:type x: torch.Tensor, shape ``(batch, d_in)``
:param cluster_labels: Cluster id per row.
:type cluster_labels: torch.LongTensor, shape ``(batch,)``
:param layer_type_list: Sequence of ``"shared"``/``"unshared"`` flags (length = number of layers at this stage).
:type layer_type_list: list[str]
:param shared_layers: Layers used when the directive is shared (index increases only when a shared layer is consumed).
:type shared_layers: torch.nn.ModuleList
:param unshared_layers: Per‑cluster lists of layers for unshared directives (index per cluster increases only when an unshared layer is consumed).
:type unshared_layers: dict[str, torch.nn.ModuleList] | torch.nn.ModuleDict
:returns: Routed activations.
:rtype: torch.Tensor
:raises ValueError: If an entry in ``layer_type_list`` is invalid or if per‑cluster
unshared stacks are inconsistent with the directives.
"""
shared_idx = 0
unshared_idx = {str(c): 0 for c in range(self.num_clusters)}
# if self.debug:
# input_hash = torch.arange(x.shape[0], device=x.device)
# print(f"layer_type_list: {layer_type_list}")
# print(f"num_clusters: {self.num_clusters}")
# for c in range(self.num_clusters):
# print(f"Cluster {c} unshared layers: {len(unshared_layers[str(c)])}")
# print(f"Number of unshared layers needed: {layer_type_list.count('unshared')}")
for layer_num, layer_type in enumerate(layer_type_list):
if layer_type.lower() in ["shared", "s"]:
x = shared_layers[shared_idx](x)
shared_idx += 1
else:
outputs = []
for c in range(self.num_clusters):
mask = (cluster_labels == c)
if mask.any():
x_c = x[mask]
x_out = unshared_layers[str(c)][unshared_idx[str(c)]](x_c)
outputs.append((mask, x_out))
out_dim = outputs[0][1].shape[1]
# Provide explicit dtype/device from x
output = torch.empty(x.shape[0], out_dim,
device=x.device,
dtype=x.dtype)
for mask, x_out in outputs:
output[mask] = x_out
x = output
for c in range(self.num_clusters):
unshared_idx[str(c)] += 1
# if self.debug:
# out_hash = torch.arange(x.shape[0], device=x.device)
# assert torch.equal(input_hash, out_hash), "Row order mismatch!"
return x
[docs]
def encode(self, x, cluster_labels):
r"""
Encoder forward pass producing ``mu`` and ``logvar``.
:param x: Input batch.
:type x: torch.Tensor, shape ``(batch, input_dim)``
:param cluster_labels: Cluster id per row.
:type cluster_labels: torch.LongTensor, shape ``(batch,)``
:returns: Tuple ``(mu, logvar)``.
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
x = self.route_through_layers(
x, cluster_labels,
self.layer_order_enc,
self.encoder_layers,
self.cluster_encoder_layers
)
if self.latent_shared:
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
else:
mu = torch.empty(x.size(0), self.latent_dim,
device=x.device, dtype=x.dtype)
logvar = torch.empty_like(mu)
for c in range(self.num_clusters):
mask = (cluster_labels == c)
if mask.any():
mu[mask] = self.cluster_fc_mu[str(c)](x[mask])
logvar[mask] = self.cluster_fc_logvar[str(c)](x[mask])
return mu, logvar
[docs]
def reparameterize(self, mu, logvar, generator=None):
r"""
Reparameterization trick: ``z = mu + eps * exp(0.5 * logvar)``.
:param mu: Mean of the approximate posterior.
:type mu: torch.Tensor, shape ``(batch, latent_dim)``
:param logvar: Log‑variance of the approximate posterior.
:type logvar: torch.Tensor, shape ``(batch, latent_dim)``
:param generator: Optionl for RNG control (default None)
:type generator: torch.Generator
:returns: Sampled latent codes ``z``.
:rtype: torch.Tensor
"""
## Add the generator -> generator owned by training loop
std = torch.exp(0.5 * logvar)
if generator is None:
eps = torch.randn_like(std)
else:
eps = torch.randn(
std.shape,
device=std.device,
dtype=std.dtype,
generator=generator,
)
return mu + eps * std
[docs]
def decode(self, z, cluster_labels):
r"""
Decoder forward pass from latent ``z`` to reconstruction.
:param z: Latent codes.
:type z: torch.Tensor, shape ``(batch, latent_dim)``
:param cluster_labels: Cluster id per row.
:type cluster_labels: torch.LongTensor, shape ``(batch,)``
:returns: Reconstructed inputs.
:rtype: torch.Tensor, shape ``(batch, input_dim)``
"""
## 30 sep 2025 -> changed mask to cluster_mask so I can stop getting confused
z = self.route_through_layers(
z, cluster_labels,
self.layer_order_dec,
self.decoder_layers,
self.cluster_decoder_layers
)
## final layer is nn.Linear
# ----------------------------------------
# 14 OCT 2025 - Add Logic for handling logit sigmoid thingie
# - gathers the final layers and applies the output activations according to the mask
# ----------------------------------------
if self.output_shared:
logits = self.final_layer(z)
else:
outputs = []
for c in range(self.num_clusters):
cluster_mask = (cluster_labels == c)
if cluster_mask.any():
z_c = z[cluster_mask]
z_out = self.cluster_final_layer[str(c)](z_c)
outputs.append((cluster_mask, z_out))
out_dim = outputs[0][1].shape[1]
logits = torch.empty(z.shape[0], out_dim,
device=z.device,
dtype=z.dtype)
for cluster_mask, z_out in outputs:
logits[cluster_mask] = z_out
return logits
[docs]
def forward(self, x, cluster_labels, deterministic=False, *, generator = None):
r"""
Full VAE forward pass: encode → reparameterize → decode.
:param x: Input batch.
:type x: torch.Tensor, shape ``(batch, input_dim)``
:param cluster_labels: Cluster id per row.
:type cluster_labels: torch.LongTensor, shape ``(batch,)``
:param deterministic: Deterministic Evaluation of Model for Imputation (default False)
:type deterministic: bool
:param generator: Optionl for RNG control (default None)
:type torch.Generator
:returns: Tuple ``(recon, mu, logvar)``.
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""
# if self.debug:
# print(f"[DEBUG] Forward start: {x.shape}")
mu, logvar = self.encode(x, cluster_labels)
if deterministic:
z = mu
else:
z = self.reparameterize(mu, logvar, generator = generator)
recon = self.decode(z, cluster_labels)
# if self.debug:
# print(f"[DEBUG] Forward end: {recon.shape}")
return recon, mu, logvar
def __repr__(self):
r"""
String summary of the architecture (encoder/latent/decoder layout).
:returns: Human‑readable multi‑line description.
:rtype: str
"""
lines = [f"CISSVAE(input_dim={self.input_dim}, latent_dim={self.latent_dim}, "
f"latent_shared={self.latent_shared}, output_shared={self.output_shared},"
f"num_clusters={self.num_clusters})"]
lines.append("Encoder Layers:")
in_dim = self.input_dim
for i, (out_dim, lt) in enumerate(zip(self.hidden_dims, self.layer_order_enc)):
lines.append(f" [{i}] {lt.upper():<8} {in_dim} → {out_dim}")
in_dim = out_dim
lines.append("\nLatent Layer:")
if self.latent_shared:
lines.append(f" SHARED {in_dim} → {self.latent_dim}")
else:
for c in range(self.num_clusters):
lines.append(f" UNSHARED (cluster {c}) {in_dim} → {self.latent_dim}")
lines.append("\nDecoder Layers:")
hidden_rev = self.hidden_dims[::-1]
in_dim = self.latent_dim
for i, (out_dim, lt) in enumerate(zip(hidden_rev, self.layer_order_dec)):
lines.append(f" [{i}] {lt.upper():<8} {in_dim} → {out_dim}")
in_dim = out_dim
lines.append("\nFinal Output Layer:")
if self.output_shared:
lines.append(f" SHARED {in_dim} → {self.input_dim}")
else:
for c in range(self.num_clusters):
lines.append(f" UNSHARED (cluster {c}) {in_dim} → {self.input_dim}")
return "\n".join(lines)
def __str__(self):
"""Mimics repr"""
return self.__repr__()
[docs]
def set_final_lr(self, final_lr):
"""Stores final lr from initial training loop in model attributes to be accessed in refit loop."""
self.final_lr = final_lr
[docs]
def get_final_lr(self):
"""Returns the learning rate stored with self.set_final_lr/"""
return(self.final_lr)
[docs]
def get_imputed_valdata(self, dataset, device="cpu", deterministic=True):
"""
Compute imputed values for the validation dataset.
Performs a forward pass through the trained VAE and reconstructs only the
validation-masked entries. Outputs are interpreted according to
`activation_groups`:
- continuous → denormalized using mean/std
- binary → sigmoid applied
- categorical → softmax + argmax (one-hot reconstruction)
Parameters
----------
dataset : ClusterDataset
Dataset object containing:
- data : normalized tensor (N, D)
- val_data : original data with NaNs at validation positions
- cluster_labels : (N,)
- feature_means : (D,)
- feature_stds : (D,)
- feature_names : list[str]
- activation_groups : dict[str, list[int]]
device : str, default="cpu"
Device to run computations on.
deterministic : bool, default=True
Whether to use deterministic forward pass.
Returns
-------
torch.Tensor
Tensor of shape (N, D) containing imputed validation values.
Non-validation entries are set to NaN.
"""
self.eval()
# -------------------------
# Load data
# -------------------------
full_x = dataset.data.to(device)
full_cluster = dataset.cluster_labels.to(device)
val_data = dataset.val_data.to(device)
val_mask = torch.isnan(val_data)
# -------------------------
# Forward pass (logits)
# -------------------------
with torch.no_grad():
logits, _, _ = self.forward(
full_x, full_cluster, deterministic=deterministic
)
# -------------------------
# Feature stats
# -------------------------
means = torch.as_tensor(dataset.feature_means, dtype=torch.float32, device=device)
stds = torch.as_tensor(dataset.feature_stds, dtype=torch.float32, device=device)
# Handle zero std safely
zero_std_idx = torch.where(stds == 0)[0]
if zero_std_idx.numel() > 0:
bad_feats = [dataset.feature_names[i] for i in zero_std_idx.tolist()]
print(f"[Warning] std == 0 → replaced with 1.0: {bad_feats}")
stds[zero_std_idx] = 1.0
# -------------------------
# Output container
# -------------------------
recon_out = logits.clone().to(torch.float32)
# -------------------------
# Apply activation groups
# -------------------------
for name, cols in dataset.activation_groups.items():
# CRITICAL FIX: enforce correct dtype
cols = torch.as_tensor(cols, dtype=torch.long, device=device)
# -------------------------
# CONTINUOUS
# -------------------------
if name == "continuous":
recon_out[:, cols] = logits[:, cols] * stds[cols] + means[cols]
# -------------------------
# BINARY
# -------------------------
elif name == "binary":
recon_out[:, cols] = torch.sigmoid(logits[:, cols])
# -------------------------
# CATEGORICAL
# -------------------------
else:
# logits subset → (N, K)
cat_logits = logits[:, cols]
# probabilities
probs = torch.softmax(cat_logits, dim=1)
# argmax per row
idx = torch.argmax(probs, dim=1)
# zero-out group
recon_out[:, cols] = 0.0
# CORRECT one-hot assignment
row_idx = torch.arange(recon_out.shape[0], device=device)
recon_out[row_idx.unsqueeze(1), cols.unsqueeze(0)] = 0 # ensure clean
recon_out[row_idx, cols[idx]] = 1.0
# -------------------------
# Keep only validation entries
# -------------------------
recon_out[val_mask] = float("nan")
return recon_out
@torch.no_grad()
def set_activation_groups(
self,
activation_groups: dict,
) -> None:
"""
Attach feature activation structure to the model.
This method stores the resolved `activation_groups` dictionary, which defines how
each input column should be interpreted during loss computation and imputation.
The model itself does NOT use this information during the forward pass; it is
stored purely as metadata to ensure consistency between the model, dataset,
loss functions, and imputation routines.
Parameters
----------
activation_groups : dict
Dictionary mapping feature types to column indices. Expected format:
{
"continuous": [col_idx, ...],
"binary": [col_idx, ...],
"<categorical_name>": [col_idx, ...],
...
}
- "continuous": indices of continuous-valued features
- "binary": indices of binary features
- Each additional key represents a grouped categorical variable (multiple columns corresponding to one variable)
Raises
------
ValueError
If activation_groups is not a dictionary or contains invalid column indices.
Notes
-----
- This replaces the old `set_binary_features` functionality.
- The model outputs raw logits for all features; interpretation is handled
externally using `activation_groups`.
- This function is primarily used after loading a model to reattach dataset
structure if needed.
"""
if not isinstance(activation_groups, dict):
raise ValueError("activation_groups must be a dictionary.")
# Validate structure
for key, cols in activation_groups.items():
if not isinstance(cols, (list, tuple)):
raise ValueError(f"activation_groups['{key}'] must be a list of column indices.")
for c in cols:
if not isinstance(c, (int, np.integer)):
raise ValueError(f"Column index '{c}' in group '{key}' is not an integer.")
if not (0 <= int(c) < self.input_dim):
raise ValueError(
f"Column index {c} in group '{key}' is out of bounds for input_dim={self.input_dim}."
)
# Store as plain attribute (NOT buffer)
self.activation_groups = {
k: list(map(int, v)) for k, v in activation_groups.items()
}
def _normalize_layer_order(self, layer_order, name):
normalized = []
for i, lt in enumerate(layer_order):
if not isinstance(lt, str):
raise ValueError(f"{name}[{i}] must be a string.")
lt_clean = lt.lower()
if lt_clean in ["shared", "s"]:
normalized.append("shared")
elif lt_clean in ["unshared", "u"]:
normalized.append("unshared")
else:
raise ValueError(
f"Invalid value in {name}[{i}]: '{lt}'. "
"Must be one of {'shared','unshared','s','u'}."
)
return normalized
# @torch.no_grad()
# def set_binary_features(self,
# mask: Optional[Union[torch.Tensor, Sequence[bool]]] = None,
# feature_names: Optional[Sequence[str]] = None,
# binary_feature_names: Optional[Iterable[str]] = None) -> None:
# """
# Update which columns are treated as binary at the output. This function should not be necessary for user to touch.
# You can pass either:
# - mask: 1D bool vector length `input_dim`, or
# - feature_names + binary_feature_names: names → mask is computed
# This is safe to call after loading a model or dataset schema.
# Can set w/ vae.set_binary_features(mask = dataset.binary_feature_mask)
# :param binary_feature_mask: Boolean vector of length p for n x p dataset. True for binary columns, False for continuous columns
# :type binary_feature_mask: Optional[Union[torch.Tensor, Sequence[bool]]]
# :param feature_names: List of all feature names - used with 'binary_feature_names'.
# :type feature_names: Optional[Sequence[str]]
# :param binary_feature_names: List of all binary features (features must also be included in 'feature_names').
# :type binary_feature_names: Optional[Iterable[str]]
# """
# if mask is not None:
# mask = torch.as_tensor(mask, dtype=torch.bool, device=self.binary_mask.device)
# if mask.ndim != 1 or mask.numel() != self.input_dim:
# raise ValueError("mask must be a 1D boolean vector of length input_dim.")
# self.binary_mask.copy_(mask) # in-place update to keep buffer reference
# return
# if (feature_names is None) or (binary_feature_names is None):
# raise ValueError("Provide either `mask` or (`feature_names` and `binary_feature_names`).")
# feat2idx = {name: i for i, name in enumerate(feature_names)}
# newmask = torch.zeros(self.input_dim, dtype=torch.bool, device=self.binary_mask.device)
# for bname in binary_feature_names:
# if bname not in feat2idx:
# raise ValueError(f"Binary feature name '{bname}' not found in feature_names.")
# newmask[feat2idx[bname]] = True
# self.binary_mask.copy_(newmask)