Skip to content

Commit

Permalink
wrap SCALE in SCALE_function
Browse files Browse the repository at this point in the history
  • Loading branch information
jsxlei committed Oct 20, 2022
1 parent 72b5e0c commit 9be2249
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 130 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ Installation only requires a few minutes.
#### Run

SCALE.py -d [input]

if cluster number k is known:

SCALE.py -d [input] -k [k]

#### Output
Output will be saved in the output folder including:
Expand All @@ -69,7 +65,7 @@ or get numerical imputed data in adata.h5ad file using scanpy **adata.obsm['impu
* save results in a specific folder: [-o] or [--outdir]
* embed feature by tSNE or UMAP: [--embed] tSNE/UMAP
* filter low quality cells by valid peaks number, default 100: [--min_peaks]
* filter low quality peaks by valid cells number, default 0.01: [--min_cells]
* filter low quality peaks by valid cells number, default 3: [--min_cells]
* filter peaks by selecting highly variable features, default 100,000: [--n_feature], disable by [--n_feature] -1.
* modify the initial learning rate, default is 0.002: [--lr]
* change iterations by watching the convergence of loss, default is 30000: [-i] or [--max_iter]
Expand Down
147 changes: 27 additions & 120 deletions SCALE.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,9 @@
"""


import time
import torch

import numpy as np
import pandas as pd
import os
import scanpy as sc
import argparse

from scale import SCALE
from scale.dataset import load_dataset
from scale.utils import read_labels, cluster_report, estimate_k, binarization
from scale.plot import plot_embedding

from sklearn.preprocessing import MaxAbsScaler
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader
from scale import SCALE_function


if __name__ == '__main__':
Expand All @@ -49,7 +35,7 @@
parser.add_argument('--decode_dim', type=int, nargs='*', default=[], help='encoder structure')
parser.add_argument('--latent', '-l',type=int, default=10, help='latent layer dim')
parser.add_argument('--min_peaks', type=float, default=100, help='Remove low quality cells with few peaks')
parser.add_argument('--min_cells', type=float, default=0.01, help='Remove low quality peaks')
parser.add_argument('--min_cells', type=float, default=3, help='Remove low quality peaks')
parser.add_argument('--n_feature', type=int, default=100000, help='Keep the number of highly variable peaks')
parser.add_argument('--log_transform', action='store_true', help='Perform log2(x+1) transform')
parser.add_argument('--max_iter', '-i', type=int, default=30000, help='Max iteration')
Expand All @@ -62,109 +48,30 @@

args = parser.parse_args()

# Set random seed
seed = args.seed
np.random.seed(seed)
torch.manual_seed(seed)

if torch.cuda.is_available(): # cuda device
device='cuda'
torch.cuda.set_device(args.gpu)
else:
device='cpu'
batch_size = args.batch_size

print("\n**********************************************************************")
print(" SCALE: Single-Cell ATAC-seq Analysis via Latent feature Extraction")
print("**********************************************************************\n")

adata, trainloader, testloader = load_dataset(
args.data_list,
batch_categories=None,
join='inner',
batch_key='batch',
batch_name='batch',
min_genes=args.min_peaks,
min_cells=args.min_cells,
batch_size=args.batch_size,
n_top_genes=args.n_feature,
log=None,
adata = SCALE_function(
args.data_list,
n_centroids = args.n_centroids,
outdir = args.outdir,
verbose = args.verbose,
pretrain = args.pretrain,
lr = args.lr,
batch_size = args.batch_size,
gpu = args.gpu,
seed = args.seed,
encode_dim = args.encode_dim,
decode_dim = args.decode_dim,
latent = args.latent,
min_peaks = args.min_peaks,
min_cells = args.min_cells,
n_feature = args.n_feature,
log_transform = args.log_transform,
max_iter = args.max_iter,
weight_decay = args.weight_decay,
impute = args.impute,
binary = args.binary,
embed = args.embed,
reference = args.reference,
cluster_method = args.cluster_method,
)

cell_num = adata.shape[0]
input_dim = adata.shape[1]

# if args.n_centroids is None:
# k = estimate_k(adata.X.T)
# print('Estimate k = {}'.format(k))
# else:
# k = args.n_centroids
lr = args.lr
k = args.n_centroids

outdir = args.outdir+'/'
if not os.path.exists(outdir):
os.makedirs(outdir)

print("\n======== Parameters ========")
print('Cell number: {}\nPeak number: {}\nn_centroids: {}\nmax_iter: {}\nbatch_size: {}\ncell filter by peaks: {}\npeak filter by cells: {}'.format(
cell_num, input_dim, k, args.max_iter, batch_size, args.min_peaks, args.min_cells))
print("============================")

dims = [input_dim, args.latent, args.encode_dim, args.decode_dim]
model = SCALE(dims, n_centroids=k)
print(model)

if not args.pretrain:
print('\n## Training Model ##')
model.init_gmm_params(testloader)
model.fit(trainloader,
lr=lr,
weight_decay=args.weight_decay,
verbose=args.verbose,
device = device,
max_iter=args.max_iter,
# name=name,
outdir=outdir
)
torch.save(model.state_dict(), os.path.join(outdir, 'model.pt')) # save model
else:
print('\n## Loading Model: {}\n'.format(args.pretrain))
model.load_model(args.pretrain)
model.to(device)

### output ###
print('outdir: {}'.format(outdir))
# 1. latent feature
adata.obsm['latent'] = model.encodeBatch(testloader, device=device, out='z')

# 2. cluster
sc.pp.neighbors(adata, n_neighbors=30, use_rep='latent')
if args.cluster_method == 'leiden':
sc.tl.leiden(adata)
elif args.cluster_method == 'kmeans':
kmeans = KMeans(n_clusters=k, n_init=20, random_state=0)
adata.obs['kmeans'] = kmeans.fit_predict(adata.obsm['latent']).astype(str)

# if args.reference in adata.obs:
# cluster_report(adata.obs[args.reference].cat.codes, adata.obs[args.cluster_method].astype(int))

sc.settings.figdir = outdir
sc.set_figure_params(dpi=80, figsize=(6,6), fontsize=10)
if args.embed == 'UMAP':
sc.tl.umap(adata, min_dist=0.1)
color = [c for c in ['celltype', args.cluster_method] if c in adata.obs]
sc.pl.umap(adata, color=color, save='.png', wspace=0.4, ncols=4)
elif args.embed == 'tSNE':
sc.tl.tsne(adata, use_rep='latent')
color = [c for c in ['celltype', args.cluster_method] if c in adata.obs]
sc.pl.tsne(adata, color=color, save='.png', wspace=0.4, ncols=4)

if args.impute:
adata.obsm['impute'] = model.encodeBatch(testloader, device=device, out='x')
if args.binary:
adata.obsm['impute'] = model.encodeBatch(testloader, device=device, out='x')
adata.obsm['binary'] = binarization(adata.obsm['impute'], adata.X)
del adata.obsm['impute']

adata.write(outdir+'adata.h5ad', compression='gzip')

Loading

0 comments on commit 9be2249

Please sign in to comment.