Linear Probing

LPModel

class scshift.LPModel(adata, x, y, label=None, x2_label=None, mask_label=None, mode='Weighted', norm='none', hidden_dim=10)[source]

Bases: object

A linear probing model class supporting multiple modes of analysis.

This class provides different linear probing strategies (Weighted, Contrast, WeightedDisentangled, etc.) for analyzing single-cell data.

Parameters:
  • adata (AnnData) – An AnnData object containing single-cell data. The relevant embeddings/features are expected in adata.obsm.

  • x (str) – Key in adata.obsm for the primary feature embedding.

  • y (str) – Key in adata.obs for the target variable (categorical).

  • label (str, optional) – Key in adata.obs for grouping samples or conditions.

  • x2_label (str, optional) – If the mode is ‘WeightedDisentangled’ or ‘Contrast’, this key indicates a secondary feature embedding or label.

  • mask_label (str, optional) – If provided, is used to define distinct sets for training/validation/test, grouping by this label instead of individual cells.

  • mode (str, default: 'Weighted') – The linear probing mode. One of [“Weighted”, “WeightedDisentangled”, “Contrast”, “Individual”].

  • norm (str, default: 'none') – Whether to apply normalization to inputs (‘none’, ‘batch’, or ‘layer’).

  • hidden_dim (int, default: 10) – Size of hidden dimension if needed (used in some modes).

Returns:

Initializes the probing model in the specified mode.

Return type:

None

Methods Table

scshift.LPModel.train

Train the linear probing model.

scshift.LPModel.predict

Predict using the trained linear probing model.

Training

LPModel.train(max_epochs=None, use_gpu=None, train_size=0.9, batch_size=None, validation_size=None, lr=0.001, weight_decay=0.0001, device=None, seed=None)[source]

Train the linear probing model.

Splits data (or groups) into train/validation (and test, if enough leftover) according to train_size and validation_size. Then fits the chosen linear probing mode using an Adam optimizer.

Parameters:
  • max_epochs (int, optional) – Maximum number of training epochs. Defaults to a heuristic if not provided.

  • use_gpu (Union[str, int, bool], optional) – If True or a valid GPU identifier, trains on GPU if available; else on CPU.

  • train_size (float, default: 0.9) – Proportion of data (or groups) to include in the training set.

  • batch_size (int, optional) – Batch size for DataLoader. If None, data is loaded in one batch.

  • validation_size (float, optional) – Proportion of data (or groups) for the validation set. Defaults to 1 - train_size.

  • lr (float, default: 1e-3) – Learning rate for the optimizer.

  • weight_decay (float, default: 1e-4) – Weight decay (L2 penalty).

  • device (torch.device or str, optional) – Device to use. If None, automatically uses GPU if use_gpu is True.

  • seed (int, optional) – Random seed for sampling training and validation sets. If not specified, then scvi.settings.seed is used.

Returns:

A tuple of two lists: (train_loss, val_loss). If validation_size > 0, val_loss is populated; otherwise it’s empty.

Return type:

(list, list)

Inference

LPModel.predict(adata)[source]

Predict using the trained linear probing model.

Depending on the mode, the inputs required for prediction may include multiple embeddings and/or label arrays.

Parameters:

adata (AnnData) – The AnnData object containing embeddings/features in adata.obsm. Must have the same structure used when initializing this model.

Returns:

  • If mode is ‘Individual’, returns predictions as a NumPy array.

  • Otherwise, returns (predictions, mapped_categories), where mapped_categories indicates the unique label groupings used.

Return type:

Union[np.ndarray, Tuple[np.ndarray, pd.DataFrame]]

Probing Modules

class pertvi.model.linearprob.WeightedLinearProbing(input_dim, y_dim, hidden_dim, use_norm='none')[source]

A weighted linear probing module applying per-donor averaging with learned weights.

Parameters:
  • input_dim (int) – Dimensionality of the input features.

  • y_dim (int) – Dimensionality of the target (number of classes).

  • hidden_dim (int) – Size of hidden dimension (unused in current minimal design, but kept for API).

  • use_norm (str, default: 'none') – Type of normalization to apply. One of [‘none’, ‘batch’, ‘layer’].

forward(x, labels, y)[source]

Forward pass for Weighted linear probing.

Parameters:
  • x (torch.Tensor) – Input feature of shape (batch_size, input_dim).

  • labels (torch.Tensor) – Group labels for each sample, used to aggregate by group.

  • y (torch.Tensor) – One-hot encoded target of shape (batch_size, y_dim).

Returns:

Scalar loss (cross-entropy).

Return type:

torch.Tensor

predict(x, labels)[source]

