Skip to content

Commit

Permalink
remove ph dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
heathcliff233 committed Aug 27, 2024
1 parent 8fe77be commit 810cfe0
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 61 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
__pycache__/
*.py[cod]
*$py.class
*.DS_Store
*.DS_Store
outputs/
4 changes: 2 additions & 2 deletions conf/pred_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ trainer:
acc_step: 4
accelerator: gpu
precision: 32
batch_size: 600
batch_size: 200
max_epochs: 100
ur90_path: /data/hongliang/Dense-Homolog-Retrieval/df-ebd.tsv
ur90_path: /data/hongliang/Dense-Homolog-Retrieval/example/df-ebd.tsv

model:
resume: True
Expand Down
18 changes: 11 additions & 7 deletions do_agg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import os
import pandas as pd
import phylopandas.phylopandas as ph
# import phylopandas.phylopandas as ph
from pyarrow import csv
import argparse
import faiss

Expand All @@ -17,23 +18,26 @@
output_path = args.output_path

# Load original sequence database
seqdb_df = ph.read_fasta(seqdb_path, use_uids=False)
# seqdb_df = ph.read_fasta(seqdb_path, use_uids=False)
seqdb_df = csv.read_csv(seqdb_path,
read_options=csv.ReadOptions(column_names=['id', 'sequence']),
parse_options=csv.ParseOptions(delimiter='\t')).to_pandas()
seqdb_df = seqdb_df.set_index('id')

# Create Index
index = faiss.IndexFlatL2(768)
index = faiss.IndexFlatL2(480)
id_lst = []

# Load embedded database and process
for rank in os.listdir(embdb_path):
for pts in os.listdir(os.path.join(embdb_path, rank)):
if pts.endswith(".pt"):
lst, vec = torch.load(os.path.join(embdb_path, rank, pts))
id_lst += lst
vec = torch.cat(torch.load(os.path.join(embdb_path, rank, pts))[0])
print(vec.shape)
index.add(vec.cpu().numpy())

# Write aggregated results
os.makedirs(output_path, exist_ok=True)
ord_df = seqdb_df.loc[id_lst].reset_index()
ord_df.to_pickle(os.path.join(output_path, "df-ebd.pkl"))
seqdb_df.reset_index(inplace=True)
seqdb_df.to_pickle(os.path.join(output_path, "df-ebd.pkl"))
faiss.write_index(index, os.path.join(output_path, "index-ebd.index"))
51 changes: 22 additions & 29 deletions do_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import torch
import pytorch_lightning as pl
from mydpr.model.biencoder import MyEncoder
from mydpr.dataset.cath35 import PdDataModule
from mydpr.dataset.cath35 import PdDataModule, ArrowDataset
from torch.utils.data import DataLoader
import sys
import os
import argparse
import faiss
import time
import math
sys.path.append("/share/hongliang")
# sys.path.append("/share/hongliang")
import pandas as pd
import numpy as np
import phylopandas.phylopandas as ph
# import phylopandas.phylopandas as ph

ckpt_path = "./cpu_model/fastmsa-cpu.ckpt" #-> modified by Sheng Wang at 2022.06.14
ckpt_path = "/data/hongliang/Dense-Homolog-Retrieval/cpu_model"
input_path = "./input_test.fasta"
qjackhmmer = "./bin/qjackhmmer"
out_path = "./testout/"
Expand Down Expand Up @@ -57,7 +58,7 @@ def my_align(out_dir, iter_num):
s0 = time.time()

# print("Start mkdir!!!")
gen_query(input_path, out_path)
# gen_query(input_path, out_path)
s1 = time.time()
# print("Mkdir output cost: %f s"%(s1-s0))

Expand All @@ -66,43 +67,35 @@ def my_align(out_dir, iter_num):
# print("Load index cost: %f s"%(s2-s1))
df = pd.read_pickle(dm_path)

model = MyEncoder.load_from_checkpoint(checkpoint_path=ckpt_path)
ds = PdDataModule(input_path, 40, model.alphabet)
model = MyEncoder(bert_path=[os.path.join(ckpt_path, 'dhr_qencoder.pt'), os.path.join(ckpt_path, 'dhr_cencoder.pt')]).eval()
ds = ArrowDataset(input_path)
names = ds.id.to_pylist()
bc = model.alphabet.get_batch_converter()

s3 = time.time()
# print("Load ckp cost: %f s"%(s3-s2))
trainer = pl.Trainer() # gpus=[0])
ret = trainer.predict(model, datamodule=ds, return_predictions=True)
trainer.save_checkpoint(ckpt_path)
s4 = time.time()
# trainer = pl.Trainer() # gpus=[0])
# ret = trainer.predict(model, datamodule=ds, return_predictions=True)
# s4 = time.time()
# print("Train predict cost: %f s"%(s4-s3))
# names, qebd = ret[0]
ebd = []
dataloader = DataLoader(ds, batch_size=search_batch, collate_fn=bc, shuffle=False)
for i, et in enumerate(dataloader):
a,b,c = et
with torch.no_grad():
ebd.append(model.forward_left(c))
ebd = torch.cat(ebd, dim=0)
encoded = ebd.numpy()

tmp1 = []
tmp2 = []
for i in ret:
n1, q1 = i
tmp1 += n1
q1 = torch.tensor(q1).numpy()
tmp2.append(q1)
encoded = np.concatenate(tmp2, axis=0)
# encoded = np.concatenate([t.cpu().numpy() for t in tmp2])
names = tmp1
# print(encoded.shape)

# encoded = qebd.numpy()
# print("prepared model")
s5 = time.time()
# print("Encode model cost: %f s"%(s5-s4))

