forked from opentargets/genetics-finemapping
-
Notifications
You must be signed in to change notification settings - Fork 0
/
1_scan_input_parquets.py
90 lines (77 loc) · 2.7 KB
/
1_scan_input_parquets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Ed Mountjoy
#
'''
# Set SPARK_HOME and PYTHONPATH to use 2.4.0
export PYSPARK_SUBMIT_ARGS="--driver-memory 8g pyspark-shell"
export SPARK_HOME=/Users/em21/software/spark-2.4.0-bin-hadoop2.7
export PYTHONPATH=$SPARK_HOME/python:$SPARK_HOME/python/lib/py4j-2.4.0-src.zip:$PYTHONPATH
'''
import pyspark.sql
from pyspark.sql.types import *
from pyspark.sql.functions import *
from glob import glob
from functools import reduce
def main():
# Make spark session
spark = (
pyspark.sql.SparkSession.builder
.config("spark.master", "local[*]")
.getOrCreate()
)
# sc = spark.sparkContext
print('Spark version: ', spark.version)
# Args
gwas_pval_threshold = 5e-8
# Paths
gwas_pattern = '/home/js29/genetics-finemapping/data/filtered/significant_window_2mb/gwas/*.parquet'
mol_pattern = '/home/js29/genetics-finemapping/data/filtered/significant_window_2mb/molecular_trait/*.parquet'
out_path = '/home/js29/genetics-finemapping/tmp/filtered_input'
# Load GWAS dfs
strip_path_gwas = udf(lambda x: x.replace('file:', '').split('/part-')[0], StringType())
gwas_dfs = (
spark.read.parquet(gwas_pattern)
.withColumn('pval_threshold', lit(gwas_pval_threshold))
.withColumn('input_name', strip_path_gwas(input_file_name()))
)
# Load molecular trait dfs
# This has to be done separately, followed by unionByName as the hive
# parititions differ across datasets due to different tissues
# (bio_features) and chromosomes
strip_path_mol = udf(lambda x: x.replace('file:', ''), StringType())
mol_dfs = []
for inf in glob(mol_pattern):
df = (
spark.read.parquet(inf)
.withColumn('pval_threshold', (0.05 / col('num_tests')))
.withColumn('pval_threshold', when(col('pval_threshold') > gwas_pval_threshold,
col('pval_threshold'))
.otherwise(gwas_pval_threshold))
.drop('num_tests')
.withColumn('input_name', strip_path_mol(lit(inf)))
)
mol_dfs.append(df)
# Take union
df = reduce(
pyspark.sql.DataFrame.unionByName,
[gwas_dfs] + mol_dfs
)
# Process
df = (
df.filter(col('pval') < col('pval_threshold'))
.select('type', 'study_id', 'phenotype_id', 'bio_feature', 'gene_id', 'chrom', 'pval_threshold', 'input_name')
.distinct()
)
# Write
(
df
.coalesce(300)
.write.json(out_path,
compression='gzip',
mode='overwrite')
)
return 0
if __name__ == '__main__':
main()