Skip to content

Commit

Permalink
WIP: debugging ffi mode
Browse files Browse the repository at this point in the history
  • Loading branch information
hvasbath committed Jan 26, 2024
1 parent 7c754d1 commit da2dba4
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 101 deletions.
5 changes: 3 additions & 2 deletions beat/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def setup(parser):

reference_sources = bconfig.init_reference_sources(
source_points,
n_sources,
n_sources[0],
c.problem_config.source_types[0],
c.problem_config.stf_type,
event=c.event,
Expand All @@ -692,7 +692,7 @@ def setup(parser):
new_bounds = {}
for param in ["time"]:
new_bounds[param] = extract_bounds_from_summary(
summarydf, varname=param, shape=(n_sources,), roundto=0
summarydf, varname=param, shape=(n_sources[0],), roundto=0
)
new_bounds[param].append(point[param])

Expand Down Expand Up @@ -1599,6 +1599,7 @@ def setup(parser):

targets = heart.init_geodetic_targets(
datasets,
event=c.event,
earth_model_name=gf.earth_model_name,
interpolation=c.geodetic_config.interpolation,
crust_inds=[crust_ind],
Expand Down
6 changes: 3 additions & 3 deletions beat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,14 +2206,14 @@ def init_dataset_config(config, datatype, mode):
' "geometry" mode: "%s"!' % (source_types[0], geometry_source_type)
)

n_sources = gmc.problem_config.n_sources[0]
n_sources = gmc.problem_config.n_sources
point = {k: v.testvalue for k, v in gmc.problem_config.priors.items()}
point = utility.adjust_point_units(point)
source_points = utility.split_point(point, n_sources_total=n_sources)
source_points = utility.split_point(point, n_sources_total=n_sources[0])

reference_sources = init_reference_sources(
source_points,
n_sources,
n_sources[0],
geometry_source_type,
gmc.problem_config.stf_type,
event=gmc.event,
Expand Down
3 changes: 2 additions & 1 deletion beat/ffi/fault.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,8 @@ def point2sources(self, point, events=[]):

slips = self.get_total_slip(index, point)
rakes = num.arctan2(-ucomps["uperp"], ucomps["uparr"]) * r2d + sf.rake
opening_fractions = ucomps["utens"] / slips
opening_fractions = num.divide(
ucomps["utens"], slips, out=np.zeros_like(slips), where=slips!=0)

sf_point = {
"slip": slips,
Expand Down
1 change: 1 addition & 0 deletions beat/models/seismic.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def point2sources(self, point):
tpoint,
mapping=self.mapping,
n_sources_total=self.n_sources_total,
weed_params=True,
)

for i, source in enumerate(self.sources):
Expand Down
2 changes: 1 addition & 1 deletion beat/plotting/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def draw_3d_slip_distribution(problem, po):
mtrace = None

datatype, cconf = list(problem.composites.items())[0]
fault = cconf.load_fault_geometry()

if po.plot_projection in ["local", "latlon"]:
perspective = "135/30"
Expand All @@ -849,7 +850,6 @@ def draw_3d_slip_distribution(problem, po):
if gc:
for corr in gc.corrections_config.euler_poles:
if corr.enabled:
fault = cconf.load_fault_geometry()
if len(po.varnames) > 0 and po.varnames[0] in varname_choices:
from beat.ffi import euler_pole2slips

Expand Down
94 changes: 3 additions & 91 deletions beat/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def perform(self, node, inputs, output):
mpoint,
mapping=self.mapping,
n_sources_total=self.n_sources_total,
)
weed_params=True,
)

for i, source in enumerate(self.sources):
utility.update_source(source, **source_points[i])
Expand Down Expand Up @@ -222,96 +223,6 @@ def infer_shape(self, node, input_shapes):
return [(len(self.lats), 3)]


class GeoInterseismicSynthesizer(tt.Op):
"""
DEPRECATED!
pytensor wrapper to transform the parameters of block model to
parameters of a fault.
"""

__props__ = ("lats", "lons", "engine", "targets", "sources", "reference")

def __init__(self, lats, lons, engine, targets, sources, reference):
self.lats = tuple(lats)
self.lons = tuple(lons)
self.engine = engine
self.targets = tuple(targets)
self.sources = tuple(sources)
self.reference = reference

def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
self.__dict__.update(state)

def make_node(self, inputs):
"""
Transforms pytensor tensors to node and allocates variables accordingly.
Parameters
----------
inputs : dict
keys being strings of source attributes of the
:class:`pyrocko.gf.seismosizer.RectangularSource` that was used
to initialise the Operator.
values are :class:`pytensor.tensor.Tensor`
"""
inlist = []

self.fixed_values = {}
self.varnames = []

for k, v in inputs.items():
if isinstance(v, tt.TensorVariable):
self.varnames.append(k)
inlist.append(tt.as_tensor_variable(v))
else:
self.fixed_values[k] = v

out = tt.as_tensor_variable(num.zeros((2, 2)))
outlist = [out.type()]
return Apply(self, inlist, outlist)

def perform(self, node, inputs, output):
"""
Perform method of the Operator to calculate synthetic displacements.
Parameters
----------
inputs : list
of :class:`numpy.ndarray`
output : list
of synthetic displacements of :class:`numpy.ndarray` (n x 3)
"""
z = output[0]

point = {vname: i for vname, i in zip(self.varnames, inputs)}
point.update(self.fixed_values)

point = utility.adjust_point_units(point)
spoint, bpoint = interseismic.seperate_point(point)

source_points = utility.split_point(spoint)

for i, source_point in enumerate(source_points):
self.sources[i].update(**source_point)

z[0] = interseismic.geo_backslip_synthetics(
engine=self.engine,
targets=self.targets,
sources=self.sources,
lons=num.array(self.lons),
lats=num.array(self.lats),
reference=self.reference,
**bpoint,
)

def infer_shape(self, node, input_shapes):
return [(len(self.lats), 3)]


class SeisSynthesizer(tt.Op):
"""
pytensor wrapper for a seismic forward model with synthetic waveforms.
Expand Down Expand Up @@ -449,6 +360,7 @@ def perform(self, node, inputs, output):
mpoint,
mapping=self.mapping,
n_sources_total=self.n_sources_total,
weed_params=True,
)

for i, source in enumerate(self.sources):
Expand Down
10 changes: 7 additions & 3 deletions beat/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ def weed_input_rvs(input_rvs, mode, datatype):
"duration",
"peak_ratio",
] + burian

elif mode == "ffi":
tobeweeded = []
else:
raise TypeError(f"Mode {mode} not supported!")

Expand Down Expand Up @@ -671,7 +672,7 @@ def adjust_point_units(point):
return mpoint


def split_point(point, mapping=None, n_sources_total=1, weed_params=True):
def split_point(point, mapping=None, n_sources_total=1, weed_params=False):
"""
Split point in solution space into List of dictionaries with source
parameters for each source.
Expand All @@ -692,8 +693,11 @@ def split_point(point, mapping=None, n_sources_total=1, weed_params=True):
source_points : list
of :func:`pymc.model.Point`
"""
if mapping is not None:
point_to_sources = mapping.point_to_sources_mapping()
else:
point_to_sources = None

point_to_sources = mapping.point_to_sources_mapping()
if weed_params:
source_parameter_names = mapping.point_variable_names()
for param in list(point.keys()):
Expand Down

0 comments on commit da2dba4

Please sign in to comment.