import logging
import warnings
from functools import partial
from typing import Dict, Iterable, List, Optional, Sequence, Union
import numpy as np
import pandas as pd
import math
import torch
import scanpy as sc
from anndata import AnnData
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from scvi import settings
from torch.utils.data import DataLoader, TensorDataset
[docs]
class LPModel():
"""
A linear probing model class supporting multiple modes of analysis.
This class provides different linear probing strategies (Weighted,
Contrast, WeightedDisentangled, etc.) for analyzing single-cell data.
Parameters
----------
adata : AnnData
An AnnData object containing single-cell data.
The relevant embeddings/features are expected in `adata.obsm`.
x : str
Key in `adata.obsm` for the primary feature embedding.
y : str
Key in `adata.obs` for the target variable (categorical).
label : str, optional
Key in `adata.obs` for grouping samples or conditions.
x2_label : str, optional
If the mode is 'WeightedDisentangled' or 'Contrast',
this key indicates a secondary feature embedding or label.
mask_label : str, optional
If provided, is used to define distinct sets for training/validation/test,
grouping by this label instead of individual cells.
mode : str, default: 'Weighted'
The linear probing mode. One of ["Weighted", "WeightedDisentangled",
"Contrast", "Individual"].
norm : str, default: 'none'
Whether to apply normalization to inputs ('none', 'batch', or 'layer').
hidden_dim : int, default: 10
Size of hidden dimension if needed (used in some modes).
Returns
-------
None
Initializes the probing model in the specified mode.
"""
def __init__(
self,
adata: AnnData,
x: str,
y: str,
label: Optional[str] = None,
x2_label: Optional[str] = None,
mask_label: Optional[str] = None,
mode: str = 'Weighted',
norm: str = 'none',
hidden_dim = 10,
):
available_method_kinds = ["Weighted", "WeightedDisentangled", "Contrast", "Individual"]
assert mode in available_method_kinds, (
f"mode = {mode} is not one of"
f" {available_method_kinds}"
)
self.mode = mode
self.norm = norm
self.x_label = x
self.y_label = y
self.ind_label = label
self.x2_label = x2_label
self.y_index = pd.get_dummies(adata.obs[y]).columns
if label is not None:
self.labels = pd.get_dummies(adata.obs[label]).values.argmax(axis=1)
if mask_label is None:
self.mask_label = self.ind_label
self.mask_labels = self.labels
else:
self.mask_label = mask_label
self.mask_labels = pd.get_dummies(adata.obs[mask_label]).values.argmax(axis=1)
if mode == 'Weighted':
self.X = adata.obsm[x]
self.Y = pd.get_dummies(adata.obs[y]).values
self.module = WeightedLinearProbing(self.X.shape[1],self.Y.shape[1],hidden_dim, norm)
elif mode == 'WeightedDisentangled':
self.X1 = adata.obsm[x]
self.X2 = adata.obsm[x2_label]
self.Y = pd.get_dummies(adata.obs[y]).values
self.module = WeightedDisentangledLinearProbing(self.X1.shape[1],self.X2.shape[1],self.Y.shape[1],hidden_dim, norm)
elif mode == 'Contrast':
self.X = adata.obsm[x]
self.Y = pd.get_dummies(adata.obs[y]).values
self.labels2 = pd.get_dummies(adata.obs[x2_label]).values.argmax(axis=1)
self.module = ContrastLinearProbing(self.X.shape[1],self.Y.shape[1],hidden_dim, norm)
else:
self.X = adata.obsm[x]
self.Y = pd.get_dummies(adata.obs[y]).values
self.module = LinearProbing(self.X.shape[1],self.Y.shape[1], norm)
[docs]
def train(
self,
max_epochs: Optional[int] = None,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 0.9,
batch_size: Optional[int] = None,
validation_size: Optional[float] = None,
lr = 1e-3,
weight_decay = 1e-4,
device = None,
seed = None,
):
"""
Train the linear probing model.
Splits data (or groups) into train/validation (and test, if enough leftover)
according to `train_size` and `validation_size`. Then fits the chosen
linear probing mode using an Adam optimizer.
Parameters
----------
max_epochs : int, optional
Maximum number of training epochs. Defaults to a heuristic if not provided.
use_gpu : Union[str, int, bool], optional
If True or a valid GPU identifier, trains on GPU if available; else on CPU.
train_size : float, default: 0.9
Proportion of data (or groups) to include in the training set.
batch_size : int, optional
Batch size for DataLoader. If None, data is loaded in one batch.
validation_size : float, optional
Proportion of data (or groups) for the validation set. Defaults to `1 - train_size`.
lr : float, default: 1e-3
Learning rate for the optimizer.
weight_decay : float, default: 1e-4
Weight decay (L2 penalty).
device : torch.device or str, optional
Device to use. If None, automatically uses GPU if `use_gpu` is True.
seed : int, optional
Random seed for sampling training and validation sets.
If not specified, then scvi.settings.seed is used.
Returns
-------
(list, list)
A tuple of two lists: (train_loss, val_loss).
If `validation_size > 0`, val_loss is populated; otherwise it's empty.
"""
n_cells = self.Y.shape[0]
if max_epochs is None:
max_epochs = np.min([round((20000 / n_cells) * 400), 400])
if validation_size is None:
validation_size = 1 - train_size
if device is None:
if use_gpu & torch.cuda.is_available():
device = torch.device("cuda")
else:
device = 'cpu'
self.module = self.module.to(device)
self.device = device
if seed is None:
seed = settings.seed
if self.mask_label is not None:
#mask_labels = pd.get_dummies(adata.obs[mask_label]).values.argmax(axis=1)
unique_label = np.unique(self.mask_labels)
n_label = unique_label.shape[0]
n_train = int(train_size *n_label)
n_val = int(validation_size *n_label)
random_state = np.random.RandomState(seed=seed)
permutation = random_state.permutation(n_label)
self.train_mask = np.isin(self.mask_labels,unique_label[permutation][:n_train])
self.val_mask = np.isin(self.mask_labels, unique_label[permutation][n_train : (n_val + n_train)])
self.test_mask = np.isin(self.mask_labels, unique_label[permutation][(n_val + n_train):])
else:
n_train = int(train_size *n_cells)
n_val = int(validation_size *n_cells)
random_state = np.random.RandomState(seed=seed)
permutation = random_state.permutation(n_cells)
self.train_mask = permutation[:n_train]
self.val_mask = permutation[n_train : (n_val + n_train)]
self.test_mask = permutation[(n_val + n_train):]
if self.mode == 'Weighted':
self.X = torch.Tensor(self.X)
self.Y = torch.Tensor(self.Y)
self.labels = torch.Tensor(self.labels)
train_dataset = TensorDataset(self.X[self.train_mask], self.Y[self.train_mask], self.labels[self.train_mask])
val_dataset = TensorDataset(self.X[self.val_mask], self.Y[self.val_mask], self.labels[self.val_mask])
elif self.mode == 'Contrast':
self.X = torch.Tensor(self.X)
self.Y = torch.Tensor(self.Y)
self.labels = torch.Tensor(self.labels)
self.labels2 = torch.Tensor(self.labels2)
train_dataset = TensorDataset(self.X[self.train_mask], self.Y[self.train_mask], self.labels[self.train_mask], self.labels2[self.train_mask])
val_dataset = TensorDataset(self.X[self.val_mask], self.Y[self.val_mask], self.labels[self.val_mask], self.labels2[self.val_mask])
elif self.mode == 'WeightedDisentangled':
self.X1 = torch.Tensor(self.X1)
self.X2 = torch.Tensor(self.X2)
self.Y = torch.Tensor(self.Y)
self.labels = torch.Tensor(self.labels)
train_dataset = TensorDataset(self.X1[self.train_mask],self.X2[self.train_mask], self.Y[self.train_mask], self.labels[self.train_mask])
val_dataset = TensorDataset(self.X1[self.val_mask],self.X2[self.val_mask], self.Y[self.val_mask], self.labels[self.val_mask])
else:
self.X = torch.Tensor(self.X)
self.Y = torch.Tensor(self.Y)
train_dataset = TensorDataset(self.X[self.train_mask], self.Y[self.train_mask])
val_dataset = TensorDataset(self.X[self.val_mask], self.Y[self.val_mask])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
if validation_size > 0:
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
optimizer = optim.Adam(self.module.parameters(), lr=lr, weight_decay=weight_decay)
train_loss = []
val_loss = []
pbar = tqdm(range(1, max_epochs + 1))
for epoch in pbar:
train_loss.append(_train(self.module, train_dataloader, self.mode,device,optimizer))
pbar.set_description('Epoch '+str(epoch)+'/'+str(max_epochs))
if validation_size > 0:
val_loss.append(_eval(self.module, val_dataloader, self.mode,device))
pbar.set_postfix(train_loss=train_loss[epoch-1], val_loss=val_loss[epoch-1])
else:
pbar.set_postfix(train_loss=train_loss[epoch-1])
return train_loss, val_loss
[docs]
@torch.no_grad()
def predict(
self,
adata: AnnData,
):
"""
Predict using the trained linear probing model.
Depending on the `mode`, the inputs required for prediction
may include multiple embeddings and/or label arrays.
Parameters
----------
adata : AnnData
The AnnData object containing embeddings/features in `adata.obsm`.
Must have the same structure used when initializing this model.
Returns
-------
Union[np.ndarray, Tuple[np.ndarray, pd.DataFrame]]
- If mode is 'Individual', returns predictions as a NumPy array.
- Otherwise, returns (predictions, mapped_categories), where
`mapped_categories` indicates the unique label groupings used.
"""
self.module.eval()
if self.mode == 'Weighted':
X = torch.Tensor(adata.obsm[self.x_label]).to(self.device)
labels = torch.Tensor(pd.get_dummies(adata.obs[self.ind_label]).values.argmax(axis=1)).to(self.device)
predictions, unique_labels = self.module.predict(X,labels)
indices = unique_labels.to(device='cpu', dtype=torch.int64).numpy()
mapped_categories = pd.get_dummies(adata.obs[self.ind_label]).columns[indices]
return predictions.cpu().numpy(), mapped_categories
elif self.mode == 'Contrast':
X = torch.Tensor(adata.obsm[self.x_label]).to(self.device)
labels = torch.Tensor(pd.get_dummies(adata.obs[self.ind_label]).values.argmax(axis=1)).to(self.device)
labels2 = torch.Tensor(pd.get_dummies(adata.obs[self.x2_label]).values.argmax(axis=1)).to(self.device)
predictions, unique_labels = self.module.predict(X,labels,labels2)
indices = unique_labels.to(device='cpu', dtype=torch.int64).numpy()
mapped_categories = pd.get_dummies(adata.obs[self.ind_label]).columns[indices]
return predictions.cpu().numpy(), mapped_categories
elif self.mode == 'WeightedDisentangled':
X1 = torch.Tensor(adata.obsm[self.x_label]).to(self.device)
X2 = torch.Tensor(adata.obsm[self.x2_label]).to(self.device)
labels = torch.Tensor(pd.get_dummies(adata.obs[self.ind_label]).values.argmax(axis=1)).to(self.device)
predictions, unique_labels = self.module.predict(X1,X2,labels)
indices = unique_labels.to(device='cpu', dtype=torch.int64).numpy()
mapped_categories = pd.get_dummies(adata.obs[self.ind_label]).columns[indices]
return predictions.cpu().numpy(), mapped_categories
else:
X = torch.Tensor(adata.obsm[self.x_label]).to(self.device)
return self.module.predict(X).cpu().numpy()
def _train(model, dataloader, mode, device,optimizer):
"""
Internal training loop for one epoch.
Parameters
----------
model : nn.Module
The PyTorch model implementing a specific linear probing strategy.
dataloader : DataLoader
DataLoader yielding batches of (X, Y, ...) depending on mode.
mode : str
Probing mode, e.g. "Weighted", "Contrast", etc.
device : str or torch.device
Device on which computation is performed.
optimizer : torch.optim.Optimizer
Optimizer used to update model parameters.
Returns
-------
float
Mean training loss for this epoch.
"""
model.train()
train_loss = []
if mode == 'Weighted':
for i, (x,y,label) in enumerate(dataloader):
x = x.to(device)
y = y.to(device)
label = label.to(device)
optimizer.zero_grad()
loss = model(x,label,y)
loss.backward()
optimizer.step()
train_loss.append(loss.detach().cpu())
elif mode == 'Contrast':
for i, (x,y,label,label2) in enumerate(dataloader):
x = x.to(device)
y = y.to(device)
label = label.to(device)
label2 = label2.to(device)
optimizer.zero_grad()
loss = model(x,label,label2,y)
loss.backward()
optimizer.step()
train_loss.append(loss.detach().cpu())
elif mode == 'WeightedDisentangled':
for i, (x1,x2,y,label) in enumerate(dataloader):
x1 = x1.to(device)
x2 = x2.to(device)
y = y.to(device)
label = label.to(device)
optimizer.zero_grad()
loss = model(x1,x2,label,y)
loss.backward()
optimizer.step()
train_loss.append(loss.detach().cpu())
else:
for i, (x,y) in enumerate(dataloader):
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
loss = model(x,y)
loss.backward()
optimizer.step()
train_loss.append(loss.detach().cpu())
return np.array(train_loss).mean()
def _eval(model, dataloader, mode, device):
"""
Internal validation loop for one epoch.
Parameters
----------
model : nn.Module
The PyTorch model implementing a specific linear probing strategy.
dataloader : DataLoader
DataLoader yielding batches of (X, Y, ...) depending on mode.
mode : str
Probing mode, e.g. "Weighted", "Contrast", etc.
device : str or torch.device
Device on which computation is performed.
Returns
-------
float
Mean validation loss for this epoch.
"""
model.eval()
val_loss = []
if mode == 'Weighted':
for i, (x,y,label,label2) in enumerate(dataloader):
x = x.to(device)
y = y.to(device)
label = label.to(device)
label2 = label2.to(device)
loss = model(x,label,label2,y)
val_loss.append(loss.detach().cpu())
elif mode == 'Contrast':
for i, (x,y,label) in enumerate(dataloader):
x = x.to(device)
y = y.to(device)
label = label.to(device)
loss = model(x,label,y)
val_loss.append(loss.detach().cpu())
elif mode == 'WeightedDisentangled':
for i, (x1,x2,y,label) in enumerate(dataloader):
x1 = x1.to(device)
x2 = x2.to(device)
y = y.to(device)
label = label.to(device)
loss = model(x1,x2,label,y)
val_loss.append(loss.detach().cpu())
else:
for i, (x,y) in enumerate(dataloader):
x = x.to(device)
y = y.to(device)
loss = model(x,y)
val_loss.append(loss.detach().cpu())
return np.array(val_loss).mean()
[docs]
class WeightedLinearProbing(nn.Module):
"""
A weighted linear probing module applying per-donor averaging with learned weights.
Parameters
----------
input_dim : int
Dimensionality of the input features.
y_dim : int
Dimensionality of the target (number of classes).
hidden_dim : int
Size of hidden dimension (unused in current minimal design, but kept for API).
use_norm : str, default: 'none'
Type of normalization to apply. One of ['none', 'batch', 'layer'].
"""
def __init__(self, input_dim, y_dim, hidden_dim, use_norm = 'none'):
super(WeightedLinearProbing, self).__init__()
self.use_norm = use_norm
if self.use_norm == 'batch':
self.norm = nn.BatchNorm1d(input_dim)
elif self.use_norm == 'layer':
self.norm = nn.LayerNorm(input_dim)
self.linear = nn.Linear(input_dim, 1) # output the weights for each cell
#self.linear1 = nn.Linear(hidden_dim, 1)
self.linear2 = nn.Linear(input_dim, y_dim) # output the weights for each cell
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
self.loss = nn.CrossEntropyLoss()
[docs]
def forward(self, x, labels, y):
"""
Forward pass for Weighted linear probing.
Parameters
----------
x : torch.Tensor
Input feature of shape (batch_size, input_dim).
labels : torch.Tensor
Group labels for each sample, used to aggregate by group.
y : torch.Tensor
One-hot encoded target of shape (batch_size, y_dim).
Returns
-------
torch.Tensor
Scalar loss (cross-entropy).
"""
if self.use_norm != 'none':
x = self.norm(x)
a = torch.sigmoid((self.linear(x))) * x
#a = x
unique, labels_tmp = torch.unique(labels, sorted=True, return_inverse=True)
label = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, a.size(1))
unique_labels, labels_count = label.unique(dim=0, return_counts=True)
res = torch.zeros_like(unique_labels, dtype=torch.float, device=a.device).scatter_add_(0, label, a)
res = res / labels_count.float().unsqueeze(1)
prediction = self.softmax(self.linear2(res))
label_y = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, y.size(1))
y_ = torch.zeros([unique_labels.shape[0],y.size(1)], dtype=torch.float, device=a.device).scatter_add_(0, label_y, y)
y_ = y_ / labels_count.float().unsqueeze(1)
return self.loss(prediction, y_)
[docs]
def predict(self, x, labels):
"""
Prediction for Weighted linear probing.
Parameters
----------
x : torch.Tensor
Input feature of shape (n_samples, input_dim).
labels : torch.Tensor
Group labels used for aggregation.
Returns
-------
(torch.Tensor, torch.Tensor)
- prediction: shape (n_groups, y_dim)
- unique group labels (on the device)
"""
if self.use_norm != 'none':
x = self.norm(x)
a = torch.sigmoid((self.linear(x))) * x
unique, labels_tmp = torch.unique(labels, sorted=True, return_inverse=True)
label = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, a.size(1))
unique_labels, labels_count = label.unique(dim=0, return_counts=True)
res = torch.zeros_like(unique_labels, dtype=torch.float, device=a.device).scatter_add_(0, label, a)
res = res / labels_count.float().unsqueeze(1)
prediction = self.softmax(self.linear2(res))
return prediction.detach(), unique
[docs]
class ContrastLinearProbing(nn.Module):
"""
A contrastive linear probing module computing differences between two subgroups.
Parameters
----------
input_dim : int
Dimensionality of the input features.
y_dim : int
Dimensionality of the target (number of classes).
hidden_dim : int
Hidden dimension size (unused in minimal design).
use_norm : str, default: 'none'
Type of normalization to apply. One of ['none', 'batch', 'layer'].
"""
def __init__(self, input_dim, y_dim, hidden_dim, use_norm = 'none'):
super(ContrastLinearProbing, self).__init__()
self.use_norm = use_norm
if self.use_norm == 'batch':
self.norm = nn.BatchNorm1d(input_dim)
elif self.use_norm == 'layer':
self.norm = nn.LayerNorm(input_dim)
self.linear = nn.Linear(input_dim, 1) # output the weights for each cell
#self.linear1 = nn.Linear(hidden_dim, 1)
self.linear2 = nn.Linear(input_dim, y_dim) # output the weights for each cell
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
self.loss = nn.CrossEntropyLoss()
[docs]
def forward(self, x, labels, labels2, y):
"""
Forward pass for contrastive linear probing.
Parameters
----------
x : torch.Tensor
Input feature of shape (batch_size, input_dim).
labels : torch.Tensor
Group labels for each sample (e.g., cell type).
labels2 : torch.Tensor
Another label dividing data into exactly two subgroups for contrast.
y : torch.Tensor
One-hot encoded target of shape (batch_size, y_dim).
Returns
-------
torch.Tensor
Scalar loss.
"""
if self.use_norm != 'none':
x = self.norm(x)
a = torch.sigmoid((self.linear(x))) * x
# Get unique values of label2
unique_label2 = torch.unique(labels2)
if len(unique_label2) != 2:
raise ValueError("label2 should have exactly two unique values")
# Process for each unique value in label2
ress = []
uniques = []
for val in unique_label2:
mask = (labels2 == val)
a_masked = a[mask]
labels_masked = labels[mask]
# Process based on labels within each label2 group
unique, labels_tmp = torch.unique(labels_masked, sorted=True, return_inverse=True)
label = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, a_masked.size(1))
unique_labels, labels_count = label.unique(dim=0, return_counts=True)
res = torch.zeros_like(unique_labels, dtype=torch.float, device=a_masked.device).scatter_add_(0, label, a_masked)
res = res / labels_count.float().unsqueeze(1)
ress.append(res)
uniques.append(unique)
label_y = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, y.size(1))
y_ = torch.zeros([unique_labels.shape[0],y.size(1)], dtype=torch.float, device=a.device).scatter_add_(0, label_y, y)
y_ = y_ / labels_count.float().unsqueeze(1)
if not torch.all(torch.eq(uniques[0], uniques[1])):
print(uniques[0])
print(uniques[1])
raise ValueError("not match")
# Calculate the difference between the two predictions
prediction_diff = self.softmax(self.linear2(ress[1]-ress[0]))
return self.loss(prediction_diff, y_)
[docs]
def predict(self, x, labels, labels2):
"""
Prediction method for contrastive linear probing.
Aggregates the input features by labels within each subgroup of `labels2`,
then calculates the difference of those aggregated embeddings.
Parameters
----------
x : torch.Tensor
Input feature of shape (n_samples, input_dim).
labels : torch.Tensor
Group labels for each sample.
labels2 : torch.Tensor
Another label dividing data into exactly two subgroups for contrast.
Returns
-------
(torch.Tensor, torch.Tensor)
- prediction_diff: The softmax of linear-probed difference.
- unique labels used in the grouping (for alignment).
"""
if self.use_norm != 'none':
x = self.norm(x)
a = torch.sigmoid((self.linear(x))) * x
# Get unique values of label2
unique_label2 = torch.unique(labels2)
if len(unique_label2) != 2:
raise ValueError("label2 should have exactly two unique values")
# Process for each unique value in label2
ress = []
uniques = []
for val in unique_label2:
mask = (labels2 == val)
a_masked = a[mask]
labels_masked = labels[mask]
# Process based on labels within each label2 group
unique, labels_tmp = torch.unique(labels_masked, sorted=True, return_inverse=True)
label = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, a_masked.size(1))
unique_labels, labels_count = label.unique(dim=0, return_counts=True)
res = torch.zeros_like(unique_labels, dtype=torch.float, device=a_masked.device).scatter_add_(0, label, a_masked)
res = res / labels_count.float().unsqueeze(1)
ress.append(res)
uniques.append(unique)
if not torch.all(torch.eq(uniques[0], uniques[1])):
print(uniques[0])
print(uniques[1])
raise ValueError("not match")
# Calculate the difference between the two predictions
prediction_diff = self.softmax(self.linear2(ress[1]-ress[0]))
return prediction_diff, unique
[docs]
class WeightedDisentangledLinearProbing(nn.Module):
"""
A weighted, disentangled linear probing module that combines two embeddings.
Parameters
----------
x1_dim : int
Dimensionality of the first input embedding.
x2_dim : int
Dimensionality of the second input embedding.
y_dim : int
Dimensionality of the target (number of classes).
hidden_dim : int
Hidden dimension size.
use_norm : str, default: 'none'
Type of normalization to apply. One of ['none', 'batch', 'layer'].
"""
def __init__(self, x1_dim, x2_dim, y_dim, hidden_dim, use_norm = 'none'):
super(WeightedDisentangledLinearProbing, self).__init__()
self.use_norm = use_norm
if self.use_norm == 'batch':
self.n1 = nn.BatchNorm1d(x1_dim)
self.n2 = nn.BatchNorm1d(x2_dim)
elif self.use_norm == 'layer':
self.n1 = nn.LayerNorm(x1_dim)
self.n2 = nn.LayerNorm(x2_dim)
self.linear = nn.Linear(x2_dim, 1) # output the weights for each cell
#self.linear1 = nn.Linear(hidden_dim, 1)
self.linear2 = nn.Linear(x1_dim, y_dim) # output the weights for each cell
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
self.loss = nn.CrossEntropyLoss()
self.y_dim = y_dim
[docs]
def forward(self, x1, x2, labels, y):
"""
Forward pass for WeightedDisentangled linear probing.
Parameters
----------
x1 : torch.Tensor
First embedding input of shape (batch_size, x1_dim).
x2 : torch.Tensor
Second embedding input of shape (batch_size, x2_dim).
labels : torch.Tensor
Group labels used for aggregation.
y : torch.Tensor
One-hot encoded target of shape (batch_size, y_dim).
Returns
-------
torch.Tensor
Scalar cross-entropy loss.
"""
if self.use_norm != 'none':
x1 = self.n1(x1)
x2 = self.n2(x2)
a = torch.sigmoid(self.linear(x2)) * x1
unique, labels_tmp = torch.unique(labels, sorted=True, return_inverse=True)
label = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, a.size(1))
unique_labels, labels_count = label.unique(dim=0, return_counts=True)
res = torch.zeros_like(unique_labels, dtype=torch.float, device=a.device).scatter_add_(0, label, a)
res = res / labels_count.float().unsqueeze(1)
prediction = self.softmax(self.linear2(res))
label_y = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, y.size(1))
y_ = torch.zeros([unique_labels.shape[0],y.size(1)], dtype=torch.float, device=a.device).scatter_add_(0, label_y, y)
y_ = y_ / labels_count.float().unsqueeze(1)
return self.loss(prediction, y_)
[docs]
def predict(self, x1, x2, labels):
"""
Prediction method for WeightedDisentangled linear probing.
Parameters
----------
x1 : torch.Tensor
First embedding of shape (n_samples, x1_dim).
x2 : torch.Tensor
Second embedding of shape (n_samples, x2_dim).
labels : torch.Tensor
Group labels used for aggregation.
Returns
-------
(torch.Tensor, torch.Tensor)
- predictions: shape (n_groups, y_dim)
- unique label IDs (on the device)
"""
if self.use_norm != 'none':
x1 = self.n1(x1)
x2 = self.n2(x2)
a = torch.sigmoid(self.linear(x2)) * x1
unique, labels_tmp = torch.unique(labels, sorted=True, return_inverse=True)
label = labels_tmp.view(labels_tmp.size(0), 1).expand(-1, a.size(1))
unique_labels, labels_count = label.unique(dim=0, return_counts=True)
res = torch.zeros_like(unique_labels, dtype=torch.float, device=a.device).scatter_add_(0, label, a)
res = res / labels_count.float().unsqueeze(1)
prediction = self.softmax(self.linear2(res))
return prediction.detach(), unique
[docs]
class LinearProbing(nn.Module):
"""
A simple linear probing module without per-group weighting.
Parameters
----------
input_dim : int
Dimensionality of the input features.
y_dim : int
Dimensionality of the target (number of classes).
use_norm : str, default: 'none'
Type of normalization to apply. One of ['none', 'batch', 'layer'].
"""
def __init__(self, input_dim, y_dim, use_norm = 'none'):
super(LinearProbing, self).__init__()
self.use_norm = use_norm
if self.use_norm == 'batch':
self.norm = nn.BatchNorm1d(input_dim)
elif self.use_norm == 'layer':
self.norm = nn.LayerNorm(input_dim)
self.linear = nn.Linear(input_dim, y_dim) # output the weights for each cell
self.softmax = nn.Softmax(dim=1)
self.loss = nn.CrossEntropyLoss()
[docs]
def forward(self, x, y):
"""
Forward pass for a simple linear probing.
Parameters
----------
x : torch.Tensor
Input feature of shape (batch_size, input_dim).
y : torch.Tensor
One-hot encoded target of shape (batch_size, y_dim).
Returns
-------
torch.Tensor
Scalar cross-entropy loss.
"""
if self.use_norm != 'none':
x = self.norm(x)
prediction = self.softmax(self.linear(x))
return self.loss(prediction, y)
[docs]
def predict(self, x):
"""
Inference for a simple linear probing.
Parameters
----------
x : torch.Tensor
Input feature of shape (n_samples, input_dim).
Returns
-------
torch.Tensor
Predictions of shape (n_samples, y_dim).
"""
if self.use_norm != 'none':
x = self.norm(x)
prediction = self.softmax(self.linear(x))
return prediction.detach()