From a9562a81121dc6d55a870057cae5d9f3294b83f4 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Sun, 18 Feb 2024 07:10:26 +0800 Subject: [PATCH] Fix #232: `tempdir` is not respected in `tl.macs3` --- docs/changelog.md | 1 + .../snapatac2/tools/_call_peaks.py | 73 +++++++++---------- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index fc02f36f2..f7406d476 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ - Fix: #221: 'pp.knn' with 'method=pynndescent' invalid csr matrix. - Fix: #226: Backed AnnData does not support numpy string array. +- Fix: #232: `tempdir` is not respected in `tl.macs3`. ## Release 2.5.3 (released Jan 16, 2024) diff --git a/snapatac2-python/snapatac2/tools/_call_peaks.py b/snapatac2-python/snapatac2/tools/_call_peaks.py index 5d92dd6aa..ac567a006 100644 --- a/snapatac2-python/snapatac2/tools/_call_peaks.py +++ b/snapatac2-python/snapatac2/tools/_call_peaks.py @@ -85,23 +85,14 @@ def macs3( """ from MACS3.Signal.PeakDetect import PeakDetect from math import log - from multiprocess import get_context - from tqdm import tqdm import tempfile if isinstance(groupby, str): groupby = list(adata.obs[groupby]) if replicate is not None and isinstance(replicate, str): replicate = list(adata.obs[replicate]) - if tempdir is None: - tempdir = Path(tempfile.mkdtemp()) - else: - tempdir = Path(tempfile.mkdtemp(dir=tempdir)) - - logging.info("Exporting fragments...") - fragments = _snapatac2.export_tags(adata, tempdir, groupby, replicate, max_frag_size, selections) - # General options + # MACS3 options options = type('MACS3_OPT', (), {})() options.info = lambda _: None options.debug = lambda _: None @@ -130,37 +121,43 @@ def macs3( options.d = extsize options.scanwindow = 2 * options.d - def _call_peaks(tags): - merged, reps = _snapatac2.create_fwtrack_obj(tags) - options.log_qvalue = log(qvalue, 10) * -1 - logging.getLogger().setLevel(logging.CRITICAL + 1) - peakdetect = PeakDetect(treat=merged, opt=options) - peakdetect.call_peaks() - peakdetect.peaks.filter_fc(fc_low = options.fecutoff) - merged = peakdetect.peaks - - others = [] - if replicate_qvalue is not None: - options.log_qvalue = log(replicate_qvalue, 10) * -1 - for x in reps: - peakdetect = PeakDetect(treat=x, opt=options) + with tempfile.TemporaryDirectory(dir=tempdir) as tmpdirname: + logging.info("Exporting fragments...") + fragments = _snapatac2.export_tags(adata, tmpdirname, groupby, replicate, max_frag_size, selections) + + def _call_peaks(tags): + import tempfile + tempfile.tempdir = tmpdirname # Overwrite the default tempdir in MACS3 + merged, reps = _snapatac2.create_fwtrack_obj(tags) + options.log_qvalue = log(qvalue, 10) * -1 + logging.getLogger().setLevel(logging.CRITICAL + 1) + peakdetect = PeakDetect(treat=merged, opt=options) peakdetect.call_peaks() peakdetect.peaks.filter_fc(fc_low = options.fecutoff) - others.append(peakdetect.peaks) - - logging.getLogger().setLevel(logging.INFO) - return _snapatac2.find_reproducible_peaks(merged, others, blacklist) - - logging.info("Calling peaks...") - peaks = _par_map(_call_peaks, [(x,) for x in fragments.values()], n_jobs) - peaks = {k: v for k, v in zip(fragments.keys(), peaks)} - if inplace: - if adata.isbacked: - adata.uns[key_added] = peaks + merged = peakdetect.peaks + + others = [] + if replicate_qvalue is not None: + options.log_qvalue = log(replicate_qvalue, 10) * -1 + for x in reps: + peakdetect = PeakDetect(treat=x, opt=options) + peakdetect.call_peaks() + peakdetect.peaks.filter_fc(fc_low = options.fecutoff) + others.append(peakdetect.peaks) + + logging.getLogger().setLevel(logging.INFO) + return _snapatac2.find_reproducible_peaks(merged, others, blacklist) + + logging.info("Calling peaks...") + peaks = _par_map(_call_peaks, [(x,) for x in fragments.values()], n_jobs) + peaks = {k: v for k, v in zip(fragments.keys(), peaks)} + if inplace: + if adata.isbacked: + adata.uns[key_added] = peaks + else: + adata.uns[key_added] = {k: v.to_pandas() for k, v in peaks.items()} else: - adata.uns[key_added] = {k: v.to_pandas() for k, v in peaks.items()} - else: - return peaks + return peaks def merge_peaks( peaks: dict[str, 'polars.DataFrame'],