-
Notifications
You must be signed in to change notification settings - Fork 7
/
predict.py
59 lines (52 loc) · 2.21 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
from collections import Counter
from access_points import get_scanner
from get_data import get_train_data, get_external_sample, sample, aps_to_dict
from pipeline import get_model
from compat import cross_val_score
def predict_proba(input_path=None, model_path=None, device=""):
lp = get_model(model_path)
data_sample = sample(device) if input_path is None else get_external_sample(input_path)
print(json.dumps(dict(zip(lp.classes_, lp.predict_proba(data_sample)[0]))))
def predict(input_path=None, model_path=None, device=""):
lp = get_model(model_path)
data_sample = sample(device) if input_path is None else get_external_sample(input_path)
return lp.predict(data_sample)[0]
def crossval(clf=None, X=None, y=None, folds=10, n=5, path=None):
if X is None or y is None:
X, y = get_train_data(path)
if len(X) < folds:
raise ValueError('There are not enough samples ({}). Need at least {}.'.format(len(X), folds))
clf = clf or get_model(path)
tot = 0
print("KFold folds={}, running {} times".format(folds, n))
for i in range(n):
res = cross_val_score(clf, X, y, cv=folds).mean()
tot += res
print("{}/{}: {}".format(i + 1, n, res))
print("-------- total --------")
print(tot / n)
return tot / n
def locations(path=None):
_, y = get_train_data(path)
if len(y) == 0:
msg = "No location samples available. First learn a location, e.g. with `cli.py learn -l floor number/name`."
print(msg)
else:
occurrences = Counter(y)
for key, value in occurrences.items():
print("{}: {}".format(key, value))
class Predicter():
def __init__(self, model=None, device=""):
self.model = model
self.device = device
self.clf = get_model(model)
self.wifi_scanner = get_scanner(device)
self.predicted_value = None
def predict(self):
aps = self.wifi_scanner.get_access_points()
self.predicted_value = self.clf.predict(aps_to_dict(aps))[0]
return self.predicted_value
def refresh(self):
self.clf = get_model(self.model)
self.wifi_scanner = get_scanner(self.device)