Skip to content

Commit

Permalink
Simplified collectri function
Browse files Browse the repository at this point in the history
  • Loading branch information
PauBadiaM committed Nov 17, 2023
1 parent 3f8c8f1 commit 9da10c9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 35 deletions.
48 changes: 21 additions & 27 deletions decoupler/omnip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
]

import os
import sys
import builtins
from types import ModuleType
from typing import Iterable
Expand Down Expand Up @@ -110,8 +109,7 @@ def _static_fallback(
query: str,
resource: str,
organism: int | str,
**kwargs
) -> pd.DataFrame:
**kwargs) -> pd.DataFrame:
"""
Fallback for static tables.
"""
Expand Down Expand Up @@ -177,7 +175,7 @@ def get_progeny(

try:
p = op.requests.Annotations.get(resources='PROGENy', **kwargs)
except:
except Exception:
p = _static_fallback(
query='annotations',
resource='PROGENy',
Expand Down Expand Up @@ -280,7 +278,7 @@ def get_resource(
entity_type='protein',
**kwargs
)
except:
except Exception:
df = _static_fallback(
query='annotations',
resource=name,
Expand Down Expand Up @@ -389,7 +387,7 @@ def get_dorothea(
genesymbols=True,
organism=_organism,
)
except:
except Exception:
do = _static_fallback(
query='interactions',
resource='DoRothEA',
Expand Down Expand Up @@ -467,6 +465,18 @@ def get_dorothea(
return do.reset_index(drop=True)


def merge_genes_to_complexes(ct_cmplx):
cmpl_gsym = []
for s in ct_cmplx['source_genesymbol']:
if s.startswith('JUN') or s.startswith('FOS'):
cmpl_gsym.append('AP1')
elif s.startswith('REL') or s.startswith('NFKB'):
cmpl_gsym.append('NFKB')
else:
cmpl_gsym.append(s)
ct_cmplx.loc[:, 'source_genesymbol'] = cmpl_gsym


def get_collectri(
organism: str | int = 'human',
split_complexes=False,
Expand Down Expand Up @@ -516,7 +526,7 @@ def get_collectri(
loops=True,
**kwargs
)
except:
except Exception:
ct = _static_fallback(
query='interactions',
resource='CollecTRI',
Expand All @@ -531,7 +541,7 @@ def get_collectri(
strict_evidences=True,
)
ct = pd.concat([ct, mirna], ignore_index=True)
except:
except Exception:
_warn_failure('TF-miRNA interaction', static_fallback=False)

# Separate gene_pairs from normal interactions
Expand All @@ -542,15 +552,7 @@ def get_collectri(

# Merge gene_pairs into complexes
if not split_complexes:
cmpl_gsym = []
for s in ct_cmplx['source_genesymbol']:
if s.startswith('JUN') or s.startswith('FOS'):
cmpl_gsym.append('AP1')
elif s.startswith('REL') or s.startswith('NFKB'):
cmpl_gsym.append('NFKB')
else:
cmpl_gsym.append(s)
ct_cmplx.loc[:, 'source_genesymbol'] = cmpl_gsym
merge_genes_to_complexes(ct_cmplx)

# Merge
ct = pd.concat([ct_inter, ct_cmplx])
Expand All @@ -559,15 +561,7 @@ def get_collectri(
ct = ct.drop_duplicates(['source_genesymbol', 'target_genesymbol'])

# Add weight
weights = []
for is_stimulation, is_inhibition in zip(ct['is_stimulation'], ct['is_inhibition']):
if is_stimulation:
weights.append(1)
elif is_inhibition:
weights.append(-1)
else:
weights.append(1)
ct['weight'] = weights
ct['weight'] = np.where(ct['is_inhibition'], -1, 1)

# Select and rename columns
ct = ct.rename(columns={'source_genesymbol': 'source', 'target_genesymbol': 'target', 'references_stripped': 'PMID'})
Expand Down Expand Up @@ -766,7 +760,7 @@ def get_ksn_omnipath(
ksn = ksn.drop_duplicates(['source', 'target', 'weight'])

# If duplicates remain, keep dephosphorylation
ksn = ksn.groupby(['source', 'target']).min().reset_index()
ksn = ksn.groupby(['source', 'target'], observed=True).min().reset_index()

return ksn

Expand Down
18 changes: 10 additions & 8 deletions decoupler/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def plot_metrics_scatter(df, x='auroc', y='auprc', groupby=None, show_text=True,
# Reformat df
sub = (
df[msk]
.groupby(['method', 'metric'])
.groupby(['method', 'metric'], observed=True)
.mean(numeric_only=True).reset_index()
.pivot(index='method', columns='metric', values='score').reset_index()
)
Expand Down Expand Up @@ -861,10 +861,10 @@ def plot_metrics_boxplot(df, metric, groupby=None, figsize=(5, 5), dpi=100, ax=N
# Compute order
order = (
df
.groupby(['method', groupby])
.groupby(['method', groupby], observed=True)
.mean(numeric_only=True)
.reset_index()
.groupby('method')
.groupby('method', observed=True)
.max()
.sort_values('score')
.index
Expand Down Expand Up @@ -1709,7 +1709,8 @@ def get_source_idxs(n_sources, act, by_abs):
else:
s_idx = np.argsort(-act.values[0])[:n_sources]
else:
raise ValueError('n_sources needs to be a list of source names or an integer number, {0} was passed.'.format(type(n_sources)))
raise ValueError('n_sources needs to be a list of source names or an \
integer number, {0} was passed.'.format(type(n_sources)))
return s_idx


Expand All @@ -1725,13 +1726,14 @@ def get_target_idxs(n_targets, obs, net, by_abs):
t_idx = (
net
.sort_values(['source', 'prod'], ascending=[True, False])
.groupby(['source'])
.groupby(['source'], observed=True)
.head(n_targets)
.index
.values
)
else:
raise ValueError('n_targets needs to be a list of target names or an integer number, {0} was passed.'.format(type(n_targets)))
raise ValueError('n_targets needs to be a list of target names or an \
integer number, {0} was passed.'.format(type(n_targets)))
return t_idx


Expand Down Expand Up @@ -1766,8 +1768,8 @@ def get_obs_act_net(act, obs, net, n_sources, n_targets, by_abs):
def add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap):

mpl = check_if_matplotlib(return_mpl=True)
s_cmap = mpl.cm.get_cmap(s_cmap)
t_cmap = mpl.cm.get_cmap(t_cmap)
s_cmap = mpl.colormaps.get_cmap(s_cmap)
t_cmap = mpl.colormaps.get_cmap(t_cmap)

color = []
for i, k in enumerate(g.vs['label']):
Expand Down
8 changes: 8 additions & 0 deletions decoupler/tests/test_omnip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_resource,
show_resources,
get_dorothea,
merge_genes_to_complexes,
get_collectri,
get_ksn_omnipath
)
Expand Down Expand Up @@ -36,6 +37,13 @@ def test_get_dorothea():
get_dorothea(organism='mouse')


def test_():
df = pd.DataFrame()
df['source_genesymbol'] = ['JUN1', 'JUN2', 'RELA', 'NFKB3', 'STAT1']
merge_genes_to_complexes(df)
assert df['source_genesymbol'].unique().size == 3


def test_get_collectri():
df = get_collectri(organism='human', split_complexes=False)
assert type(df) is pd.DataFrame
Expand Down

0 comments on commit 9da10c9

Please sign in to comment.