forked from nasa/prog_algs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predictor_template.py
75 lines (61 loc) · 3.4 KB
/
predictor_template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.
from copy import deepcopy
from prog_algs.predictors import Predictor, Prediction
# Replace the following with whatever form of UncertainData you would like to use to represent ToE
from prog_algs.uncertain_data import ScalarData
class TemplatePredictor(Predictor):
"""
Template class for performing model-based prediction
"""
# REPLACE THE FOLLOWING LIST WITH CONFIGURED PARAMETERS
default_parameters = { # Default Parameters, used as config for UKF
'Example Parameter': 0.0
}
def __init__(self, model, **kwargs):
"""
Constructor (optional)
"""
super().__init__(model, **kwargs)
# ADD PARAMETER CHECKS HERE
# e.g., self.parameters['some_value'] < 0
# INITIALIZE PREDICTOR
def predict(self, state, future_loading_eqn, **kwargs):
"""
Perform a single prediction
Parameters
----------
state : UncertainData
Estimate of the state at the time of prediction, reprecented by UncertainData
future_loading_eqn : function (t, x) -> z
Function to generate an estimate of loading at future time t and state z
options : dict, optional
Dictionary of any additional configuration values. See default parameters, above
Returns (tuple)
-------
times : List[float]
Times for each savepoint such that inputs.snapshot(i), states.snapshot(i), outputs.snapshot(i), and event_states.snapshot(i) are all at times[i]
inputs : Prediction
Inputs at each savepoint such that inputs.snapshot(i) is the input distribution (type UncertainData) at times[i]
states : Prediction
States at each savepoint such that states.snapshot(i) is the state distribution (type UncertainData) at times[i]
outputs : Prediction
Outputs at each savepoint such that outputs.snapshot(i) is the output distribution (type UncertainData) at times[i]
event_states : Prediction
Event states at each savepoint such that event_states.snapshot(i) is the event state distribution (type UncertainData) at times[i]
time_of_event : UncertainData
Distribution of predicted Time of Event (ToE) for each predicted event, represented by some subclass of UncertaintData (e.g., MultivariateNormalDist)
"""
params = deepcopy(self.parameters) # copy default parameters
params.update(kwargs)
# PERFORM PREDICTION HERE, REPLACE THE FOLLOWING LISTS
# Times of each savepoint (specified by savepts and save_freq)
times = [] # array of float (e.g., [0.0, 0.5, 1.0, ...])
# Inputs, State, Outputs, and Event States at each savepoint are stored by type Prediction
# Replace [] with estimates of the appropriate property in the form of a subclass of UncertainData (e.g, ScalarData)
inputs = Prediction(times, [])
states = Prediction(times, [])
outputs = Prediction(times, [])
event_states = Prediction(times, [])
# Time of event is represented by some type of UncertainData (e.g., MultivariateNormalDist)
time_of_event = ScalarData({'event1': 748, 'event2': 300})
return (times, inputs, states, outputs, event_states, time_of_event)