Skip to content

Commit

Permalink
Update score components script
Browse files Browse the repository at this point in the history
  • Loading branch information
omrazCZ authored Apr 13, 2023
1 parent 6d3c448 commit d4be993
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions demos/score_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.gridspec as gridspec
import pandas as pd
import matplotlib.patches as mpatches
import string

os.chdir('C://Users//Oto//Documents//GitHub//drought_impact_forecasting')

Expand All @@ -17,8 +18,8 @@

ALL['epoch'] = ALL['epoch']+1

ALL.columns = ['epoch', 'zero mad', 'last-frame mad', 'zero ols', 'last-frame ols',
'zero emd', 'last-frame emd', 'zero ssim', 'last-frame ssim']
ALL.columns = ['epoch', 'zero mad', 'last-frame mad', 'SGEDConvLSTM mad', 'zero ols', 'last-frame ols', 'SGEDConvLSTM ols',
'zero emd', 'last-frame emd', 'SGEDConvLSTM emd', 'zero ssim', 'last-frame ssim', 'SGEDConvLSTM ssim']

#fig = plt.figure(figsize=(6, 3))
#fig, ax = plt.subplots()
Expand All @@ -31,8 +32,10 @@
ax1 = plt.subplot(gs[0, 0:2])
ax1.plot(ALL['epoch'], ALL['zero mad'], color='r', linestyle='-', label='no baseline')
ax1.plot(ALL['epoch'], ALL['last-frame mad'], color='b', linestyle='-', label='baseline')
ax1.legend(loc='lower right')
ax1.plot(ALL['epoch'], ALL['SGEDConvLSTM mad'], color='b', linestyle='--', label='SGEDConvLSTM')
#ax1.legend(loc='lower right')
ax1.grid()
ax1.set_title('(a)', loc='left')

ax1.set_xlabel('Epoch')
ax1.set_ylabel('MAD')
Expand All @@ -43,8 +46,10 @@
ax2 = plt.subplot(gs[0, 2:])
ax2.plot(ALL['epoch'], ALL['zero ols'], color='r', linestyle='-', label='no baseline')
ax2.plot(ALL['epoch'], ALL['last-frame ols'], color='b', linestyle='-', label='baseline')
ax2.legend(loc='lower right')
ax2.plot(ALL['epoch'], ALL['SGEDConvLSTM ols'], color='b', linestyle='--', label='SGEDConvLSTM')
#ax2.legend(loc='lower right')
ax2.grid()
ax2.set_title('(b)', loc='left')

ax2.set_xlabel('Epoch')
ax2.set_ylabel('OLS')
Expand All @@ -55,8 +60,10 @@
ax3 = plt.subplot(gs[1, 0:2])
ax3.plot(ALL['epoch'], ALL['zero emd'], color='r', linestyle='-', label='no baseline')
ax3.plot(ALL['epoch'], ALL['last-frame emd'], color='b', linestyle='-', label='baseline')
ax3.legend(loc='lower right')
ax3.plot(ALL['epoch'], ALL['SGEDConvLSTM emd'], color='b', linestyle='--', label='SGEDConvLSTM')
#ax3.legend(loc='lower right')
ax3.grid()
ax3.set_title('(c)', loc='left')

ax3.set_xlabel('Epoch')
ax3.set_ylabel('EMD')
Expand All @@ -67,16 +74,18 @@
ax4 = plt.subplot(gs[1,2:])
ax4.plot(ALL['epoch'], ALL['zero ssim'], color='r', linestyle='-', label='no baseline')
ax4.plot(ALL['epoch'], ALL['last-frame ssim'], color='b', linestyle='-', label='baseline')
ax4.legend(loc='lower right')
ax4.plot(ALL['epoch'], ALL['SGEDConvLSTM ssim'], color='b', linestyle='--', label='SGEDConvLSTM')
#ax4.legend(loc='lower right')
ax4.grid()
ax4.set_title('(d)', loc='left')

ax4.set_xlabel('Epoch')
ax4.set_ylabel('SSIM')

ax4.set_xlim(0, 35)

# ENS
ax5 = plt.subplot(gs[2,1:3])
ax5 = plt.subplot(gs[2,0:4])
file = 'Data/lastframe_vs_zero.csv'
data = pd.read_csv(os.path.join(os.getcwd(), file), delimiter=',', index_col=0)
final_epoch = 35
Expand All @@ -87,25 +96,28 @@

#ax_last = plt.subplot2grid((4,8), (4//2, 2), colspan=4)

ax5.plot(data['zero'], label='no baseline', color='r')
ax5.plot(data['last_frame'], label='baseline', color='b')
ax5.plot(data['zero'], label='SGConvLSTM, no baseline', color='r')
ax5.plot(data['last_frame'], label='SGConvLSTM, with baseline', color='b')
ax5.plot(data['SGEDConvLSTM'], color='b', linestyle='--', label='SGEDConvLSTM, with baseline')
ax5.plot(35, 0.2902, 'o', color='lime', label='U-Net')
ax5.plot(35, 0.2803, 'o', color='c', label='Arcon')

ax5.grid(True)
ax5.set_title('(e)', loc='left')

ax5.set_ylabel("ENS")
ax5.set_xlabel("Epoch")

ax5.set_xlim(0, 36)
ax5.legend(loc='lower center')
ax5.legend(loc='lower center', ncol=2)

ax5.plot(list(range(1, final_epoch+1)), (final_epoch) * [0.31], '--', color='gray')

plt.show()

gs.tight_layout(fig)

outputFile = 'ndvi_separated_components'
outputFile = 'ens_separated_components'
plt.savefig(outputFile+'.jpg', dpi=2000)
plt.savefig(outputFile+'.pdf')

Expand Down

0 comments on commit d4be993

Please sign in to comment.