Basic prediction webserver
This commit is contained in:
parent
40c116b5cf
commit
e40dada301
36
app.py
Normal file
36
app.py
Normal file
@ -0,0 +1,36 @@
|
||||
from bottle import hook, request, response, route, run
|
||||
from sklearn.externals import joblib
|
||||
|
||||
mlb, classifier = joblib.load('offClassifier.pkl')
|
||||
|
||||
|
||||
@hook('after_request')
|
||||
def enable_cors():
|
||||
"""
|
||||
You need to add some headers to each request.
|
||||
Don't use the wildcard '*' for Access-Control-Allow-Origin in production.
|
||||
"""
|
||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
||||
response.headers['Access-Control-Allow-Methods'] = 'PUT, GET, POST, DELETE, OPTIONS'
|
||||
response.headers['Access-Control-Allow-Headers'] = 'Origin, Accept, Content-Type, X-Requested-With, X-CSRF-Token'
|
||||
|
||||
|
||||
@route('/predict', method=['OPTIONS', 'POST'])
|
||||
def predict():
|
||||
if request.method == 'OPTIONS':
|
||||
return {}
|
||||
|
||||
products = request.json
|
||||
predictions = mlb.inverse_transform(
|
||||
classifier.predict([p['name'] for p in products])
|
||||
)
|
||||
return {
|
||||
'data': [
|
||||
product.update({'predictedCategories': categories}) or product
|
||||
for product, categories in zip(products, predictions)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run(host='localhost', port=4242)
|
2254
notebook.ipynb
2254
notebook.ipynb
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user