-
Notifications
You must be signed in to change notification settings - Fork 127
/
main.py
122 lines (91 loc) · 3.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import sys
import os
import shutil
import time
import traceback
from flask import Flask, request, jsonify
import pandas as pd
from sklearn.externals import joblib
app = Flask(__name__)
# inputs
training_data = 'data/titanic.csv'
include = ['Age', 'Sex', 'Embarked', 'Survived']
dependent_variable = include[-1]
model_directory = 'model'
model_file_name = '%s/model.pkl' % model_directory
model_columns_file_name = '%s/model_columns.pkl' % model_directory
# These will be populated at training time
model_columns = None
clf = None
@app.route('/predict', methods=['POST'])
def predict():
if clf:
try:
json_ = request.json
query = pd.get_dummies(pd.DataFrame(json_))
# https://github.com/amirziai/sklearnflask/issues/3
# Thanks to @lorenzori
query = query.reindex(columns=model_columns, fill_value=0)
prediction = list(clf.predict(query))
# Converting to int from int64
return jsonify({"prediction": list(map(int, prediction))})
except Exception as e:
return jsonify({'error': str(e), 'trace': traceback.format_exc()})
else:
print('train first')
return 'no model here'
@app.route('/train', methods=['GET'])
def train():
# using random forest as an example
# can do the training separately and just update the pickles
from sklearn.ensemble import RandomForestClassifier as rf
df = pd.read_csv(training_data)
df_ = df[include]
categoricals = [] # going to one-hot encode categorical variables
for col, col_type in df_.dtypes.items():
if col_type == 'O':
categoricals.append(col)
else:
df_[col].fillna(0, inplace=True) # fill NA's with 0 for ints/floats, too generic
# get_dummies effectively creates one-hot encoded variables
df_ohe = pd.get_dummies(df_, columns=categoricals, dummy_na=True)
x = df_ohe[df_ohe.columns.difference([dependent_variable])]
y = df_ohe[dependent_variable]
# capture a list of columns that will be used for prediction
global model_columns
model_columns = list(x.columns)
joblib.dump(model_columns, model_columns_file_name)
global clf
clf = rf()
start = time.time()
clf.fit(x, y)
joblib.dump(clf, model_file_name)
message1 = 'Trained in %.5f seconds' % (time.time() - start)
message2 = 'Model training score: %s' % clf.score(x, y)
return_message = 'Success. \n{0}. \n{1}.'.format(message1, message2)
return return_message
@app.route('/wipe', methods=['GET'])
def wipe():
try:
shutil.rmtree('model')
os.makedirs(model_directory)
return 'Model wiped'
except Exception as e:
print(str(e))
return 'Could not remove and recreate the model directory'
if __name__ == '__main__':
try:
port = int(sys.argv[1])
except Exception as e:
port = 80
try:
clf = joblib.load(model_file_name)
print('model loaded')
model_columns = joblib.load(model_columns_file_name)
print('model columns loaded')
except Exception as e:
print('No model here')
print('Train first')
print(str(e))
clf = None
app.run(host='0.0.0.0', port=port, debug=True)