Skip to content

Commit

Permalink
Merge pull request #135 from kathyxchen/pytorch-140-update
Browse files Browse the repository at this point in the history
Minor changes and updating Selene to be compatible with PyTorch 1.4.0
  • Loading branch information
kathyxchen committed Feb 25, 2020
2 parents ec315d9 + 8edb637 commit 6a9f2a0
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 27 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ Selene is a Python library and command line interface for training deep neural n
We recommend using Selene with Python 3.6 or above.
Package installation should only take a few minutes (less than 10 minutes, typically ~2-3 minutes) with any of these methods (conda, pip, source).

**Install [PyTorch](https://pytorch.org/get-started/locally/).** If you have an NVIDIA GPU, install a version of PyTorch that supports it--Selene will run much faster with a discrete GPU.
**First, install [PyTorch](https://pytorch.org/get-started/locally/).** If you have an NVIDIA GPU, install a version of PyTorch that supports it--Selene will run much faster with a discrete GPU.
The library is currently compatible with PyTorch versions between 0.4.1 and 1.4.0.
We will continue to update Selene to be compatible with the latest version of PyTorch.

### Installing selene with [Anaconda](https://www.anaconda.com/download/) (for Linux):

Expand Down
4 changes: 4 additions & 0 deletions docs/source/overview/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ evaluate_model: !obj:selene_sdk.EvaluateModel {
features: !obj:selene_sdk.utils.load_features_list {
input_path: /path/to/features_list.txt
},
use_features_ord: !obj:selene_sdk.utils.load_features_list {
input_path: /path/to/features_subset_ordered.txt
},
trained_model_path: /path/to/trained/model.pth.tar,
batch_size: 64,
n_test_samples: 640000,
Expand All @@ -190,6 +193,7 @@ evaluate_model: !obj:selene_sdk.EvaluateModel {
- `report_gt_feature_n_positives`: Default is 10. In total, each class/feature must have more than `report_gt_feature_n_positives` positive examples in the test set to be considered in the performance computation. The output file that reports each class's performance will report 'NA' for classes that do not have enough positive samples.
- `use_cuda`: Default is False. Specify whether CUDA-enabled GPUs are available for torch to use.
- `data_parallel`: Default is False. Specify whether multiple GPUs are available for torch to use.
- `use_features_ord`: Default is None. Specify an ordered list of features for which to run the evaluation. The features in this list must be identical to or a subset of `features`, and in the order you want the resulting `test_targets.npz` and `test_predictions.npz` to be saved.

#### Additional notes
Similar to the `train_model` configuration, any arguments that you find in [the documentation](https://selene.flatironinstitute.org/selene.html#evaluatemodel) that are not present in the function-type value's arguments are automatically instantiated and passed in by Selene.
Expand Down
58 changes: 50 additions & 8 deletions selene_sdk/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import logging
import os
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -57,6 +58,11 @@ class EvaluateModel(object):
data_parallel : bool, optional
Default is `False`. Specify whether multiple GPUs are available
for torch to use during training.
use_features_ord : list(str) or None, optional
Default is None. Specify an ordered list of features for which to
run the evaluation. The features in this list must be identical to or
a subset of `features`, and in the order you want the resulting
`test_targets.npz` and `test_predictions.npz` to be saved.
Attributes
----------
Expand Down Expand Up @@ -88,7 +94,8 @@ def __init__(self,
n_test_samples=None,
report_gt_feature_n_positives=10,
use_cuda=False,
data_parallel=False):
data_parallel=False,
use_features_ord=None):
self.criterion = criterion

trained_model = torch.load(
Expand All @@ -103,11 +110,26 @@ def __init__(self,

self.sampler = data_sampler

self.features = features

self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)

self.features = features
self._use_ixs = list(range(len(features)))
if use_features_ord is not None:
feature_ixs = {f: ix for (ix, f) in enumerate(features)}
self._use_ixs = []
self.features = []

for f in use_features_ord:
if f in feature_ixs:
self._use_ixs.append(feature_ixs[f])
self.features.append(f)
else:
warnings.warn(("Feature {0} in `use_features_ord` "
"does not match any features in the list "
"`features` and will be skipped.").format(f))
self._write_features_ordered_to_file()

initialize_logger(
os.path.join(self.output_dir, "{0}.log".format(
__name__)),
Expand All @@ -130,11 +152,30 @@ def __init__(self,

self._test_data, self._all_test_targets = \
self.sampler.get_data_and_targets(self.batch_size, n_test_samples)
# TODO: we should be able to do this on the sampler end instead of
# here. the current workaround is problematic, since
# self._test_data still has the full featureset in it, and we
# select the subset during `evaluate`
self._all_test_targets = self._all_test_targets[:, self._use_ixs]

# reset Genome base ordering when applicable.
if (hasattr(self.sampler, "reference_sequence") and
isinstance(self.sampler.reference_sequence, Genome) and
_is_lua_trained_model(model)):
Genome.update_bases_order(['A', 'G', 'C', 'T'])
isinstance(self.sampler.reference_sequence, Genome)):
if _is_lua_trained_model(model):
Genome.update_bases_order(['A', 'G', 'C', 'T'])
else:
Genome.update_bases_order(['A', 'C', 'G', 'T'])

def _write_features_ordered_to_file(self):
"""
Write the feature ordering specified by `use_features_ord`
after matching it with the `features` list from the class
initialization parameters.
"""
fp = os.path.join(self.output_dir, 'use_features_ord.txt')
with open(fp, 'w+') as file_handle:
for f in self.features:
file_handle.write('{0}\n'.format(f))

def _get_feature_from_index(self, index):
"""
Expand Down Expand Up @@ -170,7 +211,7 @@ def evaluate(self):
all_predictions = []
for (inputs, targets) in self._test_data:
inputs = torch.Tensor(inputs)
targets = torch.Tensor(targets)
targets = torch.Tensor(targets[:, self._use_ixs])

if self.use_cuda:
inputs = inputs.cuda()
Expand All @@ -182,10 +223,11 @@ def evaluate(self):
predictions = None
if _is_lua_trained_model(self.model):
predictions = self.model.forward(
inputs.transpose(1, 2).unsqueeze_(2))
inputs.transpose(1, 2).contiguous().unsqueeze_(2))
else:
predictions = self.model.forward(
inputs.transpose(1, 2))
predictions = predictions[:, self._use_ixs]
loss = self.criterion(predictions, targets)

all_predictions.append(predictions.data.cpu().numpy())
Expand Down
3 changes: 2 additions & 1 deletion selene_sdk/predict/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def predict(model, batch_sequences, use_cuda=False):
inputs = Variable(inputs)

if _is_lua_trained_model(model):
outputs = model.forward(inputs.transpose(1, 2).unsqueeze_(2))
outputs = model.forward(
inputs.transpose(1, 2).contiguous().unsqueeze_(2))
else:
outputs = model.forward(inputs.transpose(1, 2))
return outputs.data.cpu().numpy()
Expand Down
1 change: 1 addition & 0 deletions selene_sdk/predict/_variant_effect_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def read_vcf_file(input_path,
if not reference_sequence.coords_in_bounds(chrom, start, end):
na_rows.append(line)
continue
alt = alt.replace('.', ',') # consider '.' a valid delimiter
for a in alt.split(','):
variants.append((chrom, pos, name, ref, a, strand))

Expand Down
12 changes: 7 additions & 5 deletions selene_sdk/predict/model_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def __init__(self,
if type(self.reference_sequence) == Genome and \
_is_lua_trained_model(model):
Genome.update_bases_order(['A', 'G', 'C', 'T'])
else: # even if not using Genome, I guess we can update?
Genome.update_bases_order(['A', 'C', 'G', 'T'])
self._write_mem_limit = write_mem_limit

def _initialize_reporters(self,
Expand Down Expand Up @@ -424,11 +426,11 @@ def get_predictions_for_bed_file(self,
batch_ids.append(label+(contains_unk,))
sequences[ i % self.batch_size, :, :] = encoding
if contains_unk:
warnings.warn("For region {0}, "
"reference sequence contains unknown base(s). "
"--will be marked `True` in the `contains_unk` column "
"of the .tsv or the row_labels .txt file.".format(
label))
warnings.warn(("For region {0}, "
"reference sequence contains unknown "
"base(s). --will be marked `True` in the "
"`contains_unk` column of the .tsv or "
"row_labels .txt file.").format(label))

if (batch_ids and i == 0) or i % self.batch_size != 0:
sequences = sequences[:i % self.batch_size + 1, :, :]
Expand Down
12 changes: 6 additions & 6 deletions selene_sdk/sequences/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def _not_blacklist_region(chrom, start, end, blacklist_tabix):
"""
Check if the input coordinates are not overlapping with blacklist regions.
Parameters
----------
chrom : str
Expand All @@ -29,14 +29,14 @@ def _not_blacklist_region(chrom, start, end, blacklist_tabix):
blacklist_tabix : tabix.open or None, optional
Default is `None`. Tabix file handle if a file of blacklist regions
is available.
Returns
-------
bool
False if the coordinates are overlaping with blacklist regions
(if specified). Otherwise, return True.
"""
if blacklist_tabix is not None:
try:
Expand Down Expand Up @@ -203,7 +203,7 @@ class Genome(Sequence):
"""

BASES_ARR = np.array(['A', 'C', 'G', 'T'])
BASES_ARR = ['A', 'C', 'G', 'T']
"""
This is an array with the alphabet (i.e. all possible symbols
that may occur in a sequence). We expect that
Expand Down Expand Up @@ -463,7 +463,7 @@ def get_encoding_from_coords_check_unk(self,
strand='+',
pad=False):
"""Gets the one-hot encoding of the genomic sequence at the
queried coordinates and check whether the sequence contains
queried coordinates and check whether the sequence contains
unknown base(s).
Parameters
Expand Down
4 changes: 2 additions & 2 deletions selene_sdk/sequences/proteome.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ class around the `pyfaidx.Fasta` class.
"""

BASES_ARR = np.array(['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'])
BASES_ARR = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
"""
This is an array with the alphabet (i.e. all possible symbols
that may occur in a sequence). We expect that
Expand Down
2 changes: 1 addition & 1 deletion selene_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.4"
__version__ = "0.4.5"
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
cmdclass = {'build_ext': build_ext}

setup(name="selene-sdk",
version="0.4.4",
version="0.4.5",
long_description=long_description,
long_description_content_type='text/markdown',
description=("framework for developing sequence-level "
Expand Down Expand Up @@ -62,6 +62,5 @@
"scipy",
"seaborn",
"statsmodels",
"torch>=0.4.1",
"torchvision"
"torch>=0.4.1, <=1.4.0",
])

0 comments on commit 6a9f2a0

Please sign in to comment.