Skip to content

Commit

Permalink
ffi: running again, tested static
Browse files Browse the repository at this point in the history
- cleaning directory also for last stage
- heart.init_geodetic targets add updating of local coordinate system
- backend.ArrayStepSharedLLK rearrange input point to adhere to value_vars
- ffi.get_magnitude check if moment/slips are zero
  • Loading branch information
hvasbath committed Mar 6, 2024
1 parent da2dba4 commit e30789b
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 16 deletions.
2 changes: 2 additions & 0 deletions beat/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,8 +1241,10 @@ def setup(parser):
)
store = engine.get_store(target.store_id)

logger.debug("n_chains %i", len(chains))
for chain in tqdm(chains):
for idx in idxs:
logger.debug("chain %i idx %i", chain, idx)
point = stage.mtrace.point(idx=idx, chain=chain)
reference.update(point)
# normalize MT source, TODO put into get_derived_params
Expand Down
5 changes: 5 additions & 0 deletions beat/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def step(self, point):
for name, shared_var in self.shared.items():
shared_var.set_value(point[name])

# print("point", point)

# assure order and content of RVs consistent to value_vars
point = {val_var.name:point[val_var.name] for val_var in self.value_vars}

q = self.bij.map(point)
# print("before", q.data)
apoint, alist = self.astep(q.data)
Expand Down
2 changes: 1 addition & 1 deletion beat/ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def filename(self):
def _process_patch_geodetic(engine, gfs, targets, patch, patchidx, los_vectors, odws):
logger.debug("Patch Number %i", patchidx)
logger.debug("Calculating synthetics ...")

logger.debug(patch.__str__())
disp = heart.geo_synthetics(
engine=engine, targets=targets, sources=[patch], outmode="stacked_array"
)
Expand Down
17 changes: 11 additions & 6 deletions beat/ffi/fault.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,10 @@ def get_subfault_patch_moments(
if slips is not None:
rs.update(slip=slips[i])

pm = rs.get_moment(target=target, store=store)
if slips[i] != 0.:
pm = rs.get_moment(target=target, store=store)
else:
pm = 0.
moments.append(pm)

return moments
Expand All @@ -318,7 +321,6 @@ def get_moment(self, point=None, store=None, target=None, datatype="geodetic"):
moments = []
for index in range(self.nsubfaults):
slips = self.get_total_slip(index, point)

sf_moments = self.get_subfault_patch_moments(
index=index, slips=slips, store=store, target=target, datatype=datatype
)
Expand All @@ -330,9 +332,12 @@ def get_magnitude(self, point=None, store=None, target=None, datatype="geodetic"
"""
Get total moment magnitude after Hanks and Kanamori 1979
"""
return moment_to_magnitude(
self.get_moment(point=point, store=store, target=target, datatype=datatype)
)
moment = self.get_moment(
point=point, store=store, target=target, datatype=datatype)
if moment:
return moment_to_magnitude(moment)
else:
return moment

def get_total_slip(self, index=None, point={}, components=None):
"""
Expand Down Expand Up @@ -682,7 +687,7 @@ def point2sources(self, point, events=[]):
slips = self.get_total_slip(index, point)
rakes = num.arctan2(-ucomps["uperp"], ucomps["uparr"]) * r2d + sf.rake
opening_fractions = num.divide(
ucomps["utens"], slips, out=np.zeros_like(slips), where=slips!=0)
ucomps["utens"], slips, out=num.zeros_like(slips), where=slips!=0)

sf_point = {
"slip": slips,
Expand Down
11 changes: 11 additions & 0 deletions beat/heart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,9 @@ def init_geodetic_targets(

em_name = get_earth_model_prefix(earth_model_name)

for data in datasets:
data.update_local_coords(event)

targets = [
gf.StaticTarget(
lons=num.full_like(d.lons, event.lon),
Expand Down Expand Up @@ -4169,6 +4172,14 @@ def geo_synthetics(
returns Nan in displacements if result is invalid!
"""

if False:
# for debugging
for source in sources:
print(source)

for target in targets:
print(target)

response = engine.process(sources, targets)

ns = len(sources)
Expand Down
5 changes: 4 additions & 1 deletion beat/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def sample(step, problem):
start = []
for i in tqdm(range(step.n_chains)):
point = problem.get_random_point()
start.append(problem.lsq_solution(point))
# print(point)
lsq_point = problem.lsq_solution(point)
# print("lsq", lsq_point)
start.append(lsq_point)
else:
start = None

Expand Down
3 changes: 2 additions & 1 deletion beat/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _iter_sample(

logger.debug("Step: Chain_%i step_%i" % (chain, i))
point, out_list = step.step(point)

# print("before buffer", out_list, point)
try:
trace.buffer_write(out_list, step.cumulative_samples)
except BufferError: # buffer full
Expand Down Expand Up @@ -608,6 +608,7 @@ def logp_forw(point, out_vars, in_vars, shared):
shared : List
containing :class:`pytensor.tensor.Tensor` for dependent shared data
"""
logger.debug("Compiling PyTensor function")
out_list, inarray0 = join_nonshared_inputs(point, out_vars, in_vars, shared)
f = compile_pymc([inarray0], out_list) # , on_unused_input="ignore")
f.trust_input = True
Expand Down
8 changes: 5 additions & 3 deletions beat/sampler/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def __init__(
self.backend = backend

# initial point comes in reversed order for whatever reason
self.test_point = OrderedDict(reversed(list(model.initial_point().items())))
# rearrange to order of value_vars
init_point = model.initial_point()
self.test_point = {
val_var.name:init_point[val_var.name] for val_var in self.value_vars}

self.initialize_population(model)
self.compile_model_graph(model)
Expand All @@ -139,7 +142,6 @@ def initialize_population(self, model):
return_inferencedata=False,
)

# print(prior_draws)
self.array_population = num.zeros(self.n_chains)
self.population = []
for i in range(self.n_chains):
Expand Down Expand Up @@ -171,7 +173,7 @@ def compile_model_graph(self, model):
shared=shared,
)

self.prior_logp_func = logp_forw(
self.prior_logp_func = logp_forw(
point=self.test_point,
out_vars=[model.varlogp],
in_vars=self.value_vars, # logp of dists
Expand Down
5 changes: 1 addition & 4 deletions beat/sampler/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,7 @@ def smc_sample(
step.beta = 1.0
save_sampler_state(step, update, stage_handler)

if stage == -1:
chains = []
else:
chains = None
chains = stage_handler.clean_directory(-1, chains, rm_flag)
else:
step.covariance = step.calc_covariance()
step.proposal_dist = choose_proposal(
Expand Down

0 comments on commit e30789b

Please sign in to comment.