-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add file monitoring with marshmallow/swagger
- Loading branch information
Showing
3 changed files
with
186 additions
and
46 deletions.
There are no files selected for viewing
89 changes: 62 additions & 27 deletions
89
regression_problems/predict_real_estate_prices_in_california/flask/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,84 @@ | ||
from flask import Flask, request, jsonify, render_template | ||
from flask import Flask, request, jsonify, render_template, send_from_directory | ||
import pickle | ||
import pandas as pd | ||
from marshmallow import Schema, fields, ValidationError | ||
from flask_swagger_ui import get_swaggerui_blueprint | ||
from flask_wtf.csrf import CSRFProtect | ||
from flask_limiter import Limiter | ||
from flask_limiter.util import get_remote_address | ||
|
||
app = Flask(__name__) | ||
|
||
# Configuração da proteção CSRF | ||
csrf = CSRFProtect(app) | ||
|
||
# Configuração de rate limiting | ||
limiter = Limiter( | ||
get_remote_address, | ||
app=app, | ||
default_limits=["200 per day", "50 per hour"] | ||
) | ||
|
||
# Swagger setup http://localhost:5000/swagger | ||
SWAGGER_URL = '/swagger' | ||
API_URL = '/static_swagger/swagger.json' | ||
swaggerui_blueprint = get_swaggerui_blueprint(SWAGGER_URL, API_URL, config={'app_name': "Real Estate Price Prediction API"}) | ||
app.register_blueprint(swaggerui_blueprint, url_prefix=SWAGGER_URL) | ||
|
||
# Adicionar uma rota para servir o arquivo swagger.json | ||
@app.route('/static_swagger/<path:filename>') | ||
def serve_static(filename): | ||
return send_from_directory('static_swagger', filename) | ||
|
||
# Carregar o modelo treinado | ||
model = pickle.load(open('../model_trained/model_regression_immobile.pkl', 'rb')) | ||
|
||
# Carregar a base de dados | ||
#data = pd.read_csv('../data/arquivo.csv') | ||
# Definição do schema de validação usando marshmallow | ||
class InputSchema(Schema): | ||
MedInc = fields.Float(required=True) | ||
HouseAge = fields.Float(required=True) | ||
AveRooms = fields.Float(required=True) | ||
AveBedrms = fields.Float(required=True) | ||
Population = fields.Int(required=True) | ||
AveOccup = fields.Float(required=True) | ||
RoomDensity = fields.Float(required=True) | ||
Latitude = fields.Float(required=True) | ||
Longitude = fields.Float(required=True) | ||
|
||
@app.route('/') | ||
def index(): | ||
return render_template('index.html') | ||
|
||
@app.route('/predict', methods=['POST']) | ||
@limiter.limit("10 per minute") # Limite de 10 requisições por minuto para este endpoint | ||
@csrf.exempt # Exemplo: desabilitar CSRF para a rota de API (opcional, com cuidado) | ||
def predict(): | ||
# Receber dados do cliente | ||
req_data = request.get_json() | ||
|
||
# Verificar se todos os campos necessários estão presentes | ||
required_fields = ['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude'] | ||
if not all(field in req_data for field in required_fields): | ||
return jsonify({'error': 'Missing required fields'}), 400 | ||
|
||
# Criar DataFrame a partir dos dados recebidos | ||
input_data = { | ||
'MedInc': req_data['MedInc'], | ||
'HouseAge': req_data['HouseAge'], | ||
'AveRooms': req_data['AveRooms'], | ||
'AveBedrms': req_data['AveBedrms'], | ||
'Population': req_data['Population'], | ||
'AveOccup': req_data['AveOccup'], | ||
'Latitude': req_data['Latitude'], | ||
'Longitude': req_data['Longitude'] | ||
} | ||
|
||
input_df = pd.DataFrame([input_data]) | ||
|
||
try: | ||
# Receber dados do cliente | ||
req_data = request.get_json() | ||
|
||
# Validar entrada | ||
schema = InputSchema() | ||
validated_data = schema.load(req_data) | ||
|
||
# Criar DataFrame a partir dos dados validados | ||
input_df = pd.DataFrame([validated_data]) | ||
|
||
# Prever usando o modelo carregado | ||
prediction = model.predict(input_df) | ||
return jsonify({'prediction': prediction[0]}) | ||
except ValidationError as err: | ||
# Erros de validação do input | ||
return jsonify({'error': 'Validation Error', 'messages': err.messages}), 400 | ||
except KeyError as err: | ||
# Se algum campo esperado não estiver presente | ||
return jsonify({'error': f'Missing required field: {err}'}), 400 | ||
except ValueError as err: | ||
# Erros relacionados ao valor dos dados de entrada | ||
return jsonify({'error': f'Invalid input: {err}'}), 400 | ||
except Exception as e: | ||
return jsonify({'error': str(e)}), 400 | ||
# Erro geral | ||
return jsonify({'error': 'An error occurred', 'message': str(e)}), 500 | ||
|
||
if __name__ == "__main__": | ||
app.run(host='0.0.0.0', port=5000) | ||
app.run(host='0.0.0.0', port=5000) |
82 changes: 82 additions & 0 deletions
82
...ssion_problems/predict_real_estate_prices_in_california/flask/static_swagger/swagger.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
{ | ||
"swagger": "2.0", | ||
"info": { | ||
"description": "API para prever preços de imóveis", | ||
"version": "1.0.0", | ||
"title": "Real Estate Price Prediction API" | ||
}, | ||
"host": "localhost:5000", | ||
"basePath": "/", | ||
"schemes": [ | ||
"http" | ||
], | ||
"paths": { | ||
"/predict": { | ||
"post": { | ||
"tags": ["prediction"], | ||
"summary": "Obter previsão de preço", | ||
"description": "Envia dados de entrada e obtém a previsão de preço do imóvel.", | ||
"consumes": ["application/json"], | ||
"produces": ["application/json"], | ||
"parameters": [ | ||
{ | ||
"in": "body", | ||
"name": "body", | ||
"description": "Dados de entrada", | ||
"required": true, | ||
"schema": { | ||
"$ref": "#/definitions/InputData" | ||
} | ||
} | ||
], | ||
"responses": { | ||
"200": { | ||
"description": "Previsão obtida com sucesso", | ||
"schema": { | ||
"type": "object", | ||
"properties": { | ||
"prediction": { | ||
"type": "number", | ||
"example": 150000.0 | ||
} | ||
} | ||
} | ||
}, | ||
"400": { | ||
"description": "Erro de validação de entrada" | ||
}, | ||
"500": { | ||
"description": "Erro interno do servidor" | ||
} | ||
} | ||
} | ||
} | ||
}, | ||
"definitions": { | ||
"InputData": { | ||
"type": "object", | ||
"required": [ | ||
"MedInc", | ||
"HouseAge", | ||
"AveRooms", | ||
"AveBedrms", | ||
"Population", | ||
"AveOccup", | ||
"RoomDensity", | ||
"Latitude", | ||
"Longitude" | ||
], | ||
"properties": { | ||
"MedInc": {"type": "number", "example": 8.3252}, | ||
"HouseAge": {"type": "number", "example": 41.0}, | ||
"AveRooms": {"type": "number", "example": 6.9841}, | ||
"AveBedrms": {"type": "number", "example": 1.0238}, | ||
"Population": {"type": "integer", "example": 322}, | ||
"AveOccup": {"type": "number", "example": 2.5556}, | ||
"Latitude": {"type": "number", "example": 37.88}, | ||
"Longitude": {"type": "number", "example": -122.23}, | ||
"RoomDensity": {"type": "number", "example": -122.23} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters