diff --git a/environment.yml b/environment.yml index 8865144..af9d24d 100644 --- a/environment.yml +++ b/environment.yml @@ -5,7 +5,7 @@ dependencies: - python=>3.9 - xarray - matplotlib>=3.1 - - numpy>=1.17 + - numpy>=2.0 - pandas>=2.2.1 - pip - scipy diff --git a/examples/example_report.py b/examples/example_report_AI.py similarity index 87% rename from examples/example_report.py rename to examples/example_report_AI.py index 5dc7101..8d5493f 100644 --- a/examples/example_report.py +++ b/examples/example_report_AI.py @@ -5,6 +5,10 @@ from docx.shared import Inches import warnings warnings.filterwarnings("ignore") +import google.generativeai as genai +import os +genai.configure(api_key=os.environ["GEMINI_API_KEY"]) +model = genai.GenerativeModel(model_name="gemini-1.5-flash") df = readNora10File('../tests/data/NORA_test.txt') @@ -49,10 +53,13 @@ # Add the first table df = tables.table_monthly_non_exceedance(ds,var= 'W10',step_var=2,output_file=None) -doc.add_heading('Table 1: Monthly non-exceedance', level=2) +response = model.generate_content("Write one sentence caption to the table:"+str(df)) +#doc.add_heading('Table 1: Monthly non-exceedance', level=2) +doc.add_heading('Table 1:'+response.text, level=2) table1 = doc.add_table(rows=df.shape[0] + 1, cols=df.shape[1]) table1.style = 'Table Grid' + # Add the header row for the first table hdr_cells = table1.rows[0].cells for i, column in enumerate(df.columns): diff --git a/examples/nb_hour_below_thr.csv b/examples/nb_hour_below_thr.csv deleted file mode 100644 index 60b95d2..0000000 --- a/examples/nb_hour_below_thr.csv +++ /dev/null @@ -1,11 +0,0 @@ -HS,Minimum,Mean,Maximum -<1,771,1227,1812 -<2,3783,4454,5025 -<3,6036,6582,7086 -<4,7260,7717,8055 -<5,7971,8279,8472 -<6,8397,8538,8670 -<7,8583,8676,8745 -<8,8694,8734,8778 -<9,8736,8756,8784 -<10,8742,8763,8784 diff --git a/examples/nb_hour_below_thr.png b/examples/nb_hour_below_thr.png deleted file mode 100644 index 6d7b6fa..0000000 Binary files a/examples/nb_hour_below_thr.png and /dev/null differ diff --git a/metocean_stats/stats/aux_funcs.py b/metocean_stats/stats/aux_funcs.py index da621aa..ddf2fbf 100644 --- a/metocean_stats/stats/aux_funcs.py +++ b/metocean_stats/stats/aux_funcs.py @@ -124,13 +124,17 @@ def Hs_Tp_curve(data,pdf_Hs,pdf_Hs_Tp,f_Hs_Tp,h,t,interval,X=100): # Find index of Hs=value epsilon = abs(h - rve_X) - param = find_peaks(1/epsilon) # to find the index of bottom - index = param[0][0] # the index of Hs=value + param, prov = find_peaks(1/epsilon) # to find the index of bottom + if len(param) == 0: + param = [np.argmax(1/epsilon)] + index = param[0] # the index of Hs=value # Find peak of pdf at Hs=RVE of X year pdf_Hs_Tp_X = pdf_Hs_Tp[index,:] # Find pdf at RVE of X year - param = find_peaks(pdf_Hs_Tp_X) # find the peak - index = param[0][0] + param, prov = find_peaks(pdf_Hs_Tp_X) # find the peak + if len(param) == 0: + param = [np.argmax(pdf_Hs_Tp_X)] + index = param[0] f_Hs_Tp_100=pdf_Hs_Tp_X[index] diff --git a/tests/test_plots.py b/tests/test_plots.py index 99503ac..3cd51f8 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -31,15 +31,6 @@ def test_plot_prob_non_exceedance_fitted_3p_weibull(ds=ds): else: raise ValueError("FigValue is not correct") -#def test_scatter_diagram(ds=ds): -# output_file = 'test_scatter_diagram.csv' -# df = tables.scatter_diagram(ds, var1='HS', step_var1=1, var2='TP', step_var2=1, output_file=output_file) -# if os.path.exists(output_file): -# os.remove(output_file) -# if df.shape[0] == 14: -# pass -# else: -# raise ValueError("Shape is not correct") def test_plot_monthly_stats(ds=ds): output_file = 'test_monthly_stats.png' diff --git a/tests/test_tables.py b/tests/test_tables.py index 3a0cc36..89c34b5 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -10,15 +10,13 @@ depth = ['0m', '1m', '2.5m', '5m', '10m', '15m', '20m', '25m', '30m', '40m', '50m', '75m', '100m', '150m', '200m'] -#def test_scatter_diagram(ds=ds): -# output_file = 'test_scatter_diagram.csv' -# df = tables.scatter_diagram(ds, var1='HS', step_var1=1, var2='TP', step_var2=1, output_file=output_file) -# if os.path.exists(output_file): -# os.remove(output_file) -# if df.shape[0] == 14: -# pass -# else: -# raise ValueError("Shape is not correct") +def test_scatter_diagram(ds=ds): + output_file = 'test_scatter_diagram.csv' + df = tables.scatter_diagram(ds, var1='HS', step_var1=1, var2='TP', step_var2=1, output_file=output_file) + if os.path.exists(output_file): + os.remove(output_file) + assert df.shape[0] == 14 + def test_table_var_sorted_by_hs(ds=ds): output_file = 'test_var_sorted_by_hs.csv'