"""Model class for scShift that disentangles batch-dependent and batch-independent variations in data."""
import logging
import warnings
from functools import partial
from typing import Dict, Iterable, List, Optional, Sequence, Union
import numpy as np
import pandas as pd
import math
import torch
import scanpy as sc
import pytorch_lightning as pl
import torch.optim as optim
from scvi import settings
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
from scipy.sparse import issparse
from torch.distributions import Normal
from torch.autograd import Variable as V
from random import choices
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
ObsmField,
LayerField,
NumericalJointObsField,
LabelsWithUnlabeledObsField,
NumericalObsField,
)
from scvi.dataloaders import AnnDataLoader
from scvi.dataloaders._ann_dataloader import BatchSampler
from scvi.dataloaders._anntorchdataset import AnnTorchDataset
from scvi.model._utils import (
_get_batch_code_from_category,
_init_library_size,
scrna_raw_counts_properties,
)
from scvi.model.base import BaseModelClass, ArchesMixin
from scvi.model.base._utils import _de_core
from scvi.utils import setup_anndata_dsp
from scvi.train import SemiSupervisedTrainingPlan, TrainingPlan, TrainRunner
from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter
from scvi.model._utils import parse_use_gpu_arg
from scvi.dataloaders._data_splitting import validate_data_split
from pertvi.module.pertvi import PertVIModule
logger = logging.getLogger(__name__)
Number = Union[int, float]
[docs]
class PertVIModel(BaseModelClass, ArchesMixin):
"""
Model class for scShift.
This model aims to disentangle batch-dependent and batch-independent variations
in single-cell data using variational inference.
Parameters
----------
adata : AnnData
AnnData object containing single-cell data. Must have been set up
via `scShift.setup_anndata` or an equivalent.
n_batch : int, default: 0
Number of batches in the dataset.
n_hidden : int, default: 128
Number of nodes per hidden layer.
n_latent : int, default: 10
Dimensionality of the latent space.
n_layers : int, default: 2
Number of hidden layers in encoder/decoder neural networks.
dropout_rate : float, default: 0
Dropout rate to apply to layers.
use_observed_lib_size : bool, default: True
If True, use observed library size as scaling factor in the mean of the distribution.
lam_l0 : float, default: 50
Regularization coefficient (L0) for dataset label encoding through the stochastic gate mechanism.
lam_l1 : float, default: 0.0
L1 penalty coefficient.
lam_corr : float, default: 5
Independence regularization between centralized embedding and dataset label encoding.
var_eps : float, default: 1e-4
Minimal variance for the variational posteriors.
kl_weight : float, default: 1
KL divergence scale factor.
Returns
-------
None
The model is initialized in place.
"""
def __init__(
self,
adata: AnnData,
n_batch: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 2,
dropout_rate: float = 0,
use_observed_lib_size: bool = True,
lam_l0: float = 50,
lam_l1: float = 0,
lam_corr: float = 5,
var_eps: float = 1e-4,
kl_weight: float = 1,
) -> None:
super(PertVIModel, self).__init__(adata)
n_cats_per_cov = (
self.adata_manager.get_state_registry(
REGISTRY_KEYS.CAT_COVS_KEY
).n_cats_per_key
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
else None
)
n_batch = self.summary_stats.n_batch
use_size_factor_key = (
REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
)
library_log_means, library_log_vars = None, None
if not use_size_factor_key:
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, n_batch
)
self.module = PertVIModule(
n_input=self.summary_stats["n_vars"],
n_pert = adata.obsm['pert'].shape[1],
n_batch=n_batch,
n_hidden=n_hidden,
n_output=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
use_observed_lib_size=use_observed_lib_size,
library_log_means=library_log_means,
library_log_vars=library_log_vars,
lam_l0 = lam_l0,
lam_l1 = lam_l1,
lam_corr = lam_corr,
var_eps = var_eps,
kl_weight = kl_weight,
)
self._model_summary_string = "PertVI"
# Necessary line to get params to be used for saving and loading.
self.init_params_ = self._get_init_params(locals())
logger.info("The model has been initialized")
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
pert_key: str = 'pert',
layer: Optional[str] = None,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
size_factor_key: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
**kwargs,
):
"""
Set up AnnData instance for scShift model. Need to run get_pert first.
Parameters
----------
adata : AnnData
AnnData object containing raw counts. Rows represent cells, columns
represent features.
pert_key : str, default: 'pert'
Key in `adata.obsm` for perturbation encoding.
layer : str, optional
If not None, uses this as the key in adata.layers for raw count data.
batch_key : str, optional
Key in `adata.obs` for batch information. Categories will automatically be
converted into integer categories.
labels_key : str, optional
Key in `adata.obs` for label information.
size_factor_key : str, optional
Key in `adata.obs` for size factor information. If not provided,
library size will be used.
categorical_covariate_keys : List[str], optional
Keys in `adata.obs` corresponding to categorical data.
continuous_covariate_keys : List[str], optional
Keys in `adata.obs` corresponding to continuous data.
**kwargs
Additional keyword arguments for registration.
Returns
-------
None
The `adata` is modified in place to include the necessary fields
for PertVIModel. An `AnnDataManager` is then registered to this class.
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
ObsmField(pert_key, pert_key),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
LabelsWithUnlabeledObsField(
REGISTRY_KEYS.LABELS_KEY, labels_key, 'label_0'
),
NumericalObsField(
REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False
),
CategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
),
NumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
@torch.no_grad()
def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
give_mean: bool = True,
batch_size: Optional[int] = None,
use_mask = False,
representation_kind: str = "all",
) -> np.ndarray:
"""
Return the latent representation for each cell.
Parameters
----------
adata : AnnData, optional
AnnData object with equivalent structure to initial AnnData. If `None`,
uses the AnnData object used to initialize the model.
indices : Sequence[int], optional
Indices of cells in adata to use. If `None`, use all cells.
give_mean : bool, default: True
Whether to return the mean of the distribution or a sampled value.
batch_size : int, optional
Mini-batch size for data loading.
use_mask : bool, default: False
If True, uses a masked inference input instead of the full inference input.
representation_kind : str, default: "all"
Either "base", "pert", or "all". Controls how the latent embedding is computed.
Returns
-------
np.ndarray
A Numpy array of shape (n_cells, n_latent) containing latent representations.
"""
available_representation_kinds = ["base", "pert","all"]
assert representation_kind in available_representation_kinds, (
f"representation_kind = {representation_kind} is not one of"
f" {available_representation_kinds}"
)
adata = self._validate_anndata(adata)
dataloader = self._make_data_loader(adata=adata, indices=indices,batch_size=batch_size,shuffle=False,data_loader_class=AnnDataLoader)
latent = []
for tensors in dataloader:
if use_mask:
inference_inputs = self.module._get_inference_input(tensors)
else:
inference_inputs = self.module._get_inference_input_eval(tensors)
outputs = self.module.inference(**inference_inputs)
if representation_kind=='base':
latent_m = outputs["q_m"]
latent_sample = outputs["z"]
elif representation_kind=='pert':
latent_m = outputs["p_m"]
latent_sample = outputs["z_pert"]
latent_m = torch.sign(latent_m) * (torch.clamp(torch.abs(latent_m),min=0.1)-0.1)
else:
latent_pert = torch.sign(outputs["p_m"]) * (torch.clamp(torch.abs(outputs["p_m"]),min=0.1)-0.1)
latent_m = outputs["q_m"] + latent_pert
latent_sample = outputs["z_all"]
if give_mean:
latent_sample = latent_m
latent += [latent_sample.detach().cpu()]
return torch.cat(latent).numpy()
def get_pert(
adata,
pert_label = None,
drug_label = None,
dose_label = None,
ct_pert = None,
ct_drug = None,
):
"""
A simple function to create one-hot dataset label encoding and store in `adata.obsm['pert']`.
Parameters
----------
adata : AnnData
The AnnData object.
pert_label : str, optional
Key in `adata.obs` corresponding to drug / data identity.
drug_label : str, optional
Key in `adata.obs` corresponding to drug / data identity.
dose_label : str, optional
Key in `adata.obs` for the dosage levels.
ct_pert : str, optional
Name or category representing control (unperturbed) in `pert_label` (Depracated).
ct_drug : str, optional
Name or category representing control drug in `drug_label` (Depracated).
Returns
-------
None
Modifies `adata.obsm['pert']` in place with the new encoding.
"""
if pert_label is None:
df = pd.get_dummies(adata.obs[drug_label]) * 1
if dose_label is not None:
df = df * adata.obs[dose_label][:,None]
#df.iloc[adata.obs[drug_label]==ct_drug] = 0
elif drug_label is None:
df = pd.get_dummies(adata.obs[pert_label]) * 1
#df.iloc[adata.obs[pert_label]==ct_pert] = 0
adata.obsm['pert'] = df.values
return
def train(
self,
max_epochs: Optional[int] = None,
use_gpu: Optional[Union[str, int, bool]] = None,
batch_size: int = 128,
early_stopping: bool = False,
train_size: float = 0.9,
validation_size: Optional[float] = None,
n_samples_per_label = 100,
lr = 1e-3,
weight_decay = 1e-4,
n_epochs_kl_warmup = None,
n_steps_kl_warmup = 1600,
**trainer_kwargs,
) -> None:
"""
Train the scShift model using a semi-supervised data splitter.
This method sets up a training loop with the chosen data splitter and
training plan. It can optionally perform early stopping if desired.
Parameters
----------
max_epochs : int, optional
Number of passes through the dataset. Defaults to a heuristic based on
the number of cells if not specified.
use_gpu : Union[str, int, bool], optional
Whether to use GPU for training. Can be None, True/False, or a specific
GPU index or name, e.g. "cuda:0".
batch_size : int, default: 128
Mini-batch size for data loading during training.
early_stopping : bool, default: False
If True, perform early stopping based on validation loss.
train_size : float, default: 0.9
Proportion of cells to include in the training set.
validation_size : float, optional
Proportion of cells to include in the validation set. If None, uses
1 - train_size. Additional cells, if any, form a test set.
n_samples_per_label : int, default: 100
Number of labeled samples to use per label category in semi-supervised mode.
lr : float, default: 1e-3
Learning rate for the optimizer.
weight_decay : float, default: 1e-4
Weight decay for the optimizer, acting as an L2 regularization.
n_epochs_kl_warmup : int, optional
Number of epochs over which to scale up the KL term from 0 to 1.
n_steps_kl_warmup : int, default: 1600
Number of training steps over which to warm up the KL divergence term.
**trainer_kwargs
Additional keyword arguments passed to the :class:`~scvi.train.Trainer`
or :class:`~scvi.train.SemiSupervisedTrainingPlan`.
Returns
-------
None
The model is trained in place. Check logs for training progress or
potential early stopping triggers.
"""
data_splitter = SemiSupervisedDataSplitter(
self.adata_manager,
n_samples_per_label=n_samples_per_label,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
use_gpu=use_gpu,
)
#data_splitter = DataSplitter(
# self.adata_manager,
# train_size=train_size,
# validation_size=validation_size,
# batch_size=batch_size,
# use_gpu=use_gpu,
#)
#training_plan = TrainingPlan(self.module, lr = lr, weight_decay = weight_decay,n_steps_kl_warmup = n_steps_kl_warmup, n_epochs_kl_warmup = n_epochs_kl_warmup)
training_plan = SemiSupervisedTrainingPlan(self.module, lr = lr, weight_decay = weight_decay,n_steps_kl_warmup = n_steps_kl_warmup, n_epochs_kl_warmup = n_epochs_kl_warmup)
es = "early_stopping"
trainer_kwargs[es] = (
early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
)
runner = TrainRunner(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
**trainer_kwargs,
)
return runner()