Skip to content

Commit

Permalink
feature(scar): introduce caching functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed Jul 27, 2024
1 parent 707a0e0 commit 3f07c35
Showing 1 changed file with 57 additions and 16 deletions.
73 changes: 57 additions & 16 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Union
import numpy as np, pandas as pd, anndata as ad

from collections import OrderedDict
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from tqdm.contrib import DummyTqdmFile
Expand Down Expand Up @@ -207,6 +208,7 @@ def __init__(
sparsity: float = 0.9,
batch_key: str = None,
device: str = "auto",
cache_capacity: int = 20000,
verbose: bool = True,
):
"""initialize object"""
Expand Down Expand Up @@ -268,6 +270,11 @@ def __init__(
"""float, the sparsity of expected native signals. (0, 1]. \
Forced to be one in the mode of "sgRNA(s)" and "tag(s)".
"""
self.cache_capacity = cache_capacity
"""int, the capacity of cache.
.. versionadded:: 0.6.1
"""

if isinstance(raw_count, ad.AnnData):
if batch_key:
Expand Down Expand Up @@ -438,9 +445,10 @@ def train(
train_ids, test_ids = train_test_split(list_ids, train_size=train_size)

# Generators
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=train_ids)
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=train_ids, device=self.device, cache_capacity=self.cache_capacity)
training_generator = torch.utils.data.DataLoader(
training_set, batch_size=batch_size, shuffle=shuffle
training_set, batch_size=batch_size, shuffle=shuffle,
drop_last=True
)
# val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=test_ids)
# val_generator = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -492,9 +500,9 @@ def train(
vae_nets.train()
for x_batch, ambient_freq, batch_id_onehot in training_generator:
# Move data to device
x_batch = x_batch.to(self.device)
ambient_freq = ambient_freq.to(self.device)
batch_id_onehot = batch_id_onehot.to(self.device)
# x_batch = x_batch.to(self.device)
# ambient_freq = ambient_freq.to(self.device)
# batch_id_onehot = batch_id_onehot.to(self.device)

optim.zero_grad()
dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot)
Expand Down Expand Up @@ -589,7 +597,7 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id)
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
self.native_counts = np.empty([sample_size, n_features])
Expand All @@ -606,9 +614,9 @@ def inference(

for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data:
# Move data to device
x_batch_tot = x_batch_tot.to(self.device)
x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device)
ambient_freq_tot = ambient_freq_tot.to(self.device)
# x_batch_tot = x_batch_tot.to(self.device)
# x_batch_id_onehot_tot = x_batch_id_onehot_tot.to(self.device)
# ambient_freq_tot = ambient_freq_tot.to(self.device)

minibatch_size = x_batch_tot.shape[
0
Expand Down Expand Up @@ -710,30 +718,63 @@ def assignment(self, cutoff=3, moi=None):
raise NotImplementedError


class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity

def get(self, key):
if key not in self.cache:
return None
self.cache.move_to_end(key)
return self.cache[key]

def put(self, key, value):
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.capacity:
self.cache.popitem(last=False)

class UMIDataset(torch.utils.data.Dataset):
"""Characterizes dataset for PyTorch"""

def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None):
def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None, cache_capacity=20000):
"""Initialization"""

self.raw_count = torch.from_numpy(raw_count.fillna(0).values).int() if isinstance(raw_count, pd.DataFrame) else raw_count
self.ambient_profile = torch.from_numpy(ambient_profile).float()
self.batch_id = torch.from_numpy(batch_id).to(torch.int64)
self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
self.batch_id = torch.from_numpy(batch_id).to(torch.int64).to(device)
self.batch_onehot = torch.from_numpy(np.eye(len(np.unique(batch_id)))).to(torch.int64).to(device)
self.device = device
self.cache_capacity = cache_capacity

if list_ids:
self.list_ids = list_ids
else:
self.list_ids = list(range(raw_count.shape[0]))

# Cache data
self.cache = {}

def __len__(self):
"""Denotes the total number of samples"""
return len(self.list_ids)

def __getitem__(self, index):
"""Generates one sample of data"""

if index in self.cache:
return self.cache[index]

# Select sample
sc_id = self.list_ids[index]
sc_count = self.raw_count[sc_id] if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[sc_id].X.toarray().flatten()).int()
sc_ambient = self.ambient_profile[self.batch_id[sc_id], :]
sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :]
sc_count = self.raw_count[sc_id].to(self.device) if isinstance(self.raw_count, torch.Tensor) else torch.from_numpy(self.raw_count[sc_id].X.toarray().flatten()).int().to(self.device)
sc_ambient = self.ambient_profile[self.batch_id[sc_id], :].to(self.device)
sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :].to(self.device)

# Cache the sample
if len(self.cache) <= self.cache_capacity:
self.cache[index] = (sc_count, sc_ambient, sc_batch_id_onehot)

return sc_count, sc_ambient, sc_batch_id_onehot

0 comments on commit 3f07c35

Please sign in to comment.