Skip to content

Commit

Permalink
Add prophet to model selection for prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
mariamills committed Nov 14, 2023
1 parent 06bbbad commit 536ccce
Show file tree
Hide file tree
Showing 8 changed files with 549 additions and 50 deletions.
127 changes: 81 additions & 46 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
'data': 'data/arima/diff_oil_prices_arima.csv',
'is_time_series': True
},
'prophet': {
'model': joblib.load('models/prophet/prophet_model.pkl'),
'data': 'data/prophet/prophet_df.csv',
'is_time_series': True
}
}


Expand Down Expand Up @@ -97,7 +102,7 @@ def plot():
if df.isnull().values.any() or df.isin([np.inf, -np.inf]).values.any():
return jsonify({"error": "DataFrame contains NaN or Inf values"}), 400

target_name = 'Real Oil Prices' or 'Cushing, OK WTI Spot Price FOB (Dollars per Barrel)'
target_name = 'Real Oil Prices' or 'Cushing, OK WTI Spot Price FOB (Dollars per Barrel)' or 'y'
if target_name not in df.columns:
return jsonify({"error": "'Real Oil Prices' not found in DataFrame"}), 400

Expand Down Expand Up @@ -216,9 +221,10 @@ def evaluate():

print(f"Evaluating Model type: {model_type}, data: {data_path}")

target_column = 'Real Oil Prices' if 'Real Oil Prices' in data.columns else 'Cushing, OK WTI Spot Price FOB (Dollars per Barrel)'
if target_column not in data.columns:
raise ValueError(f"'{target_column}' not found in DataFrame")
target_column_options = ['Real Oil Prices', 'Cushing, OK WTI Spot Price FOB (Dollars per Barrel)', 'y']
target_column = next((col for col in target_column_options if col in data.columns), None)
if target_column is None:
raise ValueError(f"None of the target columns {target_column_options} were found in DataFrame")

# Check if the model is a time series model
if model_data[model_type]['is_time_series']:
Expand All @@ -238,49 +244,78 @@ def evaluate():
# This could be the length of a test set, or a specified future period
num_steps = 30 # 30 days

# Make predictions using get_forecast
forecast = model.get_forecast(steps=num_steps)
print(f"Forecast: {forecast}")

forecast_mean = forecast.predicted_mean
print(f"Forecast Mean: {forecast_mean}")

prediction = forecast_mean
print(f"Prediction: {prediction}")

# Calculate metrics - can't give metrics for the future
mse, mae, rmse, mape, r2 = [None] * 5
print(f"Metrics: mse [Time Series]={mse}, mae={mae}, rmse={rmse}, mape={mape}, r2={r2}")

# Generate and save the future forecast plot
plt.figure(figsize=(12, 6))
historical_data = pd.read_csv('data/Combined_Log_Clean.csv')[target_column]
future_forecast_mean = forecast.predicted_mean
future_conf_int = forecast.conf_int() # Assuming this method gives you confidence intervals

plt.plot(historical_data.index, historical_data, label='Historical')
plt.plot(np.arange(len(historical_data), len(historical_data) + len(future_forecast_mean)),
future_forecast_mean, color='green', label='Future Forecast')
plt.fill_between(np.arange(len(historical_data), len(historical_data) + len(future_forecast_mean)),
future_conf_int.iloc[:, 0], future_conf_int.iloc[:, 1], color='lightgreen', alpha=0.5,
label='95% Confidence Interval')
plt.xlabel('Time')
plt.ylabel('Real Oil Prices')
plt.title('Future Forecast with Confidence Intervals')
plt.legend()
future_forecast_plot_file = 'static/plots/future-forecast-plot.png'
plt.savefig(future_forecast_plot_file)
plt.close()
if model_type == 'arima':
# Make predictions using get_forecast
forecast = model.get_forecast(steps=num_steps)

# Generate and save the future forecast plot
plt.figure(figsize=(12, 6))
historical_data = pd.read_csv('data/Combined_Log_Clean.csv')[target_column]
future_forecast_mean = forecast.predicted_mean
future_conf_int = forecast.conf_int()

plt.plot(historical_data.index, historical_data, label='Historical')
future_index = np.arange(len(historical_data), len(historical_data) + len(future_forecast_mean))
plt.plot(future_index, future_forecast_mean, color='green', label='Future Forecast')
plt.fill_between(future_index, future_conf_int.iloc[:, 0], future_conf_int.iloc[:, 1],
color='lightgreen', alpha=0.5, label='95% Confidence Interval')
plt.xlabel('Time')
plt.ylabel('Real Oil Prices')
plt.title('Future Forecast with Confidence Intervals')
plt.legend()
future_forecast_plot_file = 'static/plots/arima-future-forecast-plot.png'
plt.savefig(future_forecast_plot_file)
plt.close()

# Calculate metrics - can't give metrics for the future
mse, mae, rmse, mape, r2 = [None] * 5

result = {
'mse': mse,
'mae': mae,
'rmse': rmse,
'mape': mape,
'r2': r2,
'future_forecast_plot': url_for('static', filename='plots/arima-future-forecast-plot.png'),
'prediction': future_forecast_mean.tolist(),
}
elif model_type == 'prophet':
print("hello")
# Create a DataFrame with future dates for forecasting
future = model.make_future_dataframe(periods=num_steps)
forecast = model.predict(future)

# Plotting code for Prophet
fig, ax = plt.subplots(figsize=(10, 6))
prophet_df = pd.read_csv('data/prophet/prophet_df.csv')
prophet_df['ds'] = pd.to_datetime(prophet_df['ds']).dt.tz_localize(None)
ax.plot(prophet_df['ds'], prophet_df['y'], 'k.', label='Historical Data')
ax.plot(forecast['ds'], forecast['yhat'], ls='-', color='blue', label='Forecast')
ax.fill_between(forecast['ds'], forecast['yhat_lower'], forecast['yhat_upper'], color='blue', alpha=0.2,
label='Uncertainty Interval')
ax.axvline(x=prophet_df['ds'].iloc[-1], color='red', linestyle='--', lw=1, label='Start of Forecast')
ax.set_xlabel('Date')
ax.set_ylabel('Oil Prices')
ax.set_title(f"{num_steps}-Day Forecast with Prophet")
ax.legend()
future_forecast_plot_file = 'static/plots/prophet-future-forecast-plot.png'
plt.savefig(future_forecast_plot_file)
plt.show()

# Calculate metrics - can't give metrics for the future
mse, mae, rmse, mape, r2 = [None] * 5
print(f"Metrics: mse [Time Series]={mse}, mae={mae}, rmse={rmse}, mape={mape}, r2={r2}")

result = {
'mse': mse,
'mae': mae,
'rmse': rmse,
'mape': mape,
'r2': r2,
'future_forecast_plot': url_for('static', filename='plots/prophet-future-forecast-plot.png'),
'prediction': forecast['yhat'].tolist(), # Keeping the raw prediction values in case they are needed,
}

result = {
'mse': mse,
'mae': mae,
'rmse': rmse,
'mape': mape,
'r2': r2,
'future_forecast_plot': url_for('static', filename='plots/future-forecast-plot.png'),
'prediction': prediction.tolist(), # Keeping the raw prediction values in case they are needed,
}
else:
request_data = request.get_json()
selected_features = request_data.get('selected_features')
Expand Down
Loading

0 comments on commit 536ccce

Please sign in to comment.