Linear Probing¶
LPModel¶
- class scshift.LPModel(adata, x, y, label=None, x2_label=None, mask_label=None, mode='Weighted', norm='none', hidden_dim=10)[source]¶
Bases:
objectA linear probing model class supporting multiple modes of analysis.
This class provides different linear probing strategies (Weighted, Contrast, WeightedDisentangled, etc.) for analyzing single-cell data.
- Parameters:
adata (AnnData) – An AnnData object containing single-cell data. The relevant embeddings/features are expected in adata.obsm.
x (str) – Key in adata.obsm for the primary feature embedding.
y (str) – Key in adata.obs for the target variable (categorical).
label (str, optional) – Key in adata.obs for grouping samples or conditions.
x2_label (str, optional) – If the mode is ‘WeightedDisentangled’ or ‘Contrast’, this key indicates a secondary feature embedding or label.
mask_label (str, optional) – If provided, is used to define distinct sets for training/validation/test, grouping by this label instead of individual cells.
mode (str, default: 'Weighted') – The linear probing mode. One of [“Weighted”, “WeightedDisentangled”, “Contrast”, “Individual”].
norm (str, default: 'none') – Whether to apply normalization to inputs (‘none’, ‘batch’, or ‘layer’).
hidden_dim (int, default: 10) – Size of hidden dimension if needed (used in some modes).
- Returns:
Initializes the probing model in the specified mode.
- Return type:
None
Methods Table¶
Train the linear probing model. |
|
Predict using the trained linear probing model. |
Training¶
- LPModel.train(max_epochs=None, use_gpu=None, train_size=0.9, batch_size=None, validation_size=None, lr=0.001, weight_decay=0.0001, device=None, seed=None)[source]¶
Train the linear probing model.
Splits data (or groups) into train/validation (and test, if enough leftover) according to train_size and validation_size. Then fits the chosen linear probing mode using an Adam optimizer.
- Parameters:
max_epochs (int, optional) – Maximum number of training epochs. Defaults to a heuristic if not provided.
use_gpu (Union[str, int, bool], optional) – If True or a valid GPU identifier, trains on GPU if available; else on CPU.
train_size (float, default: 0.9) – Proportion of data (or groups) to include in the training set.
batch_size (int, optional) – Batch size for DataLoader. If None, data is loaded in one batch.
validation_size (float, optional) – Proportion of data (or groups) for the validation set. Defaults to 1 - train_size.
lr (float, default: 1e-3) – Learning rate for the optimizer.
weight_decay (float, default: 1e-4) – Weight decay (L2 penalty).
device (torch.device or str, optional) – Device to use. If None, automatically uses GPU if use_gpu is True.
seed (int, optional) – Random seed for sampling training and validation sets. If not specified, then scvi.settings.seed is used.
- Returns:
A tuple of two lists: (train_loss, val_loss). If validation_size > 0, val_loss is populated; otherwise it’s empty.
- Return type:
(list, list)
Inference¶
- LPModel.predict(adata)[source]¶
Predict using the trained linear probing model.
Depending on the mode, the inputs required for prediction may include multiple embeddings and/or label arrays.
- Parameters:
adata (AnnData) – The AnnData object containing embeddings/features in adata.obsm. Must have the same structure used when initializing this model.
- Returns:
If mode is ‘Individual’, returns predictions as a NumPy array.
Otherwise, returns (predictions, mapped_categories), where mapped_categories indicates the unique label groupings used.
- Return type:
Union[np.ndarray, Tuple[np.ndarray, pd.DataFrame]]
Probing Modules¶
- class pertvi.model.linearprob.WeightedLinearProbing(input_dim, y_dim, hidden_dim, use_norm='none')[source]¶
A weighted linear probing module applying per-donor averaging with learned weights.
- Parameters:
input_dim (int) – Dimensionality of the input features.
y_dim (int) – Dimensionality of the target (number of classes).
hidden_dim (int) – Size of hidden dimension (unused in current minimal design, but kept for API).
use_norm (str, default: 'none') – Type of normalization to apply. One of [‘none’, ‘batch’, ‘layer’].
- forward(x, labels, y)[source]¶
Forward pass for Weighted linear probing.
- Parameters:
x (torch.Tensor) – Input feature of shape (batch_size, input_dim).
labels (torch.Tensor) – Group labels for each sample, used to aggregate by group.
y (torch.Tensor) – One-hot encoded target of shape (batch_size, y_dim).
- Returns:
Scalar loss (cross-entropy).
- Return type:
torch.Tensor
- predict(x, labels)[source]¶
Prediction for Weighted linear probing.
- Parameters:
x (torch.Tensor) – Input feature of shape (n_samples, input_dim).
labels (torch.Tensor) – Group labels used for aggregation.
- Returns:
prediction: shape (n_groups, y_dim)
unique group labels (on the device)
- Return type:
(torch.Tensor, torch.Tensor)
- class pertvi.model.linearprob.WeightedDisentangledLinearProbing(x1_dim, x2_dim, y_dim, hidden_dim, use_norm='none')[source]¶
A weighted, disentangled linear probing module that combines two embeddings.
- Parameters:
x1_dim (int) – Dimensionality of the first input embedding.
x2_dim (int) – Dimensionality of the second input embedding.
y_dim (int) – Dimensionality of the target (number of classes).
hidden_dim (int) – Hidden dimension size.
use_norm (str, default: 'none') – Type of normalization to apply. One of [‘none’, ‘batch’, ‘layer’].
- forward(x1, x2, labels, y)[source]¶
Forward pass for WeightedDisentangled linear probing.
- Parameters:
x1 (torch.Tensor) – First embedding input of shape (batch_size, x1_dim).
x2 (torch.Tensor) – Second embedding input of shape (batch_size, x2_dim).
labels (torch.Tensor) – Group labels used for aggregation.
y (torch.Tensor) – One-hot encoded target of shape (batch_size, y_dim).
- Returns:
Scalar cross-entropy loss.
- Return type:
torch.Tensor
- predict(x1, x2, labels)[source]¶
Prediction method for WeightedDisentangled linear probing.
- Parameters:
x1 (torch.Tensor) – First embedding of shape (n_samples, x1_dim).
x2 (torch.Tensor) – Second embedding of shape (n_samples, x2_dim).
labels (torch.Tensor) – Group labels used for aggregation.
- Returns:
predictions: shape (n_groups, y_dim)
unique label IDs (on the device)
- Return type:
(torch.Tensor, torch.Tensor)
- class pertvi.model.linearprob.ContrastLinearProbing(input_dim, y_dim, hidden_dim, use_norm='none')[source]¶
A contrastive linear probing module computing differences between two subgroups.
- Parameters:
input_dim (int) – Dimensionality of the input features.
y_dim (int) – Dimensionality of the target (number of classes).
hidden_dim (int) – Hidden dimension size (unused in minimal design).
use_norm (str, default: 'none') – Type of normalization to apply. One of [‘none’, ‘batch’, ‘layer’].
- forward(x, labels, labels2, y)[source]¶
Forward pass for contrastive linear probing.
- Parameters:
x (torch.Tensor) – Input feature of shape (batch_size, input_dim).
labels (torch.Tensor) – Group labels for each sample (e.g., cell type).
labels2 (torch.Tensor) – Another label dividing data into exactly two subgroups for contrast.
y (torch.Tensor) – One-hot encoded target of shape (batch_size, y_dim).
- Returns:
Scalar loss.
- Return type:
torch.Tensor
- predict(x, labels, labels2)[source]¶
Prediction method for contrastive linear probing.
Aggregates the input features by labels within each subgroup of labels2, then calculates the difference of those aggregated embeddings.
- Parameters:
x (torch.Tensor) – Input feature of shape (n_samples, input_dim).
labels (torch.Tensor) – Group labels for each sample.
labels2 (torch.Tensor) – Another label dividing data into exactly two subgroups for contrast.
- Returns:
prediction_diff: The softmax of linear-probed difference.
unique labels used in the grouping (for alignment).
- Return type:
(torch.Tensor, torch.Tensor)
- class pertvi.model.linearprob.LinearProbing(input_dim, y_dim, use_norm='none')[source]¶
A simple linear probing module without per-group weighting.
- Parameters:
input_dim (int) – Dimensionality of the input features.
y_dim (int) – Dimensionality of the target (number of classes).
use_norm (str, default: 'none') – Type of normalization to apply. One of [‘none’, ‘batch’, ‘layer’].