scShift

scShift Model

scshift.scShift

alias of PertVIModel

class pertvi.model.pertvi.PertVIModel(adata, n_batch=0, n_hidden=128, n_latent=10, n_layers=2, dropout_rate=0, use_observed_lib_size=True, lam_l0=50, lam_l1=0, lam_corr=5, var_eps=0.0001, kl_weight=1)[source]

Bases: BaseModelClass, ArchesMixin

Model class for scShift.

This model aims to disentangle batch-dependent and batch-independent variations in single-cell data using variational inference.

Parameters:
  • adata (AnnData) – AnnData object containing single-cell data. Must have been set up via scShift.setup_anndata or an equivalent.

  • n_batch (int, default: 0) – Number of batches in the dataset.

  • n_hidden (int, default: 128) – Number of nodes per hidden layer.

  • n_latent (int, default: 10) – Dimensionality of the latent space.

  • n_layers (int, default: 2) – Number of hidden layers in encoder/decoder neural networks.

  • dropout_rate (float, default: 0) – Dropout rate to apply to layers.

  • use_observed_lib_size (bool, default: True) – If True, use observed library size as scaling factor in the mean of the distribution.

  • lam_l0 (float, default: 50) – Regularization coefficient (L0) for dataset label encoding through the stochastic gate mechanism.

  • lam_l1 (float, default: 0.0) – L1 penalty coefficient.

  • lam_corr (float, default: 5) – Independence regularization between centralized embedding and dataset label encoding.

  • var_eps (float, default: 1e-4) – Minimal variance for the variational posteriors.

  • kl_weight (float, default: 1) – KL divergence scale factor.

Returns:

The model is initialized in place.

Return type:

None

Methods Table

scshift.scShift.get_pert

A simple function to create one-hot dataset label encoding and store in adata.obsm['pert'].

scshift.scShift.setup_anndata

Set up AnnData instance for scShift model.

scshift.scShift.train

Train the scShift model using a semi-supervised data splitter.

scshift.scShift.get_latent_representation

Return the latent representation for each cell.

Preprocessing

scShift.get_pert(pert_label=None, drug_label=None, dose_label=None, ct_pert=None, ct_drug=None)

A simple function to create one-hot dataset label encoding and store in adata.obsm[‘pert’].

Parameters:
  • adata (AnnData) – The AnnData object.

  • pert_label (str, optional) – Key in adata.obs corresponding to drug / data identity.

  • drug_label (str, optional) – Key in adata.obs corresponding to drug / data identity.

  • dose_label (str, optional) – Key in adata.obs for the dosage levels.

  • ct_pert (str, optional) – Name or category representing control (unperturbed) in pert_label (Depracated).

  • ct_drug (str, optional) – Name or category representing control drug in drug_label (Depracated).

Returns:

Modifies adata.obsm[‘pert’] in place with the new encoding.

Return type:

None

classmethod scShift.setup_anndata(adata, pert_key='pert', layer=None, batch_key=None, labels_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)

Set up AnnData instance for scShift model. Need to run get_pert first.

Parameters:
  • adata (AnnData) – AnnData object containing raw counts. Rows represent cells, columns represent features.

  • pert_key (str, default: 'pert') – Key in adata.obsm for perturbation encoding.

  • layer (str, optional) – If not None, uses this as the key in adata.layers for raw count data.

  • batch_key (str, optional) – Key in adata.obs for batch information. Categories will automatically be converted into integer categories.

  • labels_key (str, optional) – Key in adata.obs for label information.

  • size_factor_key (str, optional) – Key in adata.obs for size factor information. If not provided, library size will be used.

  • categorical_covariate_keys (List[str], optional) – Keys in adata.obs corresponding to categorical data.

  • continuous_covariate_keys (List[str], optional) – Keys in adata.obs corresponding to continuous data.

  • **kwargs – Additional keyword arguments for registration.

Returns:

The adata is modified in place to include the necessary fields for PertVIModel. An AnnDataManager is then registered to this class.

Return type:

None

Training

scShift.train(max_epochs=None, use_gpu=None, batch_size=128, early_stopping=False, train_size=0.9, validation_size=None, n_samples_per_label=100, lr=0.001, weight_decay=0.0001, n_epochs_kl_warmup=None, n_steps_kl_warmup=1600, **trainer_kwargs)

Train the scShift model using a semi-supervised data splitter.

This method sets up a training loop with the chosen data splitter and training plan. It can optionally perform early stopping if desired.

Parameters:
  • max_epochs (int, optional) – Number of passes through the dataset. Defaults to a heuristic based on the number of cells if not specified.

  • use_gpu (Union[str, int, bool], optional) – Whether to use GPU for training. Can be None, True/False, or a specific GPU index or name, e.g. “cuda:0”.

  • batch_size (int, default: 128) – Mini-batch size for data loading during training.

  • early_stopping (bool, default: False) – If True, perform early stopping based on validation loss.

  • train_size (float, default: 0.9) – Proportion of cells to include in the training set.

  • validation_size (float, optional) – Proportion of cells to include in the validation set. If None, uses 1 - train_size. Additional cells, if any, form a test set.

  • n_samples_per_label (int, default: 100) – Number of labeled samples to use per label category in semi-supervised mode.

  • lr (float, default: 1e-3) – Learning rate for the optimizer.

  • weight_decay (float, default: 1e-4) – Weight decay for the optimizer, acting as an L2 regularization.

  • n_epochs_kl_warmup (int, optional) – Number of epochs over which to scale up the KL term from 0 to 1.

  • n_steps_kl_warmup (int, default: 1600) – Number of training steps over which to warm up the KL divergence term.

  • **trainer_kwargs – Additional keyword arguments passed to the Trainer or SemiSupervisedTrainingPlan.

Returns:

The model is trained in place. Check logs for training progress or potential early stopping triggers.

Return type:

None

Post-Training / Inference

scShift.get_latent_representation(adata=None, indices=None, give_mean=True, batch_size=None, use_mask=False, representation_kind='all')

Return the latent representation for each cell.

Parameters:
  • adata (AnnData, optional) – AnnData object with equivalent structure to initial AnnData. If None, uses the AnnData object used to initialize the model.

  • indices (Sequence[int], optional) – Indices of cells in adata to use. If None, use all cells.

  • give_mean (bool, default: True) – Whether to return the mean of the distribution or a sampled value.

  • batch_size (int, optional) – Mini-batch size for data loading.

  • use_mask (bool, default: False) – If True, uses a masked inference input instead of the full inference input.

  • representation_kind (str, default: "all") – Either “base”, “pert”, or “all”. Controls how the latent embedding is computed.

Returns:

A Numpy array of shape (n_cells, n_latent) containing latent representations.

Return type:

np.ndarray