Skip to content

Commit

Permalink
Merge branch 'nanograv:master' into wx2pl
Browse files Browse the repository at this point in the history
  • Loading branch information
abhisrkckl authored Jan 24, 2024
2 parents bf7df7c + 258e83c commit 4a4036a
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 44 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,10 @@ docs/examples-rendered/*.py
*.tex
*.aux
*.log
*.pdf
*.pdf

# test output
tests/datafile/par_*.par
tests/datafile/fake_toas.tim
tests/datafile/*.converted.par
tests/datafile/_test_pintempo.out
1 change: 1 addition & 0 deletions CHANGELOG-unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ the released changes.
- Consistent naming in `TimingModel.get_params_mapping()`
- Better exceptions for unsupported/unimplemented binary models (BTX, MSS, etc.)
- Emit warnings when `WaveX`/`DMWaveX` is used together with other representations of red/DM noise
- `get_observatory()` no longer overwrites `include_gps` and `include_bipm` of `Observatory` objects unless explicitly stated (BIPM and GPS clock corrections no longer incorrectly applied to BAT TOAs).
### Removed
74 changes: 35 additions & 39 deletions src/pint/observatory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
necessary.
"""

from copy import deepcopy
import os
import textwrap
from collections import defaultdict
Expand Down Expand Up @@ -143,24 +144,6 @@ class Observatory:
# standard name.
_alias_map = {}

def __new__(cls, name, *args, **kwargs):
# Generates a new Observatory object instance, and adds it
# it the registry, using name as the key. Name must be unique,
# a new instance with a given name will over-write the existing
# one only if overwrite=True
obs = super().__new__(cls)
if name.lower() in cls._registry:
if "overwrite" not in kwargs or not kwargs["overwrite"]:
raise ValueError(
f"Observatory {name.lower()} already present and overwrite=False"
)
log.warning(f"Observatory '{name.lower()}' already present; overwriting...")

cls._register(obs, name)
return obs
cls._register(obs, name)
return obs

def __init__(
self,
name,
Expand All @@ -171,13 +154,26 @@ def __init__(
bipm_version=bipm_default,
overwrite=False,
):
self._name = name.lower()
self._aliases = (
list(set(map(str.lower, aliases))) if aliases is not None else []
)
if aliases is not None:
Observatory._add_aliases(self, aliases)
self.fullname = fullname if fullname is not None else name
self.include_gps = include_gps
self.include_bipm = include_bipm
self.bipm_version = bipm_version

if name.lower() in Observatory._registry:
if not overwrite:
raise ValueError(
f"Observatory {name.lower()} already present and overwrite=False"
)
log.warning(f"Observatory '{name.lower()}' already present; overwriting...")

Observatory._register(self, name)

@classmethod
def _register(cls, obs, name):
"""Add an observatory to the registry using the specified name
Expand All @@ -186,7 +182,6 @@ def _register(cls, obs, name):
The Observatory instance's name attribute will be updated for
consistency."""
cls._registry[name.lower()] = obs
obs._name = name.lower()

@classmethod
def _add_aliases(cls, obs, aliases):
Expand All @@ -199,10 +194,6 @@ def _add_aliases(cls, obs, aliases):
to ensure consistency."""
for a in aliases:
cls._alias_map[a.lower()] = obs.name
for o in cls._registry.values():
o._aliases = [
alias for alias, name in cls._alias_map.items() if name == o.name
]

@staticmethod
def gps_correction(t, limits="warn"):
Expand Down Expand Up @@ -460,7 +451,7 @@ def posvel(self, t, ephem, group=None):


def get_observatory(
name, include_gps=True, include_bipm=True, bipm_version=bipm_default
name, include_gps=None, include_bipm=None, bipm_version=bipm_default
):
"""Convenience function to get observatory object with options.
Expand All @@ -474,22 +465,30 @@ def get_observatory(
----------
name : str
The name of the observatory
include_gps : bool, optional
Set False to disable UTC(GPS)->UTC clock correction.
include_bipm : bool, optional
Set False to disable TAI TT(BIPM) clock correction.
include_gps : bool or None, optional
Override UTC(GPS)->UTC clock correction.
include_bipm : bool or None, optional
Override TAI TT(BIPM) clock correction.
bipm_version : str, optional
Set the version of TT BIPM clock correction files.
.. note:: This function can and should be expanded if more clock
file switches/options are added at a public API level.
"""
site = Observatory.get(name)
site.include_gps = include_gps
site.include_bipm = include_bipm
site.bipm_version = bipm_version
return site
if include_bipm is not None or include_gps is not None:
site = deepcopy(Observatory.get(name))

if include_gps is not None:
site.include_gps = include_gps

if include_bipm is not None:
site.include_bipm = include_bipm
site.bipm_version = bipm_version

return site

return Observatory.get(name)


def earth_location_distance(loc1, loc2):
Expand Down Expand Up @@ -533,8 +532,8 @@ def compare_t2_observatories_dat(t2dir=None):
for line in interesting_lines(f, comments="#"):
try:
x, y, z, full_name, short_name = line.split()
except ValueError:
raise ValueError(f"unrecognized line '{line}'")
except ValueError as e:
raise ValueError(f"unrecognized line '{line}'") from e
x, y, z = float(x), float(y), float(z)
full_name, short_name = full_name.lower(), short_name.lower()
topo_obs_entry = textwrap.dedent(
Expand Down Expand Up @@ -831,10 +830,7 @@ def find_clock_file(
"""
# Avoid import loop
from pint.observatory.clock_file import ClockFile, GlobalClockFile
from pint.observatory.global_clock_corrections import (
Index,
get_clock_correction_file,
)
from pint.observatory.global_clock_corrections import Index

if name == "":
raise ValueError("No filename supplied to find_clock_file")
Expand Down
2 changes: 1 addition & 1 deletion src/pint/observatory/satellite_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def __init__(
overwrite=False,
):
super().__init__(
self,
name,
include_gps=include_gps,
include_bipm=include_bipm,
Expand Down Expand Up @@ -444,4 +443,5 @@ def get_satellite_observatory(name, ft2name, **kwargs):
# values as new observatories are added.
if "maxextrap" not in kwargs:
kwargs["maxextrap"] = 2

return SatelliteObs(name, ft2name, **kwargs)
3 changes: 2 additions & 1 deletion src/pint/observatory/topo_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ def load_observatories(filename=observatories_json, overwrite=False):
def load_observatories_from_usual_locations(clear=False):
"""Load observatories from the default JSON file as well as ``$PINT_OBS_OVERRIDE``, optionally clearing the registry
Running with ``clear=True`` will return PINT to the state it is on import. Running with ``clear=False`` may result in conflicting definitions if observatories have already been imported.
Running with ``clear=True`` will return PINT to the state it is on import.
Running with ``clear=False`` may result in conflicting definitions if observatories have already been imported.
Parameters
----------
Expand Down
Empty file modified tests/test_eventstats.py
100755 → 100644
Empty file.
1 change: 0 additions & 1 deletion tests/test_fermiphase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import pint.models
import pint.scripts.fermiphase as fermiphase
import pint.toa as toa
from pint.fermi_toas import get_Fermi_TOAs, _default_uncertainty
from pint.observatory.satellite_obs import get_satellite_observatory
from pinttestdata import datadir
Expand Down
17 changes: 16 additions & 1 deletion tests/test_observatory.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,11 @@ def test_json_observatory_output(sandbox):
gbt_reload = get_observatory("gbt")

for p in gbt_orig.__dict__:
if p not in ["_clock"]:
if p not in ["_clock", "_aliases"]:
assert getattr(gbt_orig, p) == getattr(gbt_reload, p)

assert set(gbt_orig._aliases) == set(gbt_reload._aliases)


def test_json_observatory_input_latlon(sandbox):
gbt_orig = get_observatory("gbt")
Expand Down Expand Up @@ -336,3 +338,16 @@ def test_compare_t2_observatories_dat():
def test_compare_tempo_obsys_dat():
s = compare_tempo_obsys_dat(testdatadir / "observatory")
assert isinstance(s, defaultdict)


def test_ssb_obs():
ssb = Observatory.get("@")
assert not ssb.include_bipm and not ssb.include_gps

ssb = get_observatory("@")
assert not ssb.include_bipm and not ssb.include_gps

# get_observatory changes the state of the registered
# Observatory objects. So this needs to be repeated.
ssb = Observatory.get("@")
assert not ssb.include_bipm and not ssb.include_gps

0 comments on commit 4a4036a

Please sign in to comment.