Skip to content

Commit

Permalink
Added tests for ignore_upper_triangle option in distance import sub c…
Browse files Browse the repository at this point in the history
…ommand.
  • Loading branch information
sverhoeven committed May 31, 2016
1 parent 29f15bd commit db0f96f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
6 changes: 4 additions & 2 deletions kripodb/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ def distmatrix_import_tsv(inputfile, fragmentsdb, distmatrixfn, nrrows, ignore_u
# distmatrix wants score as float instead of str
def csv_iter(rows):
for row in rows:
if row[0] == row[1]:
continue
if ignore_upper_triangle and row[0] > row[1]:
continue
row[2] = float(row[2])
Expand Down Expand Up @@ -584,17 +586,17 @@ def read_fpneighpairs_file(inputfile, ignore_upper_triangle=False):
Args:
inputfile (file): File object to read
ignore_upper_triangle (bool): Ignore upper triangle of input
Yields:
Tuple((Str,Str,Float)): List of (query fragment identifier, hit fragment identifier, distance score)
"""
current_query = None
reader = csv.reader(inputfile, delimiter=' ', skipinitialspace=True)

for row in reader:
if len(row) == 2 and current_query != row[0]:
if ignore_upper_triangle and current_query > row[1]:
if ignore_upper_triangle and current_query > row[0]:
continue
yield (current_query, row[0], float(row[1]))
elif len(row) == 4:
Expand Down
62 changes: 62 additions & 0 deletions tests/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,39 @@ def test_distmatrix_import_run():
os.remove(output_fn)


def test_distmatrix_import_run_ignore_upper_triangle():
output_fn = tmpname()

tsv = '''frag_id1 frag_id2 score
2mlm_2W7_frag1 2mlm_2W7_frag1 1.0000000000000000
2mlm_2W7_frag2 2mlm_2W7_frag2 1.0000000000000000
2mlm_2W7_frag1 2mlm_2W7_frag2 0.5877164873731594
2mlm_2W7_frag2 3wvm_STE_frag1 0.4633096818493935
2mlm_2W7_frag2 2mlm_2W7_frag1 0.5877164873731594
3wvm_STE_frag1 2mlm_2W7_frag2 0.4633096818493935
'''
inputfile = StringIO(tsv)

try:
script.distmatrix_import_run(inputfile=inputfile,
format='tsv',
distmatrixfn=output_fn,
fragmentsdb='data/fragments.sqlite',
nrrows=2,
ignore_upper_triangle=True)

distmatrix = DistanceMatrix(output_fn)
result = [r for r in distmatrix]
distmatrix.close()
print(result)
expected = [('2mlm_2W7_frag1', '2mlm_2W7_frag2xx', 0.5877), ('2mlm_2W7_frag2', '3wvm_STE_frag1', 0.4633)]
assert_array_almost_equal([r[2] for r in result], [r[2] for r in expected], 3)
eq_([(r[0], r[1],) for r in result], [(r[0], r[1],) for r in result])
finally:
if os.path.exists(output_fn):
os.remove(output_fn)


def test_distmatrix_export_run():
outputfile = StringIO()
script.distmatrix_export_run('data/distances.h5', outputfile)
Expand Down Expand Up @@ -136,6 +169,35 @@ def test_distmatrix_importfpneigh_run():
os.remove(output_fn)


def test_distmatrix_importfpneigh_run_ignore_upper_triangle():
output_fn = tmpname()

tsv = '''Compounds similar to 2mlm_2W7_frag1:
2mlm_2W7_frag1 1.0000
2mlm_2W7_frag2 0.5877
Compounds similar to 2mlm_2W7_frag2:
2mlm_2W7_frag2 1.0000
2mlm_2W7_frag1 0.5877
3wvm_STE_frag1 0.4633
'''
inputfile = StringIO(tsv)

try:
script.distmatrix_importfpneigh_run(inputfile=inputfile,
distmatrixfn=output_fn,
fragmentsdb='data/fragments.sqlite',
nrrows=3,
ignore_upper_triangle=True)

distmatrix = DistanceMatrix(output_fn)
rows = [r for r in distmatrix]
distmatrix.close()
expected = [(u'2mlm_2W7_frag1', u'2mlm_2W7_frag2', 0.5877), (u'2mlm_2W7_frag2', u'3wvm_STE_frag1', 0.4633)]
eq_(rows, expected)
finally:
os.remove(output_fn)


def test_fpneigh2tsv_run():
fpneigh_in = '''Compounds similar to 2mlm_2W7_frag1:
2mlm_2W7_frag1 1.0000
Expand Down

0 comments on commit db0f96f

Please sign in to comment.