Prediction for Weighted linear probing.

Parameters:
  • x (torch.Tensor) – Input feature of shape (n_samples, input_dim).

  • labels (torch.Tensor) – Group labels used for aggregation.

Returns:

  • prediction: shape (n_groups, y_dim)

  • unique group labels (on the device)

Return type:

(torch.Tensor, torch.Tensor)

class pertvi.model.linearprob.WeightedDisentangledLinearProbing(x1_dim, x2_dim, y_dim, hidden_dim, use_norm='none')[source]

A weighted, disentangled linear probing module that combines two embeddings.

Parameters:
  • x1_dim (int) – Dimensionality of the first input embedding.

  • x2_dim (int) – Dimensionality of the second input embedding.

  • y_dim (int) – Dimensionality of the target (number of classes).

  • hidden_dim (int) – Hidden dimension size.

  • use_norm (str, default: 'none') – Type of normalization to apply. One of [‘none’, ‘batch’, ‘layer’].

forward(x1, x2, labels, y)[source]

Forward pass for WeightedDisentangled linear probing.

Parameters:
  • x1 (torch.Tensor) – First embedding input of shape (batch_size, x1_dim).

  • x2 (torch.Tensor) – Second embedding input of shape (batch_size, x2_dim).

  • labels (torch.Tensor) – Group labels used for aggregation.

  • y (torch.Tensor) – One-hot encoded target of shape (batch_size, y_dim).

Returns:

Scalar cross-entropy loss.

Return type:

torch.Tensor

predict(x1, x2, labels)[source]

Prediction method for WeightedDisentangled linear probing.

Parameters:
  • x1 (torch.Tensor) – First embedding of shape (n_samples, x1_dim).

  • x2 (torch.Tensor) – Second embedding of shape (n_samples, x2_dim).

  • labels (torch.Tensor) – Group labels used for aggregation.

Returns:

  • predictions: shape (n_groups, y_dim)

  • unique label IDs (on the device)

Return type:

(torch.Tensor, torch.Tensor)

class pertvi.model.linearprob.ContrastLinearProbing(input_dim, y_dim, hidden_dim, use_norm='none')[source]

A contrastive linear probing module computing differences between two subgroups.

Parameters:
  • input_dim (int) – Dimensionality of the input features.

  • y_dim (int) – Dimensionality of the target (number of classes).

  • hidden_dim (int) – Hidden dimension size (unused in minimal design).

  • use_norm (str, default: 'none') – Type of normalization to apply. One of [‘none’, ‘batch’, ‘layer’].

forward(x, labels, labels2, y)[source]

Forward pass for contrastive linear probing.

Parameters:
  • x (torch.Tensor) – Input feature of shape (batch_size, input_dim).

  • labels (torch.Tensor) – Group labels for each sample (e.g., cell type).

  • labels2 (torch.Tensor) – Another label dividing data into exactly two subgroups for contrast.

  • y (torch.Tensor) – One-hot encoded target of shape (batch_size, y_dim).

Returns:

Scalar loss.

Return type:

torch.Tensor

predict(x, labels, labels2)[source]

Prediction method for contrastive linear probing.

Aggregates the input features by labels within each subgroup of labels2, then calculates the difference of those aggregated embeddings.

Parameters:
  • x (torch.Tensor) – Input feature of shape (n_samples, input_dim).

  • labels (torch.Tensor) – Group labels for each sample.

  • labels2 (torch.Tensor) – Another label dividing data into exactly two subgroups for contrast.

Returns:

  • prediction_diff: The softmax of linear-probed difference.

  • unique labels used in the grouping (for alignment).

Return type:

(torch.Tensor, torch.Tensor)

class pertvi.model.linearprob.LinearProbing(input_dim, y_dim, use_norm='none')[source]

A simple linear probing module without per-group weighting.

Parameters:
  • input_dim (int) – Dimensionality of the input features.

  • y_dim (int) – Dimensionality of the target (number of classes).

  • use_norm (str, default: 'none') – Type of normalization to apply. One of [‘none’, ‘batch’, ‘layer’].

forward(x, y)[source]

Forward pass for a simple linear probing.

Parameters:
  • x (torch.Tensor) – Input feature of shape (batch_size, input_dim).

  • y (torch.Tensor) – One-hot encoded target of shape (batch_size, y_dim).

Returns:

Scalar cross-entropy loss.

Return type:

torch.Tensor

predict(x)[source]

Inference for a simple linear probing.

Parameters:

x (torch.Tensor) – Input feature of shape (n_samples, input_dim).

Returns:

Predictions of shape (n_samples, y_dim).

Return type:

torch.Tensor