Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Run Manager for New Infrastructure #139

Open
wants to merge 36 commits into
base: jim-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ee08171
summary output in run manager
zipengwang98 Jul 31, 2024
9b8de58
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
e96b34e
Merge pull request #117 from thomasckng/run-manager
thomasckng Jul 31, 2024
38bb277
Merge pull request #134 from kazewong/98-moving-naming-tracking-into-…
thomasckng Aug 21, 2024
1c21605
Remove legacy test files
thomasckng Aug 21, 2024
19c4a29
Update utils.py
thomasckng Aug 21, 2024
169aaa2
Reformat
thomasckng Aug 21, 2024
0f687f7
Update runmanger
thomasckng Aug 21, 2024
1dbbd55
Update runmanager
thomasckng Aug 21, 2024
cb07d4e
Reformat
thomasckng Aug 21, 2024
e5b6e84
Change default path
thomasckng Aug 21, 2024
b1133d3
Update runManager.py
xuyuon Aug 22, 2024
2a9d696
Update runManager.py
xuyuon Aug 22, 2024
15a103d
Merge pull request #135 from thomasckng/run-manager
ThibeauWouters Aug 22, 2024
7050ca8
Merge pull request #12 from kazewong/run-manager
xuyuon Aug 29, 2024
ec1c90f
Fix bug in initializing likelihood
xuyuon Aug 29, 2024
fd63c8b
Fix bug in save_summary()
xuyuon Aug 29, 2024
77a75dc
Fix bug
thomasckng Aug 30, 2024
863f34c
Merge pull request #140 from thomasckng/run-manager-dev
xuyuon Aug 30, 2024
9b19b9f
Merge pull request #13 from kazewong/run-manager
xuyuon Aug 30, 2024
55a79e8
Fix transform initialization
xuyuon Aug 30, 2024
f046a70
Fix stdout issue
xuyuon Aug 30, 2024
5d1458e
Reformatted
xuyuon Aug 30, 2024
ca9f906
Merge pull request #141 from xuyuon/fix-run-manager-bugs
xuyuon Aug 30, 2024
522d7c0
Added MultipleEventPERunManager in runManager.py
xuyuon Aug 30, 2024
06e22b7
Added Multiple_event_runManager.py
xuyuon Aug 30, 2024
186b9d9
Fix bug in load_from_path
xuyuon Aug 31, 2024
de18eda
Added error log in MultipleEventPERunManager
xuyuon Aug 31, 2024
3ddc1fc
Fixed typo in Multiple_event_runManager.py
xuyuon Aug 31, 2024
2d3976f
reformatted
xuyuon Aug 31, 2024
c906c67
Added documentation in Multiple_event_runManager.py
xuyuon Sep 2, 2024
7309e44
Added try block in load_from_path()
xuyuon Sep 2, 2024
337c21e
Fixed try block in load_from_path()
xuyuon Sep 2, 2024
1fb0b29
reformatted
xuyuon Sep 2, 2024
3a02bdc
reformatted
xuyuon Sep 2, 2024
e1af4a9
Merge pull request #145 from xuyuon/run-manager
thomasckng Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 58 additions & 54 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import jax
import jax.numpy as jnp

Expand All @@ -12,58 +11,50 @@
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}
bounds = jnp.array(
[
[10.0, 40.0],
[0.125, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
[0.0, 2000.0],
[-0.05, 0.05],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
[0.0, jnp.pi],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
]
)


run = SingleEventRun(
seed=0,
detectors=["H1", "L1"],
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
},
priors={
"M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "MassRatio"},
"s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"cos_iota": {"name": "CosIota"},
"psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"sin_dec": {"name": "SinDec"},
"M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0},
"s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0},
"t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi},
"iota": {"name": "SinePrior"},
"psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi},
"dec": {"name": "CosinePrior"},
},
waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
jim_parameters={
"n_loop_training": 10,
"n_loop_production": 10,
"n_local_steps": 150,
"n_global_steps": 150,
"n_chains": 500,
"n_epochs": 50,
"learning_rate": 0.001,
"n_max_examples": 45000,
"momentum": 0.9,
"batch_size": 50000,
"use_global": True,
"keep_quantile": 0.0,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds},
likelihood_parameters={"name": "TransientLikelihoodFD"},
sample_transforms=[
{"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,},
{"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,},
{"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,},
{"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
],
likelihood_transforms=[
{"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]},
],
injection=True,
injection_parameters={
"M_c": 28.6,
Expand All @@ -78,15 +69,28 @@
"ra": 1.2,
"dec": 0.3,
},
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
jim_parameters={
"n_loop_training": 100,
"n_loop_production": 20,
"n_local_steps": 10,
"n_global_steps": 1000,
"n_chains": 500,
"n_epochs": 30,
"learning_rate": 1e-4,
"n_max_examples": 30000,
"momentum": 0.9,
"batch_size": 30000,
"use_global": True,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
)

run_manager = SingleEventPERunManager(run=run)
run_manager.sample()

# plot the corner plot and diagnostic plot
run_manager.plot_corner()
run_manager.plot_diagnostic()
run_manager.save_summary()
6 changes: 4 additions & 2 deletions src/jimgw/single_event/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def inject_signal(
h_sky: dict[str, Float[Array, " n_sample"]],
params: dict[str, Float],
psd_file: str = "",
) -> None:
) -> tuple[Float, Float]:
"""
Inject a signal into the detector data.

Expand All @@ -392,7 +392,7 @@ def inject_signal(

Returns
-------
None
SNR
"""
self.frequencies = freqs
self.psd = self.load_psd(freqs, psd_file)
Expand All @@ -415,6 +415,8 @@ def inject_signal(
print(f"The injected optimal SNR is {optimal_SNR}")
print(f"The injected match filter SNR is {match_filter_SNR}")

return optimal_SNR, match_filter_SNR

@jaxtyped(typechecker=typechecker)
def load_psd(
self, freqs: Float[Array, " n_sample"], psd_file: str = ""
Expand Down
Loading
Loading