os.makedirs(os.path.join(out_path, "db"), exist_ok=True)
for i in range(math.ceil(encoded.shape[0]/search_batch)):
scores, idxes = index.search(encoded[i*search_batch:(i+1)*search_batch], tar_num)
idx_batch = len(idxes)
for j in range(idx_batch):
tar_idx = idxes[j]
res = df.iloc[tar_idx]
res.phylo.to_fasta_dev(os.path.join(out_path, "db", names[i*search_batch+j]+'.fasta'))
res[['id', 'sequence']].to_csv(os.path.join(out_path, "db", names[i*search_batch+j]+'.tsv'), sep='\t', index=False, header=False)

#end = time.time()
#print("Time for predict %d : %f s"%(tar_num, end-s5))
Expand Down
20 changes: 0 additions & 20 deletions input_test.fasta

This file was deleted.

10 changes: 10 additions & 0 deletions input_test.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
T0949-D1 GVHHYTIDEFNYYYKPDRMTWHVGEKVELTIDNRSQSAPPIAHQFSIGRTLVSIAVGWKDNFFDGVPITSGGQTGPVPAFSVSLNGGQKYTFSFVVPNKPGKWEYGCFLQTGQHFMNGMHGILDILPAQ
T0950-D1 VVYPEINVKTLSQAVKNIWRLSHQQKSGIEIIQEKTLRISLYSRDLDEAARASVPQLQTVLRQLPPQDYFLTLTEIDTELEDPELDDETRNTLLEARSEHIRNLKKDVKGVIRSLRKEANLMASRIADVSNVVILERLESSLKEEQERKAEIQADIAQQEKNKAKLVVDRNKIIESQDVIRQYNLADMFKDYIPNISDLDKLDLANPKKELIKQAIKQGVEIAKKILGNISKGLKYIELADARAKLDERINQINKDCDDLKIQLKGVEQRIAGIEDVHQIDKERTTLLLQAAKLEQAWNIFAKQLQNTIDGKIDQQDLTKIIHKQLDFLDDLALQYHSMLLS
T0951-D1 SIGLAHNVTILGSGETTVVLGHGYGTDQSVWKLLVPYLVDDYKVLLYDHMGAGTTNPDYFDFDRYSSLEGYSYDLIAILEEFQVSKCIYVGHSMSSMAAAVASIFRPDLFHKLVMISPTPRLINTEEYYGGFEQKVMDETLRSLDENFKSLSLGTAPLLLACDLESAAMQEYCRTLFNMRPDIACCITRMICGLDLRPYLGHVTVPCHIIQSSNDIMVPVAVGEYLRKNLGGPSVVEVMPTEGHLPHLSMPEVTIPVVLRHIRQDI
T0953s1-D1 ASIAIGDNDTGLRWGGDGIVQIVANNAIVGGWNSTDIFTEAGKHITSNGNLNQWGGGAIYCRDLNVS
T0953s2-D1 AVQGPWVGSSYVAETGQNWASLAANELRVTERPFWISSFIGRSK
T0953s2-D2 EEIWEWTGENHSFNKDWLIGELRNRGGTPVVINIRAHQVSYTPGAPLFEFPGDLPNAYITLNIYADIYGRGGTGGVAYLGGNPGGDCIHNWIGNRLRINNQGWICGRAVVGTSPQWINVGNIAGSWL
T0953s2-D3 GGGGGGGFRVGHTEAGGGGGRPLGAGGVSSLNLNGDNATLGAPGRGYQLGNDYAGNGGDVGNPGSASSAEMGGGAAG
T0954-D1 KHKYHFQKTFTVSQAGNCRIMAYCDALSCLVISQPSPGFGVKMLSTANMKSSQYIPMHGKQIRGLAFSSYLRGLLLSASLDNTIKLTSLETNTVVQTYNAGRPVWSCCWCLDEANYIYAGLANGSILVYDVRNTSSHVQELVAQKARCPLVSLSYMPRAASAAFPYGGVLAGTLEDASFWEQKMDFSHWPHVLPLEPGGCIDFQTENSSRHCLVTYRPDKNHTTIRSVLMEMSYRLDDTGNPICSCQPVHTFFGGPTCKLLTKNAIFQSPENDGNILVCTGDEAANSALLWDAASGSLLQDLQTDQPVLDICPFEVNRNSYLATLTEKMVHIYKWE
T0955-D1 SQETRKKCTEMKKKFKNCEVRCDESNHCVEVRCSDTKYTLC
T0957s1-D1 NSFEVSSLPDANGKNHITAVKGDAKIPVDKIELYMRARVLEQAGIVNTASNNSMIMDKLLDSAQGATSANRKTSVVVSGPNGNVRIYATWTILPDGTKRLSTVTGTFK
4 changes: 2 additions & 2 deletions mydpr/dataset/cath35.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __getitem__(self, index):
return self.id[index].as_py(), self.seq[index].as_py()

class PdDataModule(pl.LightningDataModule):
def __init__(self, data_path, batch_size, alphabet, trainer):
def __init__(self, data_path, batch_size, alphabet, trainer=None):
super().__init__()
self.path = data_path
self.batch_size = batch_size
Expand All @@ -96,7 +96,7 @@ def setup(self, stage):

def predict_dataloader(self):
sampler = DistributedProxySampler(self.pd_set, self.world_size, self.rank)
return DataLoader(dataset=self.pd_set, collate_fn=self.batch_converter, sampler=sampler, num_workers=8, batch_size=self.batch_size)
return DataLoader(dataset=self.pd_set, collate_fn=self.batch_converter, sampler=sampler, num_workers=8, batch_size=self.batch_size, shuffle=False)


def get_filename(sel_path: str) -> List[str]:
Expand Down

0 comments on commit 810cfe0

Please sign in to comment.