diff --git a/experiments/compile_results.ipynb b/experiments/compile_results.ipynb new file mode 100644 index 0000000..b835fa4 --- /dev/null +++ b/experiments/compile_results.ipynb @@ -0,0 +1,2625 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from glob import glob\n", + "import pandas as pd\n", + "from fastcore.xtras import load_pickle\n", + "from scipy.stats import sem\n", + "import dabest\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from gptchem.plotsettings import *\n", + "from gptchem.settings import (\n", + " ONE_COL_GOLDEN_RATIO_HEIGHT_INCH,\n", + " ONE_COL_WIDTH_INCH,\n", + " TWO_COL_GOLDEN_RATIO_HEIGHT_INCH,\n", + ")\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "all_res = glob('results/*.pkl')" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "compiled_results = []\n", + "\n", + "for res in all_res:\n", + " res = load_pickle(res)\n", + " summary = {\n", + " 'acc_macro': res['acc_macro'],\n", + " 'f1_macro': res['f1_macro'],\n", + " 'f1_micro': res['f1_micro'],\n", + " 'kappa': res['kappa'],\n", + " 'num_support_samples': res['num_support_samples'],\n", + " 'model': res['model'],\n", + " 'temperature': res['temperature'],\n", + " 'strategy': res['strategy'],\n", + " }\n", + " compiled_results.append(summary)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'accuracy': 0.56,\n", + " 'acc_macro': 0.56,\n", + " 'racc': 0.5,\n", + " 'kappa': 0.1200000000000001,\n", + " 'confusion_matrix': pycm.ConfusionMatrix(classes: [0, 1]),\n", + " 'roc_auc': 0.56,\n", + " 'f1_macro': 0.45436507936507936,\n", + " 'f1_micro': 0.56,\n", + " 'frac_valid': 1.0,\n", + " 'all_y_true': (#50) [0,1,1,0,1,0,1,1,0,0...],\n", + " 'all_y_pred': [1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 0,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 0,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 0,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1],\n", + " 'valid_indices': [0,\n", + " 1,\n", + " 2,\n", + " 3,\n", + " 4,\n", + " 5,\n", + " 6,\n", + " 7,\n", + " 8,\n", + " 9,\n", + " 10,\n", + " 11,\n", + " 12,\n", + " 13,\n", + " 14,\n", + " 15,\n", + " 16,\n", + " 17,\n", + " 18,\n", + " 19,\n", + " 20,\n", + " 21,\n", + " 22,\n", + " 23,\n", + " 24,\n", + " 25,\n", + " 26,\n", + " 27,\n", + " 28,\n", + " 29,\n", + " 30,\n", + " 31,\n", + " 32,\n", + " 33,\n", + " 34,\n", + " 35,\n", + " 36,\n", + " 37,\n", + " 38,\n", + " 39,\n", + " 40,\n", + " 41,\n", + " 42,\n", + " 43,\n", + " 44,\n", + " 45,\n", + " 46,\n", + " 47,\n", + " 48,\n", + " 49],\n", + " 'might_have_rounded_floats': False,\n", + " 'num_support_samples': 10,\n", + " 'strategy': 'diverse',\n", + " 'model': 'claude-instant-1',\n", + " 'num_test_points': 50,\n", + " 'random_state': 0,\n", + " 'predictions': [1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 0,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 0,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 0,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1,\n", + " 1],\n", + " 'targets': array([0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0,\n", + " 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1,\n", + " 0, 0, 0, 1, 0, 0]),\n", + " 'max_test': 5,\n", + " 'temperature': 0.8}" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame(compiled_results)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", + " ret = ret.dtype.type(ret / rcount)\n", + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", + " ret = ret.dtype.type(ret / rcount)\n", + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", + " ret = ret.dtype.type(ret / rcount)\n", + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", + " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", + "/Users/kevinmaikjablonka/miniconda3/envs/chemlift/lib/python3.9/site-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", + " ret = ret.dtype.type(ret / rcount)\n" + ] + } + ], + "source": [ + "aggregated = df.groupby(['model','temperature', 'strategy', 'num_support_samples']).agg(['mean', sem, 'std', 'count']).sort_values(('f1_macro', 'mean'), ascending=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
acc_macrof1_macrof1_microkappa
meansemstdcountmeansemstdcountmeansemstdcountmeansemstdcount
modeltemperaturestrategynum_support_samples
gpt-40.8diverse1000.8133330.0183790.04501960.8121850.0184900.04529260.8133330.0183790.04501960.6266670.0367570.0900376
0.2diverse1000.8066670.0217050.05316660.8041660.0221900.05435460.8066670.0217050.05316660.6133330.0434100.1063336
0.8random1000.7766670.0227550.05573760.7683290.0254750.06240160.7766670.0227550.05573760.5533330.0455090.1114756
0.2random1000.7633330.0181960.04457260.7515800.0230220.05639260.7633330.0181960.04457260.5266670.0363930.0891446
500.7377780.0180880.05426390.7219330.0238790.07163890.7377780.0180880.05426390.4755560.0361750.1085259
............................................................
claude-instant-10.2diverse200.5266190.0204660.05013260.4319100.0326540.07998760.5266190.0204660.05013260.0558720.0416390.1019956
text-davinci-0030.2random200.4933330.0152020.069666210.4270140.0160190.073410210.4933330.0152020.06966621-0.0133330.0304050.13933221
text-ada-0010.2random50.5228740.0284090.05681840.4145700.0205750.04114940.5228740.0284090.0568184-0.0807270.0793020.1586034
text-davinci-0030.2random50.5018180.0108440.050863220.4133510.0161460.075730220.5018180.0108440.050863220.0036360.0216880.10172522
claude-20.2random50.5166670.0130810.03204260.4013660.0316960.07763960.5166670.0130810.03204260.0333330.0261620.0640836
\n", + "

75 rows × 16 columns

\n", + "
" + ], + "text/plain": [ + " acc_macro \\\n", + " mean sem \n", + "model temperature strategy num_support_samples \n", + "gpt-4 0.8 diverse 100 0.813333 0.018379 \n", + " 0.2 diverse 100 0.806667 0.021705 \n", + " 0.8 random 100 0.776667 0.022755 \n", + " 0.2 random 100 0.763333 0.018196 \n", + " 50 0.737778 0.018088 \n", + "... ... ... \n", + "claude-instant-1 0.2 diverse 20 0.526619 0.020466 \n", + "text-davinci-003 0.2 random 20 0.493333 0.015202 \n", + "text-ada-001 0.2 random 5 0.522874 0.028409 \n", + "text-davinci-003 0.2 random 5 0.501818 0.010844 \n", + "claude-2 0.2 random 5 0.516667 0.013081 \n", + "\n", + " \\\n", + " std count \n", + "model temperature strategy num_support_samples \n", + "gpt-4 0.8 diverse 100 0.045019 6 \n", + " 0.2 diverse 100 0.053166 6 \n", + " 0.8 random 100 0.055737 6 \n", + " 0.2 random 100 0.044572 6 \n", + " 50 0.054263 9 \n", + "... ... ... \n", + "claude-instant-1 0.2 diverse 20 0.050132 6 \n", + "text-davinci-003 0.2 random 20 0.069666 21 \n", + "text-ada-001 0.2 random 5 0.056818 4 \n", + "text-davinci-003 0.2 random 5 0.050863 22 \n", + "claude-2 0.2 random 5 0.032042 6 \n", + "\n", + " f1_macro \\\n", + " mean sem \n", + "model temperature strategy num_support_samples \n", + "gpt-4 0.8 diverse 100 0.812185 0.018490 \n", + " 0.2 diverse 100 0.804166 0.022190 \n", + " 0.8 random 100 0.768329 0.025475 \n", + " 0.2 random 100 0.751580 0.023022 \n", + " 50 0.721933 0.023879 \n", + "... ... ... \n", + "claude-instant-1 0.2 diverse 20 0.431910 0.032654 \n", + "text-davinci-003 0.2 random 20 0.427014 0.016019 \n", + "text-ada-001 0.2 random 5 0.414570 0.020575 \n", + "text-davinci-003 0.2 random 5 0.413351 0.016146 \n", + "claude-2 0.2 random 5 0.401366 0.031696 \n", + "\n", + " \\\n", + " std count \n", + "model temperature strategy num_support_samples \n", + "gpt-4 0.8 diverse 100 0.045292 6 \n", + " 0.2 diverse 100 0.054354 6 \n", + " 0.8 random 100 0.062401 6 \n", + " 0.2 random 100 0.056392 6 \n", + " 50 0.071638 9 \n", + "... ... ... \n", + "claude-instant-1 0.2 diverse 20 0.079987 6 \n", + "text-davinci-003 0.2 random 20 0.073410 21 \n", + "text-ada-001 0.2 random 5 0.041149 4 \n", + "text-davinci-003 0.2 random 5 0.075730 22 \n", + "claude-2 0.2 random 5 0.077639 6 \n", + "\n", + " f1_micro \\\n", + " mean sem \n", + "model temperature strategy num_support_samples \n", + "gpt-4 0.8 diverse 100 0.813333 0.018379 \n", + " 0.2 diverse 100 0.806667 0.021705 \n", + " 0.8 random 100 0.776667 0.022755 \n", + " 0.2 random 100 0.763333 0.018196 \n", + " 50 0.737778 0.018088 \n", + "... ... ... \n", + "claude-instant-1 0.2 diverse 20 0.526619 0.020466 \n", + "text-davinci-003 0.2 random 20 0.493333 0.015202 \n", + "text-ada-001 0.2 random 5 0.522874 0.028409 \n", + "text-davinci-003 0.2 random 5 0.501818 0.010844 \n", + "claude-2 0.2 random 5 0.516667 0.013081 \n", + "\n", + " \\\n", + " std count \n", + "model temperature strategy num_support_samples \n", + "gpt-4 0.8 diverse 100 0.045019 6 \n", + " 0.2 diverse 100 0.053166 6 \n", + " 0.8 random 100 0.055737 6 \n", + " 0.2 random 100 0.044572 6 \n", + " 50 0.054263 9 \n", + "... ... ... \n", + "claude-instant-1 0.2 diverse 20 0.050132 6 \n", + "text-davinci-003 0.2 random 20 0.069666 21 \n", + "text-ada-001 0.2 random 5 0.056818 4 \n", + "text-davinci-003 0.2 random 5 0.050863 22 \n", + "claude-2 0.2 random 5 0.032042 6 \n", + "\n", + " kappa \\\n", + " mean sem \n", + "model temperature strategy num_support_samples \n", + "gpt-4 0.8 diverse 100 0.626667 0.036757 \n", + " 0.2 diverse 100 0.613333 0.043410 \n", + " 0.8 random 100 0.553333 0.045509 \n", + " 0.2 random 100 0.526667 0.036393 \n", + " 50 0.475556 0.036175 \n", + "... ... ... \n", + "claude-instant-1 0.2 diverse 20 0.055872 0.041639 \n", + "text-davinci-003 0.2 random 20 -0.013333 0.030405 \n", + "text-ada-001 0.2 random 5 -0.080727 0.079302 \n", + "text-davinci-003 0.2 random 5 0.003636 0.021688 \n", + "claude-2 0.2 random 5 0.033333 0.026162 \n", + "\n", + " \n", + " std count \n", + "model temperature strategy num_support_samples \n", + "gpt-4 0.8 diverse 100 0.090037 6 \n", + " 0.2 diverse 100 0.106333 6 \n", + " 0.8 random 100 0.111475 6 \n", + " 0.2 random 100 0.089144 6 \n", + " 50 0.108525 9 \n", + "... ... ... \n", + "claude-instant-1 0.2 diverse 20 0.101995 6 \n", + "text-davinci-003 0.2 random 20 0.139332 21 \n", + "text-ada-001 0.2 random 5 0.158603 4 \n", + "text-davinci-003 0.2 random 5 0.101725 22 \n", + "claude-2 0.2 random 5 0.064083 6 \n", + "\n", + "[75 rows x 16 columns]" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aggregated" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [], + "source": [ + "differences = []\n", + "\n", + "metrics = 'f1_macro'\n", + "for model in df['model'].unique(): \n", + " for temperature in df['temperature'].unique(): \n", + "\n", + " for num_support_samples in df['num_support_samples'].unique():\n", + " subset = df.query(f'model == \"{model}\" and temperature == {temperature} and num_support_samples == {num_support_samples}')\n", + "\n", + " diverse_res = subset.query('strategy == \"diverse\"')\n", + " random_res = subset.query('strategy == \"random\"')\n", + "\n", + " differences.append({\n", + " 'random': random_res[metrics].mean(),\n", + " 'diverse': diverse_res[metrics].mean(),\n", + " 'model': model,\n", + " 'temperature': temperature,\n", + " 'num_support_samples': num_support_samples, \n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "differences = pd.DataFrame(differences)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
randomdiversemodeltemperaturenum_support_samples
00.467846NaNgpt-3.5-turbo0.85
10.541575NaNgpt-3.5-turbo0.8100
20.524538NaNgpt-3.5-turbo0.820
30.519985NaNgpt-3.5-turbo0.810
40.508387NaNgpt-3.5-turbo0.850
50.445453NaNgpt-3.5-turbo0.25
60.549613NaNgpt-3.5-turbo0.2100
70.493340NaNgpt-3.5-turbo0.220
80.486977NaNgpt-3.5-turbo0.210
90.530513NaNgpt-3.5-turbo0.250
100.4770050.461708claude-instant-10.85
110.5579130.490935claude-instant-10.8100
120.5065930.432338claude-instant-10.820
130.5168470.508217claude-instant-10.810
140.5829250.497253claude-instant-10.850
150.4572980.469855claude-instant-10.25
160.5404640.493734claude-instant-10.2100
170.4754010.431910claude-instant-10.220
180.4844720.490521claude-instant-10.210
190.6099730.446600claude-instant-10.250
200.469044NaNgpt-40.85
210.7683290.812185gpt-40.8100
220.656506NaNgpt-40.820
230.562016NaNgpt-40.810
240.679098NaNgpt-40.850
250.454896NaNgpt-40.25
260.7515800.804166gpt-40.2100
270.630447NaNgpt-40.220
280.588173NaNgpt-40.210
290.721933NaNgpt-40.250
300.438854NaNtext-davinci-0030.85
31NaNNaNtext-davinci-0030.8100
320.437460NaNtext-davinci-0030.820
330.458305NaNtext-davinci-0030.810
340.493742NaNtext-davinci-0030.850
350.413351NaNtext-davinci-0030.25
36NaNNaNtext-davinci-0030.2100
370.427014NaNtext-davinci-0030.220
380.465876NaNtext-davinci-0030.210
390.496884NaNtext-davinci-0030.250
400.4588280.490004claude-20.85
410.5083700.658772claude-20.8100
420.5175770.696970claude-20.820
430.5147910.529523claude-20.810
440.5374570.548816claude-20.850
450.4013660.560960claude-20.25
460.572401NaNclaude-20.2100
470.5450020.669170claude-20.220
480.5006680.597424claude-20.210
490.6028900.521686claude-20.250
500.450069NaNtext-ada-0010.85
51NaNNaNtext-ada-0010.8100
520.609821NaNtext-ada-0010.820
530.521528NaNtext-ada-0010.810
54NaNNaNtext-ada-0010.850
550.414570NaNtext-ada-0010.25
56NaNNaNtext-ada-0010.2100
570.436827NaNtext-ada-0010.220
580.548611NaNtext-ada-0010.210
59NaNNaNtext-ada-0010.250
\n", + "
" + ], + "text/plain": [ + " random diverse model temperature num_support_samples\n", + "0 0.467846 NaN gpt-3.5-turbo 0.8 5\n", + "1 0.541575 NaN gpt-3.5-turbo 0.8 100\n", + "2 0.524538 NaN gpt-3.5-turbo 0.8 20\n", + "3 0.519985 NaN gpt-3.5-turbo 0.8 10\n", + "4 0.508387 NaN gpt-3.5-turbo 0.8 50\n", + "5 0.445453 NaN gpt-3.5-turbo 0.2 5\n", + "6 0.549613 NaN gpt-3.5-turbo 0.2 100\n", + "7 0.493340 NaN gpt-3.5-turbo 0.2 20\n", + "8 0.486977 NaN gpt-3.5-turbo 0.2 10\n", + "9 0.530513 NaN gpt-3.5-turbo 0.2 50\n", + "10 0.477005 0.461708 claude-instant-1 0.8 5\n", + "11 0.557913 0.490935 claude-instant-1 0.8 100\n", + "12 0.506593 0.432338 claude-instant-1 0.8 20\n", + "13 0.516847 0.508217 claude-instant-1 0.8 10\n", + "14 0.582925 0.497253 claude-instant-1 0.8 50\n", + "15 0.457298 0.469855 claude-instant-1 0.2 5\n", + "16 0.540464 0.493734 claude-instant-1 0.2 100\n", + "17 0.475401 0.431910 claude-instant-1 0.2 20\n", + "18 0.484472 0.490521 claude-instant-1 0.2 10\n", + "19 0.609973 0.446600 claude-instant-1 0.2 50\n", + "20 0.469044 NaN gpt-4 0.8 5\n", + "21 0.768329 0.812185 gpt-4 0.8 100\n", + "22 0.656506 NaN gpt-4 0.8 20\n", + "23 0.562016 NaN gpt-4 0.8 10\n", + "24 0.679098 NaN gpt-4 0.8 50\n", + "25 0.454896 NaN gpt-4 0.2 5\n", + "26 0.751580 0.804166 gpt-4 0.2 100\n", + "27 0.630447 NaN gpt-4 0.2 20\n", + "28 0.588173 NaN gpt-4 0.2 10\n", + "29 0.721933 NaN gpt-4 0.2 50\n", + "30 0.438854 NaN text-davinci-003 0.8 5\n", + "31 NaN NaN text-davinci-003 0.8 100\n", + "32 0.437460 NaN text-davinci-003 0.8 20\n", + "33 0.458305 NaN text-davinci-003 0.8 10\n", + "34 0.493742 NaN text-davinci-003 0.8 50\n", + "35 0.413351 NaN text-davinci-003 0.2 5\n", + "36 NaN NaN text-davinci-003 0.2 100\n", + "37 0.427014 NaN text-davinci-003 0.2 20\n", + "38 0.465876 NaN text-davinci-003 0.2 10\n", + "39 0.496884 NaN text-davinci-003 0.2 50\n", + "40 0.458828 0.490004 claude-2 0.8 5\n", + "41 0.508370 0.658772 claude-2 0.8 100\n", + "42 0.517577 0.696970 claude-2 0.8 20\n", + "43 0.514791 0.529523 claude-2 0.8 10\n", + "44 0.537457 0.548816 claude-2 0.8 50\n", + "45 0.401366 0.560960 claude-2 0.2 5\n", + "46 0.572401 NaN claude-2 0.2 100\n", + "47 0.545002 0.669170 claude-2 0.2 20\n", + "48 0.500668 0.597424 claude-2 0.2 10\n", + "49 0.602890 0.521686 claude-2 0.2 50\n", + "50 0.450069 NaN text-ada-001 0.8 5\n", + "51 NaN NaN text-ada-001 0.8 100\n", + "52 0.609821 NaN text-ada-001 0.8 20\n", + "53 0.521528 NaN text-ada-001 0.8 10\n", + "54 NaN NaN text-ada-001 0.8 50\n", + "55 0.414570 NaN text-ada-001 0.2 5\n", + "56 NaN NaN text-ada-001 0.2 100\n", + "57 0.436827 NaN text-ada-001 0.2 20\n", + "58 0.548611 NaN text-ada-001 0.2 10\n", + "59 NaN NaN text-ada-001 0.2 50" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "differences" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Random low temperature" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "random_low_temp = df.query('temperature == 0.2 and strategy == \"random\"')" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
acc_macrof1_macrof1_microkappanum_support_samplesmodeltemperaturestrategy
20.780.7756020.780.56100gpt-40.2random
30.500.3333330.500.005gpt-3.5-turbo0.2random
40.560.5484400.560.1220text-davinci-0030.2random
110.520.3762990.520.045text-davinci-0030.2random
150.280.2800000.28-0.4410text-davinci-0030.2random
...........................
7240.700.6847410.700.4050gpt-3.5-turbo0.2random
7250.500.4998000.500.005claude-instant-10.2random
7260.580.5757580.580.1610gpt-3.5-turbo0.2random
7270.500.3333330.500.005claude-20.2random
7290.400.3990380.40-0.2020text-davinci-0030.2random
\n", + "

341 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " acc_macro f1_macro f1_micro kappa num_support_samples \\\n", + "2 0.78 0.775602 0.78 0.56 100 \n", + "3 0.50 0.333333 0.50 0.00 5 \n", + "4 0.56 0.548440 0.56 0.12 20 \n", + "11 0.52 0.376299 0.52 0.04 5 \n", + "15 0.28 0.280000 0.28 -0.44 10 \n", + ".. ... ... ... ... ... \n", + "724 0.70 0.684741 0.70 0.40 50 \n", + "725 0.50 0.499800 0.50 0.00 5 \n", + "726 0.58 0.575758 0.58 0.16 10 \n", + "727 0.50 0.333333 0.50 0.00 5 \n", + "729 0.40 0.399038 0.40 -0.20 20 \n", + "\n", + " model temperature strategy \n", + "2 gpt-4 0.2 random \n", + "3 gpt-3.5-turbo 0.2 random \n", + "4 text-davinci-003 0.2 random \n", + "11 text-davinci-003 0.2 random \n", + "15 text-davinci-003 0.2 random \n", + ".. ... ... ... \n", + "724 gpt-3.5-turbo 0.2 random \n", + "725 claude-instant-1 0.2 random \n", + "726 gpt-3.5-turbo 0.2 random \n", + "727 claude-2 0.2 random \n", + "729 text-davinci-003 0.2 random \n", + "\n", + "[341 rows x 8 columns]" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "random_low_temp" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/m9/_txh68y946s4pxy1x2wnd3lh0000gn/T/ipykernel_2652/2342559592.py:1: FutureWarning: ['strategy'] did not aggregate successfully. If any error is raised this will raise in a future version of pandas. Drop these columns/ops to avoid this warning.\n", + " random_low_temp = random_low_temp.groupby(['model', 'temperature', 'num_support_samples']).agg(['mean', sem, 'std', 'count']).sort_values(('f1_macro', 'mean'), ascending=False)\n" + ] + } + ], + "source": [ + "random_low_temp = random_low_temp.groupby(['model', 'temperature', 'num_support_samples']).agg(['mean', sem, 'std', 'count']).sort_values(('f1_macro', 'mean'), ascending=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
acc_macrof1_macrof1_microkappa
meansemstdcountmeansemstdcountmeansemstdcountmeansemstdcount
modeltemperaturenum_support_samples
gpt-40.21000.7633330.0181960.04457260.7515800.0230220.05639260.7633330.0181960.04457260.5266670.0363930.0891446
500.7377780.0180880.05426390.7219330.0238790.07163890.7377780.0180880.05426390.4755560.0361750.1085259
200.6666670.0290590.08717890.6304470.0400280.12008590.6666670.0290590.08717890.3333330.0581190.1743569
claude-instant-10.2500.6320000.0500400.11189350.6099730.0581930.13012350.6320000.0500400.11189350.2640000.1000800.2237865
claude-20.2500.6100000.0500000.07071120.6028900.0457250.06466420.6100000.0500000.07071120.2200000.1000000.1414212
gpt-40.2100.6111110.0321650.09649490.5881730.0356740.10702190.6111110.0321650.09649490.2222220.0643290.1929889
claude-20.21000.5866670.0545690.09451630.5724010.0575460.09967330.5866670.0545690.09451630.1733330.1091380.1890333
gpt-3.5-turbo0.21000.5733330.0193540.106004300.5496130.0228420.125112300.5733330.0193540.106004300.1466670.0387070.21200730
text-ada-0010.2100.5558820.1441180.20381320.5486110.1388890.19641920.5558820.1441180.20381320.1004440.2809990.3973932
claude-20.2200.5500000.0500000.07071120.5450020.0549980.07777920.5500000.0500000.07071120.1000000.1000000.1414212
claude-instant-10.21000.5823670.0261840.05854950.5404640.0382990.08564050.5823670.0261840.05854950.1618980.0521880.1166965
gpt-3.5-turbo0.2500.5566670.0181710.099528300.5305130.0218450.119650300.5566670.0181710.099528300.1133330.0363420.19905530
claude-20.2100.5150000.0170780.03415740.5006680.0195460.03909240.5150000.0170780.03415740.0300000.0341570.0683134
text-davinci-0030.2500.5510000.0228250.102078200.4968840.0299920.134130200.5510000.0228250.102078200.1020000.0456510.20415720
gpt-3.5-turbo0.2200.5280000.0129580.070974300.4933400.0162180.088830300.5280000.0129580.070974300.0560000.0259160.14194730
100.5366670.0140630.077028300.4869770.0187290.102582300.5366670.0140630.077028300.0733330.0281270.15405630
claude-instant-10.2100.5266670.0190900.04676260.4844720.0403930.09894360.5266670.0190900.04676260.0533330.0381810.0935246
200.5300000.0184390.04516660.4754010.0327610.08024760.5300000.0184390.04516660.0600000.0368780.0903336
text-davinci-0030.2100.5103700.0188070.097723270.4658760.0205610.106837270.5103700.0188070.097723270.0207410.0376140.19544727
claude-instant-10.250.5377780.0139220.04176790.4572980.0329720.09891590.5377780.0139220.04176790.0755560.0278440.0835339
gpt-40.250.5333330.0216020.06480790.4548960.0416370.12491290.5333330.0216020.06480790.0666670.0432050.1296159
gpt-3.5-turbo0.250.5160000.0080460.044069300.4454530.0128920.070611300.5160000.0080460.044069300.0320000.0160920.08813830
text-ada-0010.2200.5244130.0722960.16165950.4368270.0941460.21051750.4789580.0825510.1845905-0.0146410.1429990.3197555
text-davinci-0030.2200.4933330.0152020.069666210.4270140.0160190.073410210.4933330.0152020.06966621-0.0133330.0304050.13933221
text-ada-0010.250.5228740.0284090.05681840.4145700.0205750.04114940.5228740.0284090.0568184-0.0807270.0793020.1586034
text-davinci-0030.250.5018180.0108440.050863220.4133510.0161460.075730220.5018180.0108440.050863220.0036360.0216880.10172522
claude-20.250.5166670.0130810.03204260.4013660.0316960.07763960.5166670.0130810.03204260.0333330.0261620.0640836
\n", + "
" + ], + "text/plain": [ + " acc_macro \\\n", + " mean sem \n", + "model temperature num_support_samples \n", + "gpt-4 0.2 100 0.763333 0.018196 \n", + " 50 0.737778 0.018088 \n", + " 20 0.666667 0.029059 \n", + "claude-instant-1 0.2 50 0.632000 0.050040 \n", + "claude-2 0.2 50 0.610000 0.050000 \n", + "gpt-4 0.2 10 0.611111 0.032165 \n", + "claude-2 0.2 100 0.586667 0.054569 \n", + "gpt-3.5-turbo 0.2 100 0.573333 0.019354 \n", + "text-ada-001 0.2 10 0.555882 0.144118 \n", + "claude-2 0.2 20 0.550000 0.050000 \n", + "claude-instant-1 0.2 100 0.582367 0.026184 \n", + "gpt-3.5-turbo 0.2 50 0.556667 0.018171 \n", + "claude-2 0.2 10 0.515000 0.017078 \n", + "text-davinci-003 0.2 50 0.551000 0.022825 \n", + "gpt-3.5-turbo 0.2 20 0.528000 0.012958 \n", + " 10 0.536667 0.014063 \n", + "claude-instant-1 0.2 10 0.526667 0.019090 \n", + " 20 0.530000 0.018439 \n", + "text-davinci-003 0.2 10 0.510370 0.018807 \n", + "claude-instant-1 0.2 5 0.537778 0.013922 \n", + "gpt-4 0.2 5 0.533333 0.021602 \n", + "gpt-3.5-turbo 0.2 5 0.516000 0.008046 \n", + "text-ada-001 0.2 20 0.524413 0.072296 \n", + "text-davinci-003 0.2 20 0.493333 0.015202 \n", + "text-ada-001 0.2 5 0.522874 0.028409 \n", + "text-davinci-003 0.2 5 0.501818 0.010844 \n", + "claude-2 0.2 5 0.516667 0.013081 \n", + "\n", + " f1_macro \\\n", + " std count mean \n", + "model temperature num_support_samples \n", + "gpt-4 0.2 100 0.044572 6 0.751580 \n", + " 50 0.054263 9 0.721933 \n", + " 20 0.087178 9 0.630447 \n", + "claude-instant-1 0.2 50 0.111893 5 0.609973 \n", + "claude-2 0.2 50 0.070711 2 0.602890 \n", + "gpt-4 0.2 10 0.096494 9 0.588173 \n", + "claude-2 0.2 100 0.094516 3 0.572401 \n", + "gpt-3.5-turbo 0.2 100 0.106004 30 0.549613 \n", + "text-ada-001 0.2 10 0.203813 2 0.548611 \n", + "claude-2 0.2 20 0.070711 2 0.545002 \n", + "claude-instant-1 0.2 100 0.058549 5 0.540464 \n", + "gpt-3.5-turbo 0.2 50 0.099528 30 0.530513 \n", + "claude-2 0.2 10 0.034157 4 0.500668 \n", + "text-davinci-003 0.2 50 0.102078 20 0.496884 \n", + "gpt-3.5-turbo 0.2 20 0.070974 30 0.493340 \n", + " 10 0.077028 30 0.486977 \n", + "claude-instant-1 0.2 10 0.046762 6 0.484472 \n", + " 20 0.045166 6 0.475401 \n", + "text-davinci-003 0.2 10 0.097723 27 0.465876 \n", + "claude-instant-1 0.2 5 0.041767 9 0.457298 \n", + "gpt-4 0.2 5 0.064807 9 0.454896 \n", + "gpt-3.5-turbo 0.2 5 0.044069 30 0.445453 \n", + "text-ada-001 0.2 20 0.161659 5 0.436827 \n", + "text-davinci-003 0.2 20 0.069666 21 0.427014 \n", + "text-ada-001 0.2 5 0.056818 4 0.414570 \n", + "text-davinci-003 0.2 5 0.050863 22 0.413351 \n", + "claude-2 0.2 5 0.032042 6 0.401366 \n", + "\n", + " \\\n", + " sem std count \n", + "model temperature num_support_samples \n", + "gpt-4 0.2 100 0.023022 0.056392 6 \n", + " 50 0.023879 0.071638 9 \n", + " 20 0.040028 0.120085 9 \n", + "claude-instant-1 0.2 50 0.058193 0.130123 5 \n", + "claude-2 0.2 50 0.045725 0.064664 2 \n", + "gpt-4 0.2 10 0.035674 0.107021 9 \n", + "claude-2 0.2 100 0.057546 0.099673 3 \n", + "gpt-3.5-turbo 0.2 100 0.022842 0.125112 30 \n", + "text-ada-001 0.2 10 0.138889 0.196419 2 \n", + "claude-2 0.2 20 0.054998 0.077779 2 \n", + "claude-instant-1 0.2 100 0.038299 0.085640 5 \n", + "gpt-3.5-turbo 0.2 50 0.021845 0.119650 30 \n", + "claude-2 0.2 10 0.019546 0.039092 4 \n", + "text-davinci-003 0.2 50 0.029992 0.134130 20 \n", + "gpt-3.5-turbo 0.2 20 0.016218 0.088830 30 \n", + " 10 0.018729 0.102582 30 \n", + "claude-instant-1 0.2 10 0.040393 0.098943 6 \n", + " 20 0.032761 0.080247 6 \n", + "text-davinci-003 0.2 10 0.020561 0.106837 27 \n", + "claude-instant-1 0.2 5 0.032972 0.098915 9 \n", + "gpt-4 0.2 5 0.041637 0.124912 9 \n", + "gpt-3.5-turbo 0.2 5 0.012892 0.070611 30 \n", + "text-ada-001 0.2 20 0.094146 0.210517 5 \n", + "text-davinci-003 0.2 20 0.016019 0.073410 21 \n", + "text-ada-001 0.2 5 0.020575 0.041149 4 \n", + "text-davinci-003 0.2 5 0.016146 0.075730 22 \n", + "claude-2 0.2 5 0.031696 0.077639 6 \n", + "\n", + " f1_micro \\\n", + " mean sem \n", + "model temperature num_support_samples \n", + "gpt-4 0.2 100 0.763333 0.018196 \n", + " 50 0.737778 0.018088 \n", + " 20 0.666667 0.029059 \n", + "claude-instant-1 0.2 50 0.632000 0.050040 \n", + "claude-2 0.2 50 0.610000 0.050000 \n", + "gpt-4 0.2 10 0.611111 0.032165 \n", + "claude-2 0.2 100 0.586667 0.054569 \n", + "gpt-3.5-turbo 0.2 100 0.573333 0.019354 \n", + "text-ada-001 0.2 10 0.555882 0.144118 \n", + "claude-2 0.2 20 0.550000 0.050000 \n", + "claude-instant-1 0.2 100 0.582367 0.026184 \n", + "gpt-3.5-turbo 0.2 50 0.556667 0.018171 \n", + "claude-2 0.2 10 0.515000 0.017078 \n", + "text-davinci-003 0.2 50 0.551000 0.022825 \n", + "gpt-3.5-turbo 0.2 20 0.528000 0.012958 \n", + " 10 0.536667 0.014063 \n", + "claude-instant-1 0.2 10 0.526667 0.019090 \n", + " 20 0.530000 0.018439 \n", + "text-davinci-003 0.2 10 0.510370 0.018807 \n", + "claude-instant-1 0.2 5 0.537778 0.013922 \n", + "gpt-4 0.2 5 0.533333 0.021602 \n", + "gpt-3.5-turbo 0.2 5 0.516000 0.008046 \n", + "text-ada-001 0.2 20 0.478958 0.082551 \n", + "text-davinci-003 0.2 20 0.493333 0.015202 \n", + "text-ada-001 0.2 5 0.522874 0.028409 \n", + "text-davinci-003 0.2 5 0.501818 0.010844 \n", + "claude-2 0.2 5 0.516667 0.013081 \n", + "\n", + " kappa \\\n", + " std count mean \n", + "model temperature num_support_samples \n", + "gpt-4 0.2 100 0.044572 6 0.526667 \n", + " 50 0.054263 9 0.475556 \n", + " 20 0.087178 9 0.333333 \n", + "claude-instant-1 0.2 50 0.111893 5 0.264000 \n", + "claude-2 0.2 50 0.070711 2 0.220000 \n", + "gpt-4 0.2 10 0.096494 9 0.222222 \n", + "claude-2 0.2 100 0.094516 3 0.173333 \n", + "gpt-3.5-turbo 0.2 100 0.106004 30 0.146667 \n", + "text-ada-001 0.2 10 0.203813 2 0.100444 \n", + "claude-2 0.2 20 0.070711 2 0.100000 \n", + "claude-instant-1 0.2 100 0.058549 5 0.161898 \n", + "gpt-3.5-turbo 0.2 50 0.099528 30 0.113333 \n", + "claude-2 0.2 10 0.034157 4 0.030000 \n", + "text-davinci-003 0.2 50 0.102078 20 0.102000 \n", + "gpt-3.5-turbo 0.2 20 0.070974 30 0.056000 \n", + " 10 0.077028 30 0.073333 \n", + "claude-instant-1 0.2 10 0.046762 6 0.053333 \n", + " 20 0.045166 6 0.060000 \n", + "text-davinci-003 0.2 10 0.097723 27 0.020741 \n", + "claude-instant-1 0.2 5 0.041767 9 0.075556 \n", + "gpt-4 0.2 5 0.064807 9 0.066667 \n", + "gpt-3.5-turbo 0.2 5 0.044069 30 0.032000 \n", + "text-ada-001 0.2 20 0.184590 5 -0.014641 \n", + "text-davinci-003 0.2 20 0.069666 21 -0.013333 \n", + "text-ada-001 0.2 5 0.056818 4 -0.080727 \n", + "text-davinci-003 0.2 5 0.050863 22 0.003636 \n", + "claude-2 0.2 5 0.032042 6 0.033333 \n", + "\n", + " \n", + " sem std count \n", + "model temperature num_support_samples \n", + "gpt-4 0.2 100 0.036393 0.089144 6 \n", + " 50 0.036175 0.108525 9 \n", + " 20 0.058119 0.174356 9 \n", + "claude-instant-1 0.2 50 0.100080 0.223786 5 \n", + "claude-2 0.2 50 0.100000 0.141421 2 \n", + "gpt-4 0.2 10 0.064329 0.192988 9 \n", + "claude-2 0.2 100 0.109138 0.189033 3 \n", + "gpt-3.5-turbo 0.2 100 0.038707 0.212007 30 \n", + "text-ada-001 0.2 10 0.280999 0.397393 2 \n", + "claude-2 0.2 20 0.100000 0.141421 2 \n", + "claude-instant-1 0.2 100 0.052188 0.116696 5 \n", + "gpt-3.5-turbo 0.2 50 0.036342 0.199055 30 \n", + "claude-2 0.2 10 0.034157 0.068313 4 \n", + "text-davinci-003 0.2 50 0.045651 0.204157 20 \n", + "gpt-3.5-turbo 0.2 20 0.025916 0.141947 30 \n", + " 10 0.028127 0.154056 30 \n", + "claude-instant-1 0.2 10 0.038181 0.093524 6 \n", + " 20 0.036878 0.090333 6 \n", + "text-davinci-003 0.2 10 0.037614 0.195447 27 \n", + "claude-instant-1 0.2 5 0.027844 0.083533 9 \n", + "gpt-4 0.2 5 0.043205 0.129615 9 \n", + "gpt-3.5-turbo 0.2 5 0.016092 0.088138 30 \n", + "text-ada-001 0.2 20 0.142999 0.319755 5 \n", + "text-davinci-003 0.2 20 0.030405 0.139332 21 \n", + "text-ada-001 0.2 5 0.079302 0.158603 4 \n", + "text-davinci-003 0.2 5 0.021688 0.101725 22 \n", + "claude-2 0.2 5 0.026162 0.064083 6 " + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "random_low_temp" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "\n", + "for i, model in enumerate(df['model'].unique()):\n", + " subset = random_low_temp.query(f'model == \"{model}\"')\n", + " subset = subset.reset_index()\n", + " subset = subset.sort_values('num_support_samples')\n", + " ax.plot(subset['num_support_samples'], subset[('f1_macro', 'mean')], label=model, color=f\"C{i}\",\n", + " marker='o')\n", + " ax.fill_between(subset['num_support_samples'], subset[('f1_macro', 'mean')] - subset[('f1_macro', 'sem')], \n", + " subset[('f1_macro', 'mean')] + subset[('f1_macro', 'sem')], alpha=0.2, color=f\"C{i}\")\n", + "\n", + "range_frame(ax, \n", + " np.array([5, 100]),\n", + " np.array([df['f1_macro'].min(), df['f1_macro'].max()]),\n", + " )\n", + "\n", + "ax.set_xlabel('number of support samples')\n", + "ax.set_ylabel('macro F1 score')\n", + "\n", + "\n", + "matplotx.line_labels(ax)\n", + "fig.tight_layout()\n", + "\n", + "fig.savefig('icl_t02_random_sampling.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Random high temp" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": {}, + "outputs": [], + "source": [ + "random_high_temp = df.query('temperature == 0.8 and strategy == \"random\"')" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/m9/_txh68y946s4pxy1x2wnd3lh0000gn/T/ipykernel_2652/3153245902.py:1: FutureWarning: ['strategy'] did not aggregate successfully. If any error is raised this will raise in a future version of pandas. Drop these columns/ops to avoid this warning.\n", + " random_high_temp = random_high_temp.groupby(['model', 'num_support_samples']).agg(['mean', sem, 'std', 'count']).sort_values(('f1_macro', 'mean'), ascending=False)\n" + ] + } + ], + "source": [ + "random_high_temp = random_high_temp.groupby(['model', 'num_support_samples']).agg(['mean', sem, 'std', 'count']).sort_values(('f1_macro', 'mean'), ascending=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "\n", + "for i, model in enumerate(df['model'].unique()):\n", + " subset = random_high_temp.query(f'model == \"{model}\"')\n", + " subset = subset.reset_index()\n", + " subset = subset.sort_values('num_support_samples')\n", + " ax.plot(subset['num_support_samples'], subset[('f1_macro', 'mean')], label=model, color=f\"C{i}\",\n", + " marker='o')\n", + " ax.fill_between(subset['num_support_samples'], subset[('f1_macro', 'mean')] - subset[('f1_macro', 'sem')], \n", + " subset[('f1_macro', 'mean')] + subset[('f1_macro', 'sem')], alpha=0.2, color=f\"C{i}\")\n", + "\n", + "range_frame(ax, \n", + " np.array([5, 100]),\n", + " np.array([df['f1_macro'].min(), df['f1_macro'].max()]),\n", + " )\n", + "\n", + "ax.set_xlabel('number of support samples')\n", + "ax.set_ylabel('macro F1 score')\n", + "\n", + "\n", + "matplotx.line_labels(ax)\n", + "fig.tight_layout()\n", + "\n", + "fig.savefig('icl_t08_random_sampling.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chemlift", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiments/icl_t02_random_sampling.pdf b/experiments/icl_t02_random_sampling.pdf new file mode 100644 index 0000000..18955a7 Binary files /dev/null and b/experiments/icl_t02_random_sampling.pdf differ diff --git a/experiments/icl_t08_random_sampling.pdf b/experiments/icl_t08_random_sampling.pdf new file mode 100644 index 0000000..f5b1c8e Binary files /dev/null and b/experiments/icl_t08_random_sampling.pdf differ diff --git a/experiments/run_experiments_hugginface.py b/experiments/run_experiments_hugginface.py index bfc63c8..d10be55 100644 --- a/experiments/run_experiments_hugginface.py +++ b/experiments/run_experiments_hugginface.py @@ -1,8 +1,20 @@ -# from langchain import HuggingFaceHub -# llm = HuggingFaceHub(repo_id = ) +from langchain import HuggingFaceHub +from chemlift.icl.fewshotclassifier import FewShotClassifier +from chemlift.icl.fewshotpredictor import Strategy +from gptchem.data import get_photoswitch_data +from sklearn.model_selection import train_test_split +from gptchem.evaluator import evaluate_classification +import time +from fastcore.xtras import save_pickle, load_pickle +import os +import dotenv +dotenv.load_dotenv("../.env", override=True) -models = [3 +number_support_samples = [5, 10, 20, 50, 100] +strategies = [Strategy.RANDOM, Strategy.DIVERSE] + +models = [ "google/flan-t5-xl", "bigscience/bloom", "EleutherAI/pythia-70m-deduped", @@ -15,3 +27,96 @@ ] +context_sizes = { + "google/flan-t5-xl": 1024, + "bigscience/bloom": 1024, + "EleutherAI/pythia-70m-deduped": 500, + "EleutherAI/pythia-160m-deduped": 500, + "EleutherAI/pythia-410m-deduped": 500, + "EleutherAI/pythia-1b-deduped": 500, + "EleutherAI/pythia-2.8b-deduped": 500, + "EleutherAI/pythia-6.9b-deduped": 500, + "EleutherAI/pythia-12b-deduped": 500, +} + + +def get_timestr(): + return time.strftime("%Y-%m-%d_%H-%M-%S") + + +def train_test( + num_support_samples, + strategy, + model, + num_test_points, + random_state=42, + temperature=0.8, + max_test=5, +): + llm = HuggingFaceHub( + repo_id=model, model_kwargs={"temperature": temperature, "max_length": context_sizes[model]} + ) + classifier = FewShotClassifier( + llm, + property_name="class of the transition wavelength", + n_support=num_support_samples, + strategy=strategy, + seed=random_state, + prefix="You are an expert chemist. ", + max_test=max_test, + ) + + data = get_photoswitch_data() + data = data.dropna(subset=["SMILES", "E isomer pi-pi* wavelength in nm"]) + + data["label"] = data["E isomer pi-pi* wavelength in nm"].apply( + lambda x: 1 if x > data["E isomer pi-pi* wavelength in nm"].median() else 0 + ) + + data_train, data_test = train_test_split( + data, test_size=num_test_points, stratify=data["label"], random_state=random_state + ) + + classifier.fit(data_train["SMILES"].values, data_train["label"].values) + predictions = classifier.predict(data_test["SMILES"].values) + + report = evaluate_classification(data_test["label"].values, predictions) + + report["num_support_samples"] = num_support_samples + report["strategy"] = strategy.value + report["model"] = model + report["num_test_points"] = num_test_points + report["random_state"] = random_state + + report["predictions"] = predictions + report["targets"] = data_test["label"].values + report["max_test"] = max_test + report["temperature"] = temperature + + if not os.path.exists("results"): + os.makedirs("results") + + save_pickle(f"results/{get_timestr()}_huggingface_report.pkl", report) + print(report) + + +if __name__ == "__main__": + for seed in range(5): + for num_support_samples in number_support_samples: + for strategy in strategies: + for anthropic_mode in models: + for num_test_points in [50]: + for temperature in [0.2, 0.8]: + for max_test in [1, 5, 10]: + try: + train_test( + num_support_samples, + strategy, + anthropic_mode, + num_test_points, + random_state=seed, + temperature=temperature, + max_test=max_test, + ) + except Exception as e: + print(anthropic_mode, e) diff --git a/experiments/run_experiments_openai.py b/experiments/run_experiments_openai.py index 92522a8..47bffb4 100644 --- a/experiments/run_experiments_openai.py +++ b/experiments/run_experiments_openai.py @@ -1,6 +1,111 @@ +from chemlift.icl.utils import LangChainChatModelWrapper +from langchain.chat_models import ChatOpenAI +from langchain.llms import OpenAI +from chemlift.icl.fewshotclassifier import FewShotClassifier +from chemlift.icl.fewshotpredictor import Strategy from gptchem.data import get_photoswitch_data -from gptchem.evaluator import evaluate_classication - from sklearn.model_selection import train_test_split +from gptchem.evaluator import evaluate_classification +import time +from fastcore.xtras import save_pickle, load_pickle +import os +import dotenv +import langchain +from langchain.cache import SQLiteCache + +langchain.llm_cache = SQLiteCache(database_path=".langchain.db") +dotenv.load_dotenv("../.env", override=True) + +number_support_samples = [5, 10, 20, 50, 100] +strategies = [Strategy.RANDOM, Strategy.DIVERSE] + +openai_llm_models = ["text-ada-001", "text-davinci-003"] +openai_chat_models = ["gpt-4", "gpt-3.5-turbo"] + +openai_models = openai_llm_models + openai_chat_models + + +def get_timestr(): + return time.strftime("%Y-%m-%d_%H-%M-%S") + + +def train_test( + num_support_samples, + strategy, + model, + num_test_points, + random_state=42, + temperature=0.8, + max_test=5, +): + if model in openai_chat_models: + llm = LangChainChatModelWrapper(ChatOpenAI(model=model, temperature=temperature)) + elif model in openai_llm_models: + llm = OpenAI(model_name=model) + else: + raise ValueError(f"Unknown model {model}") + + classifier = FewShotClassifier( + llm, + property_name="class of the transition wavelength", + n_support=num_support_samples, + strategy=strategy, + seed=random_state, + prefix="You are an expert chemist. ", + max_test=max_test, + ) + + data = get_photoswitch_data() + data = data.dropna(subset=["SMILES", "E isomer pi-pi* wavelength in nm"]) + + data["label"] = data["E isomer pi-pi* wavelength in nm"].apply( + lambda x: 1 if x > data["E isomer pi-pi* wavelength in nm"].median() else 0 + ) + + data_train, data_test = train_test_split( + data, test_size=num_test_points, stratify=data["label"], random_state=random_state + ) + + classifier.fit(data_train["SMILES"].values, data_train["label"].values) + predictions = classifier.predict(data_test["SMILES"].values) + + report = evaluate_classification(data_test["label"].values, predictions) + + report["num_support_samples"] = num_support_samples + report["strategy"] = strategy.value + report["model"] = model + report["num_test_points"] = num_test_points + report["random_state"] = random_state + + report["predictions"] = predictions + report["targets"] = data_test["label"].values + report["max_test"] = max_test + report["temperature"] = temperature + + if not os.path.exists("results"): + os.makedirs("results") + + save_pickle(f"results/{get_timestr()}_openai_report.pkl", report) + print(report) + -openai_models = ["text-ada-001", "text-davinci-003", "gpt-4", "gpt-3.5-turbo"] +if __name__ == "__main__": + for seed in range(5): + for num_support_samples in number_support_samples: + for strategy in strategies: + for anthropic_mode in openai_models: + for num_test_points in [50]: + for temperature in [0.2, 0.8]: + for max_test in [1, 5, 10]: + try: + train_test( + num_support_samples, + strategy, + anthropic_mode, + num_test_points, + random_state=seed + 34, + temperature=temperature, + max_test=max_test, + ) + except Exception as e: + print(e)