Skip to content

Commit

Permalink
Merge pull request #82 from thepetabyteproject/fits_bw
Browse files Browse the repository at this point in the history
Freq parameters in fits files
  • Loading branch information
devanshkv authored Nov 13, 2021
2 parents 6ab8c72 + 02fa5e4 commit c6c94f0
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 23 deletions.
1 change: 1 addition & 0 deletions bin/your_heimdall.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
progress=~values.no_progress,
c_min=values.channel_start,
c_max=c_max,
highest_frequency_first=True,
savgol_sigma=values.savgol_sigma,
spectral_kurtosis_sigma=values.spectral_kurtosis_sigma,
savgol_frequency_window=values.savgol_frequency_window,
Expand Down
8 changes: 8 additions & 0 deletions bin/your_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@
type=int,
default=1,
)
parser.add_argument(
"--highest_frequency_first",
help="Write highest frequency first",
required=False,
action="store_true",
default=False,
)
parser.add_argument(
"-npsub",
"--nspectra_per_subint",
Expand Down Expand Up @@ -227,6 +234,7 @@
frequency_decimation_factor=values.frequency_decimation_factor,
replacement_policy=values.replacement_policy,
npoln=values.num_polarisation,
highest_frequency_first=values.highest_frequency_first,
)

if values.type == "fits":
Expand Down
58 changes: 47 additions & 11 deletions tests/test_candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.fixture(scope="function", autouse=True)
def cand():
def cand_fil():
fil_file = os.path.join(_install_dir, "data/28.fil")
cand = Candidate(
fp=fil_file,
Expand All @@ -25,6 +25,22 @@ def cand():
return cand


@pytest.fixture(scope="function", autouse=True)
def cand_fits():
fil_file = os.path.join(_install_dir, "data/28.fits")
cand = Candidate(
fp=fil_file,
dm=475.28400,
tcand=2.0288800,
width=2,
label=-1,
snr=16.8128,
min_samp=256,
device=0,
)
return cand


def test_Candidate():
fits_file = os.path.join(_install_dir, "data/28.fits")
cand = Candidate(
Expand All @@ -40,32 +56,44 @@ def test_Candidate():
assert np.isclose(cand.dispersion_delay(), 0.6254989199749227, atol=1e-3)


def test_candidate_chunk(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_candidate_chunk(cand, request):
cand = request.getfixturevalue(cand)
cand.get_chunk()
assert np.isclose(np.mean(cand.data), 128, atol=1)


def test_dedispersion_none(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_dedispersion_none(cand, request):
cand = request.getfixturevalue(cand)
cand.dedisperse()
assert cand.dedispersed == None


def test_dedisperse(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_dedisperse(cand, request):
cand = request.getfixturevalue(cand)
cand.get_chunk()
cand.dedisperse()
assert np.isclose(np.max(cand.dedispersed.T.sum(0)), 47527, atol=1)
assert np.isclose(np.max(cand.dedispersets()), 47527, atol=1)


def test_snr_none(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_snr_none(cand, request):
cand = request.getfixturevalue(cand)
assert cand.get_snr() == None


def test_optimize_dm(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_optimize_dm(cand, request):
cand = request.getfixturevalue(cand)
assert cand.optimize_dm() == None


def test_dmtime_snr_opt_snr(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_dmtime_snr_opt_snr(cand, request):
cand = request.getfixturevalue(cand)
cand.get_chunk()
cand.dedisperse()
cand.dmtime()
Expand All @@ -79,7 +107,9 @@ def test_dmtime_snr_opt_snr(cand):
assert pytest.approx(cand.optimize_dm()[0], rel=2) == 475


def test_h5(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_h5(cand, request):
cand = request.getfixturevalue(cand)
cand.get_chunk()
cand.dedisperse()
cand.dmtime()
Expand All @@ -88,7 +118,9 @@ def test_h5(cand):
os.remove(str(cand.id) + ".h5")


def test_decimate_on_dedispersed(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_decimate_on_dedispersed(cand, request):
cand = request.getfixturevalue(cand)
cand.get_chunk()
cand.dedisperse()
ts_orig = cand.dedispersed.T.mean(0)
Expand Down Expand Up @@ -119,7 +151,9 @@ def test_decimate_on_dedispersed(cand):
assert (bp_orig - cand.dedispersed[0, :]).sum() == 0


def test_decimate_on_dmt(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_decimate_on_dmt(cand, request):
cand = request.getfixturevalue(cand)
cand.get_chunk()
cand.dmtime()
dmt_orig = cand.dmt
Expand All @@ -138,7 +172,9 @@ def test_decimate_on_dmt(cand):
cand.decimate(key="at", axis=0, pad=True, decimate_factor=4, mode="median")


def test_resize(cand):
@pytest.mark.parametrize("cand", ["cand_fil", "cand_fits"])
def test_resize(cand, request):
cand = request.getfixturevalue(cand)
cand.get_chunk()
cand.dedisperse()
cand.dmtime()
Expand Down
2 changes: 1 addition & 1 deletion your/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from your.writer import *
from your.your import *

__version__ = "0.6.5"
__version__ = "0.6.6"
2 changes: 1 addition & 1 deletion your/formats/filwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def sigproc_object_from_writer(your_writer):
fil_obj.data_type = 0

fil_obj.nchans = your_writer.nchans
fil_obj.foff = your_writer.your_object.your_header.foff
fil_obj.foff = your_writer.foff
fil_obj.fch1 = your_writer.chan_freqs[0]
fil_obj.nbeams = 1
fil_obj.ibeam = 0
Expand Down
16 changes: 8 additions & 8 deletions your/formats/psrfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def __init__(self, psrfitslist):

# Unifying properties with pysigproc
self.npol = self.npoln
self.bw = self.header["OBSBW"]
self.bw = self.specinfo.BW
self.cfreq = self.header["OBSFREQ"]
self.fch1 = self.cfreq - self.bw / 2.0 # Verify
self.fch1 = self.freqs[0] # self.cfreq - self.bw / 2.0 # Verify
self.foff = self.bw / self.nchan
self.nchans = self.nchan
self.tstart = self.specinfo.start_MJD[0]
Expand Down Expand Up @@ -766,12 +766,12 @@ def __init__(self, filenames):
self.bytes_per_subint = self.bytes_per_spectra * self.spectra_per_subint

# Flip the band?
if self.hi_freq < self.lo_freq:
tmp = self.hi_freq
self.hi_freq = self.lo_freq
self.lo_freq = tmp
self.df *= -1.0
self.need_flipband = True
# if self.hi_freq < self.lo_freq:
# tmp = self.hi_freq
# self.hi_freq = self.lo_freq
# self.lo_freq = tmp
# self.df *= -1.0
# self.need_flipband = True
# Compute the bandwidth
self.BW = self.num_channels * self.df
self.mjd = int(self.start_MJD[0])
Expand Down
19 changes: 17 additions & 2 deletions your/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Writer:
c_min (int): Starting channel index (default: 0)
c_max (int): End channel index (default: total number of frequencies)
npoln (int): Number of output polarisations (default: 1)
highest_frequency_first (bool): The output should have highest frequency first (default: False)
outdir (str): Output directory for file
outname (str): Name of the file to write to (without the file extension)
progress (bool): Set to it to false to disable progress bars
Expand All @@ -51,6 +52,7 @@ def __init__(
c_min=None,
c_max=None,
npoln=1,
highest_frequency_first=False,
outdir=None,
outname=None,
flag_rfi=False,
Expand All @@ -75,6 +77,7 @@ def __init__(
self.c_min = c_min
self.c_max = c_max
self.npoln = npoln
self.highest_frequency_first = highest_frequency_first

self.time_decimation_factor = time_decimation_factor
self.frequency_decimation_factor = frequency_decimation_factor
Expand Down Expand Up @@ -163,7 +166,10 @@ def chan_max(self):

@property
def chan_freqs(self):
return self.your_object.chan_freqs[self.chan_min : self.chan_max]
chan_freqs = self.your_object.chan_freqs[self.chan_min : self.chan_max]
if self.highest_frequency_first and chan_freqs[0] < chan_freqs[-1]:
chan_freqs = chan_freqs[::-1]
return chan_freqs

@property
def nchans(self):
Expand All @@ -176,6 +182,13 @@ def tstart(self):
+ self.nstart * self.your_object.your_header.tsamp / (60 * 60 * 24)
)

@property
def foff(self):
if self.highest_frequency_first and self.your_object.your_header.foff > 0:
return -self.your_object.your_header.foff
else:
return self.your_object.your_header.foff

@property
def poln_order(self):
if self.npoln == 1:
Expand Down Expand Up @@ -252,6 +265,8 @@ def get_data_to_write(self, start_sample, nsamp):
data = data.astype(self.your_object.your_header.dtype)

# shape of data is (nt, npoln, nf)
if self.highest_frequency_first and self.your_object.your_header.foff > 0:
data = data[:, :, ::-1]
self.data = data

def to_fil(self, data=None):
Expand Down Expand Up @@ -438,7 +453,7 @@ def dada_header(self):
"""
header = dict()
header["BW"] = str(self.nchans * self.your_object.your_header.foff)
header["BW"] = str(self.nchans * self.foff)
header["FREQ"] = str((self.chan_freqs[0] + self.chan_freqs[-1]) / 2)
tstart = Time(self.tstart, format="mjd")
header["MJD_START"] = str(self.tstart)
Expand Down

0 comments on commit c6c94f0

Please sign in to comment.