scShift¶
scShift Model¶
- scshift.scShift¶
alias of
PertVIModel
- class pertvi.model.pertvi.PertVIModel(adata, n_batch=0, n_hidden=128, n_latent=10, n_layers=2, dropout_rate=0, use_observed_lib_size=True, lam_l0=50, lam_l1=0, lam_corr=5, var_eps=0.0001, kl_weight=1)[source]¶
Bases:
BaseModelClass,ArchesMixinModel class for scShift.
This model aims to disentangle batch-dependent and batch-independent variations in single-cell data using variational inference.
- Parameters:
adata (AnnData) – AnnData object containing single-cell data. Must have been set up via scShift.setup_anndata or an equivalent.
n_batch (int, default: 0) – Number of batches in the dataset.
n_hidden (int, default: 128) – Number of nodes per hidden layer.
n_latent (int, default: 10) – Dimensionality of the latent space.
n_layers (int, default: 2) – Number of hidden layers in encoder/decoder neural networks.
dropout_rate (float, default: 0) – Dropout rate to apply to layers.
use_observed_lib_size (bool, default: True) – If True, use observed library size as scaling factor in the mean of the distribution.
lam_l0 (float, default: 50) – Regularization coefficient (L0) for dataset label encoding through the stochastic gate mechanism.
lam_l1 (float, default: 0.0) – L1 penalty coefficient.
lam_corr (float, default: 5) – Independence regularization between centralized embedding and dataset label encoding.
var_eps (float, default: 1e-4) – Minimal variance for the variational posteriors.
kl_weight (float, default: 1) – KL divergence scale factor.
- Returns:
The model is initialized in place.
- Return type:
None
Methods Table¶
A simple function to create one-hot dataset label encoding and store in adata.obsm['pert']. |
|
Set up AnnData instance for scShift model. |
|
Train the scShift model using a semi-supervised data splitter. |
|
Return the latent representation for each cell. |
Preprocessing¶
- scShift.get_pert(pert_label=None, drug_label=None, dose_label=None, ct_pert=None, ct_drug=None)¶
A simple function to create one-hot dataset label encoding and store in adata.obsm[‘pert’].
- Parameters:
adata (AnnData) – The AnnData object.
pert_label (str, optional) – Key in adata.obs corresponding to drug / data identity.
drug_label (str, optional) – Key in adata.obs corresponding to drug / data identity.
dose_label (str, optional) – Key in adata.obs for the dosage levels.
ct_pert (str, optional) – Name or category representing control (unperturbed) in pert_label (Depracated).
ct_drug (str, optional) – Name or category representing control drug in drug_label (Depracated).
- Returns:
Modifies adata.obsm[‘pert’] in place with the new encoding.
- Return type:
None
- classmethod scShift.setup_anndata(adata, pert_key='pert', layer=None, batch_key=None, labels_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)¶
Set up AnnData instance for scShift model. Need to run get_pert first.
- Parameters:
adata (AnnData) – AnnData object containing raw counts. Rows represent cells, columns represent features.
pert_key (str, default: 'pert') – Key in adata.obsm for perturbation encoding.
layer (str, optional) – If not None, uses this as the key in adata.layers for raw count data.
batch_key (str, optional) – Key in adata.obs for batch information. Categories will automatically be converted into integer categories.
labels_key (str, optional) – Key in adata.obs for label information.
size_factor_key (str, optional) – Key in adata.obs for size factor information. If not provided, library size will be used.
categorical_covariate_keys (List[str], optional) – Keys in adata.obs corresponding to categorical data.
continuous_covariate_keys (List[str], optional) – Keys in adata.obs corresponding to continuous data.
**kwargs – Additional keyword arguments for registration.
- Returns:
The adata is modified in place to include the necessary fields for PertVIModel. An AnnDataManager is then registered to this class.
- Return type:
None
Training¶
- scShift.train(max_epochs=None, use_gpu=None, batch_size=128, early_stopping=False, train_size=0.9, validation_size=None, n_samples_per_label=100, lr=0.001, weight_decay=0.0001, n_epochs_kl_warmup=None, n_steps_kl_warmup=1600, **trainer_kwargs)¶
Train the scShift model using a semi-supervised data splitter.
This method sets up a training loop with the chosen data splitter and training plan. It can optionally perform early stopping if desired.
- Parameters:
max_epochs (int, optional) – Number of passes through the dataset. Defaults to a heuristic based on the number of cells if not specified.
use_gpu (Union[str, int, bool], optional) – Whether to use GPU for training. Can be None, True/False, or a specific GPU index or name, e.g. “cuda:0”.
batch_size (int, default: 128) – Mini-batch size for data loading during training.
early_stopping (bool, default: False) – If True, perform early stopping based on validation loss.
train_size (float, default: 0.9) – Proportion of cells to include in the training set.
validation_size (float, optional) – Proportion of cells to include in the validation set. If None, uses 1 - train_size. Additional cells, if any, form a test set.
n_samples_per_label (int, default: 100) – Number of labeled samples to use per label category in semi-supervised mode.
lr (float, default: 1e-3) – Learning rate for the optimizer.
weight_decay (float, default: 1e-4) – Weight decay for the optimizer, acting as an L2 regularization.
n_epochs_kl_warmup (int, optional) – Number of epochs over which to scale up the KL term from 0 to 1.
n_steps_kl_warmup (int, default: 1600) – Number of training steps over which to warm up the KL divergence term.
**trainer_kwargs – Additional keyword arguments passed to the
TrainerorSemiSupervisedTrainingPlan.
- Returns:
The model is trained in place. Check logs for training progress or potential early stopping triggers.
- Return type:
None
Post-Training / Inference¶
- scShift.get_latent_representation(adata=None, indices=None, give_mean=True, batch_size=None, use_mask=False, representation_kind='all')¶
Return the latent representation for each cell.
- Parameters:
adata (AnnData, optional) – AnnData object with equivalent structure to initial AnnData. If None, uses the AnnData object used to initialize the model.
indices (Sequence[int], optional) – Indices of cells in adata to use. If None, use all cells.
give_mean (bool, default: True) – Whether to return the mean of the distribution or a sampled value.
batch_size (int, optional) – Mini-batch size for data loading.
use_mask (bool, default: False) – If True, uses a masked inference input instead of the full inference input.
representation_kind (str, default: "all") – Either “base”, “pert”, or “all”. Controls how the latent embedding is computed.
- Returns:
A Numpy array of shape (n_cells, n_latent) containing latent representations.
- Return type:
np.ndarray