From f6eca044f359c59073f22d8bef15565ed7b2ac40 Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Thu, 13 Jun 2024 11:08:59 +0800 Subject: [PATCH 1/4] refactor(config): change to compositional config --- pdm.lock | 556 +++++++++++++++++- pyproject.toml | 4 + src/lm_saes/activation/activation_dataset.py | 14 +- src/lm_saes/activation/activation_source.py | 6 +- src/lm_saes/activation/activation_store.py | 2 +- src/lm_saes/activation/token_source.py | 37 +- src/lm_saes/analysis/auto_interp.py | 8 +- src/lm_saes/analysis/features_to_logits.py | 2 +- .../analysis/sample_feature_activations.py | 30 +- src/lm_saes/config.py | 105 ++-- src/lm_saes/evals.py | 22 +- src/lm_saes/runner.py | 207 ++++--- src/lm_saes/sae_training.py | 32 +- src/lm_saes/utils/config.py | 214 +++++++ 14 files changed, 1008 insertions(+), 231 deletions(-) create mode 100644 src/lm_saes/utils/config.py diff --git a/pdm.lock b/pdm.lock index b13fe5f..a4474c2 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:5266c91187a20b13682380660c9795b4ea9c2f2c2ad5370e97ab83ec920ece84" +content_hash = "sha256:c80acaac6865f83ab37f1a93ef1622a1cc51a94c3b9e7ea959d11f146d5acbf0" [[package]] name = "accelerate" @@ -102,6 +102,31 @@ files = [ {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, ] +[[package]] +name = "appnope" +version = "0.1.4" +requires_python = ">=3.6" +summary = "Disable App Nap on macOS >= 10.9" +groups = ["dev"] +marker = "platform_system == \"Darwin\"" +files = [ + {file = "appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c"}, + {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, +] + +[[package]] +name = "asttokens" +version = "2.4.1" +summary = "Annotate AST trees with source code positions" +groups = ["dev"] +dependencies = [ + "six>=1.12.0", +] +files = [ + {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"}, + {file = "asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0"}, +] + [[package]] name = "async-timeout" version = "4.0.3" @@ -157,6 +182,31 @@ files = [ {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, ] +[[package]] +name = "cffi" +version = "1.16.0" +requires_python = ">=3.8" +summary = "Foreign Function Interface for Python calling C code." +groups = ["dev"] +marker = "implementation_name == \"pypy\"" +dependencies = [ + "pycparser", +] +files = [ + {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, + {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"}, + {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"}, + {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"}, + {file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"}, + {file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"}, + {file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"}, +] + [[package]] name = "charset-normalizer" version = "3.3.2" @@ -209,6 +259,20 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "comm" +version = "0.2.2" +requires_python = ">=3.8" +summary = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +groups = ["dev"] +dependencies = [ + "traitlets>=4", +] +files = [ + {file = "comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3"}, + {file = "comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e"}, +] + [[package]] name = "contourpy" version = "1.2.1" @@ -274,6 +338,32 @@ files = [ {file = "datasets-2.19.1.tar.gz", hash = "sha256:0df9ef6c5e9138cdb996a07385220109ff203c204245578b69cca905eb151d3a"}, ] +[[package]] +name = "debugpy" +version = "1.8.1" +requires_python = ">=3.8" +summary = "An implementation of the Debug Adapter Protocol for Python" +groups = ["dev"] +files = [ + {file = "debugpy-1.8.1-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3bda0f1e943d386cc7a0e71bfa59f4137909e2ed947fb3946c506e113000f741"}, + {file = "debugpy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dda73bf69ea479c8577a0448f8c707691152e6c4de7f0c4dec5a4bc11dee516e"}, + {file = "debugpy-1.8.1-cp310-cp310-win32.whl", hash = "sha256:3a79c6f62adef994b2dbe9fc2cc9cc3864a23575b6e387339ab739873bea53d0"}, + {file = "debugpy-1.8.1-cp310-cp310-win_amd64.whl", hash = "sha256:7eb7bd2b56ea3bedb009616d9e2f64aab8fc7000d481faec3cd26c98a964bcdd"}, + {file = "debugpy-1.8.1-py2.py3-none-any.whl", hash = "sha256:28acbe2241222b87e255260c76741e1fbf04fdc3b6d094fcf57b6c6f75ce1242"}, + {file = "debugpy-1.8.1.zip", hash = "sha256:f696d6be15be87aef621917585f9bb94b1dc9e8aced570db1b8a6fc14e8f9b42"}, +] + +[[package]] +name = "decorator" +version = "5.1.1" +requires_python = ">=3.5" +summary = "Decorators for Humans" +groups = ["dev"] +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "dill" version = "0.3.8" @@ -351,13 +441,24 @@ name = "exceptiongroup" version = "1.2.1" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default"] +groups = ["default", "dev"] marker = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.1-py3-none-any.whl", hash = "sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad"}, {file = "exceptiongroup-1.2.1.tar.gz", hash = "sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16"}, ] +[[package]] +name = "executing" +version = "2.0.1" +requires_python = ">=3.5" +summary = "Get the currently executing AST node of a frame, and other information" +groups = ["dev"] +files = [ + {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, + {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, +] + [[package]] name = "fancy-einsum" version = "0.0.3" @@ -407,6 +508,16 @@ files = [ {file = "fastapi_cli-0.0.4.tar.gz", hash = "sha256:e2e9ffaffc1f7767f488d6da34b6f5a377751c996f397902eb6abb99a67bde32"}, ] +[[package]] +name = "fastjsonschema" +version = "2.19.1" +summary = "Fastest Python implementation of JSON schema" +groups = ["dev"] +files = [ + {file = "fastjsonschema-2.19.1-py3-none-any.whl", hash = "sha256:3672b47bc94178c9f23dbb654bf47440155d4db9df5f7bc47643315f9c405cd0"}, + {file = "fastjsonschema-2.19.1.tar.gz", hash = "sha256:e3126a94bdc4623d3de4485f8d468a12f02a67921315ddc87836d6e456dc789d"}, +] + [[package]] name = "filelock" version = "3.14.0" @@ -635,6 +746,56 @@ files = [ {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, ] +[[package]] +name = "ipykernel" +version = "6.29.4" +requires_python = ">=3.8" +summary = "IPython Kernel for Jupyter" +groups = ["dev"] +dependencies = [ + "appnope; platform_system == \"Darwin\"", + "comm>=0.1.1", + "debugpy>=1.6.5", + "ipython>=7.23.1", + "jupyter-client>=6.1.12", + "jupyter-core!=5.0.*,>=4.12", + "matplotlib-inline>=0.1", + "nest-asyncio", + "packaging", + "psutil", + "pyzmq>=24", + "tornado>=6.1", + "traitlets>=5.4.0", +] +files = [ + {file = "ipykernel-6.29.4-py3-none-any.whl", hash = "sha256:1181e653d95c6808039c509ef8e67c4126b3b3af7781496c7cbfb5ed938a27da"}, + {file = "ipykernel-6.29.4.tar.gz", hash = "sha256:3d44070060f9475ac2092b760123fadf105d2e2493c24848b6691a7c4f42af5c"}, +] + +[[package]] +name = "ipython" +version = "8.25.0" +requires_python = ">=3.10" +summary = "IPython: Productive Interactive Computing" +groups = ["dev"] +dependencies = [ + "colorama; sys_platform == \"win32\"", + "decorator", + "exceptiongroup; python_version < \"3.11\"", + "jedi>=0.16", + "matplotlib-inline", + "pexpect>4.3; sys_platform != \"win32\" and sys_platform != \"emscripten\"", + "prompt-toolkit<3.1.0,>=3.0.41", + "pygments>=2.4.0", + "stack-data", + "traitlets>=5.13.0", + "typing-extensions>=4.6; python_version < \"3.12\"", +] +files = [ + {file = "ipython-8.25.0-py3-none-any.whl", hash = "sha256:53eee7ad44df903a06655871cbab66d156a051fd86f3ec6750470ac9604ac1ab"}, + {file = "ipython-8.25.0.tar.gz", hash = "sha256:c6ed726a140b6e725b911528f80439c534fac915246af3efc39440a6b0f9d716"}, +] + [[package]] name = "jaxtyping" version = "0.2.29" @@ -649,6 +810,20 @@ files = [ {file = "jaxtyping-0.2.29.tar.gz", hash = "sha256:e1cd916ed0196e40402b0638449e7d051571562b2cd68d8b94961a383faeb409"}, ] +[[package]] +name = "jedi" +version = "0.19.1" +requires_python = ">=3.6" +summary = "An autocompletion tool for Python that can be used for text editors." +groups = ["dev"] +dependencies = [ + "parso<0.9.0,>=0.8.3", +] +files = [ + {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, + {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, +] + [[package]] name = "jinja2" version = "3.1.4" @@ -663,6 +838,85 @@ files = [ {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, ] +[[package]] +name = "jsonschema" +version = "4.22.0" +requires_python = ">=3.8" +summary = "An implementation of JSON Schema validation for Python" +groups = ["dev"] +dependencies = [ + "attrs>=22.2.0", + "jsonschema-specifications>=2023.03.6", + "referencing>=0.28.4", + "rpds-py>=0.7.1", +] +files = [ + {file = "jsonschema-4.22.0-py3-none-any.whl", hash = "sha256:ff4cfd6b1367a40e7bc6411caec72effadd3db0bbe5017de188f2d6108335802"}, + {file = "jsonschema-4.22.0.tar.gz", hash = "sha256:5b22d434a45935119af990552c862e5d6d564e8f6601206b305a61fdf661a2b7"}, +] + +[[package]] +name = "jsonschema-specifications" +version = "2023.12.1" +requires_python = ">=3.8" +summary = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +groups = ["dev"] +dependencies = [ + "referencing>=0.31.0", +] +files = [ + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, +] + +[[package]] +name = "jupyter-client" +version = "8.6.2" +requires_python = ">=3.8" +summary = "Jupyter protocol implementation and client libraries" +groups = ["dev"] +dependencies = [ + "jupyter-core!=5.0.*,>=4.12", + "python-dateutil>=2.8.2", + "pyzmq>=23.0", + "tornado>=6.2", + "traitlets>=5.3", +] +files = [ + {file = "jupyter_client-8.6.2-py3-none-any.whl", hash = "sha256:50cbc5c66fd1b8f65ecb66bc490ab73217993632809b6e505687de18e9dea39f"}, + {file = "jupyter_client-8.6.2.tar.gz", hash = "sha256:2bda14d55ee5ba58552a8c53ae43d215ad9868853489213f37da060ced54d8df"}, +] + +[[package]] +name = "jupyter-core" +version = "5.7.2" +requires_python = ">=3.8" +summary = "Jupyter core package. A base package on which Jupyter projects rely." +groups = ["dev"] +dependencies = [ + "platformdirs>=2.5", + "pywin32>=300; sys_platform == \"win32\" and platform_python_implementation != \"PyPy\"", + "traitlets>=5.3", +] +files = [ + {file = "jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409"}, + {file = "jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9"}, +] + +[[package]] +name = "kaleido" +version = "0.2.1" +summary = "Static image export for web-based visualization libraries with zero dependencies" +groups = ["dev"] +files = [ + {file = "kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl", hash = "sha256:ca6f73e7ff00aaebf2843f73f1d3bacde1930ef5041093fe76b83a15785049a7"}, + {file = "kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:bb9a5d1f710357d5d432ee240ef6658a6d124c3e610935817b4b42da9c787c05"}, + {file = "kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aa21cf1bf1c78f8fa50a9f7d45e1003c387bd3d6fe0a767cfbbf344b95bdc3a8"}, + {file = "kaleido-0.2.1-py2.py3-none-manylinux2014_aarch64.whl", hash = "sha256:845819844c8082c9469d9c17e42621fbf85c2b237ef8a86ec8a8527f98b6512a"}, + {file = "kaleido-0.2.1-py2.py3-none-win32.whl", hash = "sha256:ecc72635860be616c6b7161807a65c0dbd9b90c6437ac96965831e2e24066552"}, + {file = "kaleido-0.2.1-py2.py3-none-win_amd64.whl", hash = "sha256:4670985f28913c2d063c5734d125ecc28e40810141bdb0a46f15b76c1d45f23c"}, +] + [[package]] name = "kiwisolver" version = "1.4.5" @@ -768,6 +1022,20 @@ files = [ {file = "matplotlib-3.9.0.tar.gz", hash = "sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a"}, ] +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +requires_python = ">=3.8" +summary = "Inline Matplotlib backend for Jupyter" +groups = ["dev"] +dependencies = [ + "traitlets", +] +files = [ + {file = "matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca"}, + {file = "matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90"}, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -906,6 +1174,34 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "nbformat" +version = "5.10.4" +requires_python = ">=3.8" +summary = "The Jupyter Notebook format" +groups = ["dev"] +dependencies = [ + "fastjsonschema>=2.15", + "jsonschema>=2.6", + "jupyter-core!=5.0.*,>=4.12", + "traitlets>=5.1", +] +files = [ + {file = "nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b"}, + {file = "nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a"}, +] + +[[package]] +name = "nest-asyncio" +version = "1.6.0" +requires_python = ">=3.5" +summary = "Patch asyncio to allow nested event loops" +groups = ["dev"] +files = [ + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, +] + [[package]] name = "networkx" version = "3.1" @@ -1165,6 +1461,31 @@ files = [ {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, ] +[[package]] +name = "parso" +version = "0.8.4" +requires_python = ">=3.6" +summary = "A Python Parser" +groups = ["dev"] +files = [ + {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"}, + {file = "parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d"}, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +summary = "Pexpect allows easy control of interactive console applications." +groups = ["dev"] +marker = "sys_platform != \"win32\" and sys_platform != \"emscripten\"" +dependencies = [ + "ptyprocess>=0.5", +] +files = [ + {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, + {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, +] + [[package]] name = "pillow" version = "10.3.0" @@ -1237,6 +1558,20 @@ files = [ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.47" +requires_python = ">=3.7.0" +summary = "Library for building powerful interactive command lines in Python" +groups = ["dev"] +dependencies = [ + "wcwidth", +] +files = [ + {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, + {file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"}, +] + [[package]] name = "protobuf" version = "4.25.3" @@ -1269,6 +1604,27 @@ files = [ {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, ] +[[package]] +name = "ptyprocess" +version = "0.7.0" +summary = "Run a subprocess in a pseudo terminal" +groups = ["dev"] +marker = "sys_platform != \"win32\" and sys_platform != \"emscripten\"" +files = [ + {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, + {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, +] + +[[package]] +name = "pure-eval" +version = "0.2.2" +summary = "Safely evaluate AST nodes without side effects" +groups = ["dev"] +files = [ + {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, + {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, +] + [[package]] name = "pyarrow" version = "16.1.0" @@ -1300,6 +1656,18 @@ files = [ {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, ] +[[package]] +name = "pycparser" +version = "2.22" +requires_python = ">=3.8" +summary = "C parser in Python" +groups = ["dev"] +marker = "implementation_name == \"pypy\"" +files = [ + {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, + {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, +] + [[package]] name = "pydantic" version = "2.7.3" @@ -1466,6 +1834,17 @@ files = [ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] +[[package]] +name = "pywin32" +version = "306" +summary = "Python for Window Extensions" +groups = ["dev"] +marker = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\"" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, +] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1484,6 +1863,67 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "pyzmq" +version = "26.0.3" +requires_python = ">=3.7" +summary = "Python bindings for 0MQ" +groups = ["dev"] +dependencies = [ + "cffi; implementation_name == \"pypy\"", +] +files = [ + {file = "pyzmq-26.0.3-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:44dd6fc3034f1eaa72ece33588867df9e006a7303725a12d64c3dff92330f625"}, + {file = "pyzmq-26.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:acb704195a71ac5ea5ecf2811c9ee19ecdc62b91878528302dd0be1b9451cc90"}, + {file = "pyzmq-26.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dbb9c997932473a27afa93954bb77a9f9b786b4ccf718d903f35da3232317de"}, + {file = "pyzmq-26.0.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6bcb34f869d431799c3ee7d516554797f7760cb2198ecaa89c3f176f72d062be"}, + {file = "pyzmq-26.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ece17ec5f20d7d9b442e5174ae9f020365d01ba7c112205a4d59cf19dc38ee"}, + {file = "pyzmq-26.0.3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ba6e5e6588e49139a0979d03a7deb9c734bde647b9a8808f26acf9c547cab1bf"}, + {file = "pyzmq-26.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3bf8b000a4e2967e6dfdd8656cd0757d18c7e5ce3d16339e550bd462f4857e59"}, + {file = "pyzmq-26.0.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2136f64fbb86451dbbf70223635a468272dd20075f988a102bf8a3f194a411dc"}, + {file = "pyzmq-26.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e8918973fbd34e7814f59143c5f600ecd38b8038161239fd1a3d33d5817a38b8"}, + {file = "pyzmq-26.0.3-cp310-cp310-win32.whl", hash = "sha256:0aaf982e68a7ac284377d051c742610220fd06d330dcd4c4dbb4cdd77c22a537"}, + {file = "pyzmq-26.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:f1a9b7d00fdf60b4039f4455afd031fe85ee8305b019334b72dcf73c567edc47"}, + {file = "pyzmq-26.0.3-cp310-cp310-win_arm64.whl", hash = "sha256:80b12f25d805a919d53efc0a5ad7c0c0326f13b4eae981a5d7b7cc343318ebb7"}, + {file = "pyzmq-26.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2c18645ef6294d99b256806e34653e86236eb266278c8ec8112622b61db255de"}, + {file = "pyzmq-26.0.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e6bc96ebe49604df3ec2c6389cc3876cabe475e6bfc84ced1bf4e630662cb35"}, + {file = "pyzmq-26.0.3-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:971e8990c5cc4ddcff26e149398fc7b0f6a042306e82500f5e8db3b10ce69f84"}, + {file = "pyzmq-26.0.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8416c23161abd94cc7da80c734ad7c9f5dbebdadfdaa77dad78244457448223"}, + {file = "pyzmq-26.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:082a2988364b60bb5de809373098361cf1dbb239623e39e46cb18bc035ed9c0c"}, + {file = "pyzmq-26.0.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d57dfbf9737763b3a60d26e6800e02e04284926329aee8fb01049635e957fe81"}, + {file = "pyzmq-26.0.3-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:77a85dca4c2430ac04dc2a2185c2deb3858a34fe7f403d0a946fa56970cf60a1"}, + {file = "pyzmq-26.0.3-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4c82a6d952a1d555bf4be42b6532927d2a5686dd3c3e280e5f63225ab47ac1f5"}, + {file = "pyzmq-26.0.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4496b1282c70c442809fc1b151977c3d967bfb33e4e17cedbf226d97de18f709"}, + {file = "pyzmq-26.0.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:e4946d6bdb7ba972dfda282f9127e5756d4f299028b1566d1245fa0d438847e6"}, + {file = "pyzmq-26.0.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:03c0ae165e700364b266876d712acb1ac02693acd920afa67da2ebb91a0b3c09"}, + {file = "pyzmq-26.0.3-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:3e3070e680f79887d60feeda051a58d0ac36622e1759f305a41059eff62c6da7"}, + {file = "pyzmq-26.0.3-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6ca08b840fe95d1c2bd9ab92dac5685f949fc6f9ae820ec16193e5ddf603c3b2"}, + {file = "pyzmq-26.0.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e76654e9dbfb835b3518f9938e565c7806976c07b37c33526b574cc1a1050480"}, + {file = "pyzmq-26.0.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:871587bdadd1075b112e697173e946a07d722459d20716ceb3d1bd6c64bd08ce"}, + {file = "pyzmq-26.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d0a2d1bd63a4ad79483049b26514e70fa618ce6115220da9efdff63688808b17"}, + {file = "pyzmq-26.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0270b49b6847f0d106d64b5086e9ad5dc8a902413b5dbbb15d12b60f9c1747a4"}, + {file = "pyzmq-26.0.3-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:703c60b9910488d3d0954ca585c34f541e506a091a41930e663a098d3b794c67"}, + {file = "pyzmq-26.0.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74423631b6be371edfbf7eabb02ab995c2563fee60a80a30829176842e71722a"}, + {file = "pyzmq-26.0.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4adfbb5451196842a88fda3612e2c0414134874bffb1c2ce83ab4242ec9e027d"}, + {file = "pyzmq-26.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3516119f4f9b8671083a70b6afaa0a070f5683e431ab3dc26e9215620d7ca1ad"}, + {file = "pyzmq-26.0.3.tar.gz", hash = "sha256:dba7d9f2e047dfa2bca3b01f4f84aa5246725203d6284e3790f2ca15fba6b40a"}, +] + +[[package]] +name = "referencing" +version = "0.35.1" +requires_python = ">=3.8" +summary = "JSON Referencing + Python" +groups = ["dev"] +dependencies = [ + "attrs>=22.2.0", + "rpds-py>=0.7.0", +] +files = [ + {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, + {file = "referencing-0.35.1.tar.gz", hash = "sha256:25b42124a6c8b632a425174f24087783efb348a6f1e0008e63cd4466fedf703c"}, +] + [[package]] name = "regex" version = "2024.5.15" @@ -1542,6 +1982,62 @@ files = [ {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, ] +[[package]] +name = "rpds-py" +version = "0.18.1" +requires_python = ">=3.8" +summary = "Python bindings to Rust's persistent data structures (rpds)" +groups = ["dev"] +files = [ + {file = "rpds_py-0.18.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d31dea506d718693b6b2cffc0648a8929bdc51c70a311b2770f09611caa10d53"}, + {file = "rpds_py-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:732672fbc449bab754e0b15356c077cc31566df874964d4801ab14f71951ea80"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a98a1f0552b5f227a3d6422dbd61bc6f30db170939bd87ed14f3c339aa6c7c9"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7f1944ce16401aad1e3f7d312247b3d5de7981f634dc9dfe90da72b87d37887d"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38e14fb4e370885c4ecd734f093a2225ee52dc384b86fa55fe3f74638b2cfb09"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08d74b184f9ab6289b87b19fe6a6d1a97fbfea84b8a3e745e87a5de3029bf944"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d70129cef4a8d979caa37e7fe957202e7eee8ea02c5e16455bc9808a59c6b2f0"}, + {file = "rpds_py-0.18.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0bb20e3a11bd04461324a6a798af34d503f8d6f1aa3d2aa8901ceaf039176d"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81c5196a790032e0fc2464c0b4ab95f8610f96f1f2fa3d4deacce6a79852da60"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f3027be483868c99b4985fda802a57a67fdf30c5d9a50338d9db646d590198da"}, + {file = "rpds_py-0.18.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d44607f98caa2961bab4fa3c4309724b185b464cdc3ba6f3d7340bac3ec97cc1"}, + {file = "rpds_py-0.18.1-cp310-none-win32.whl", hash = "sha256:c273e795e7a0f1fddd46e1e3cb8be15634c29ae8ff31c196debb620e1edb9333"}, + {file = "rpds_py-0.18.1-cp310-none-win_amd64.whl", hash = "sha256:8352f48d511de5f973e4f2f9412736d7dea76c69faa6d36bcf885b50c758ab9a"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cbfbea39ba64f5e53ae2915de36f130588bba71245b418060ec3330ebf85678e"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:a3d456ff2a6a4d2adcdf3c1c960a36f4fd2fec6e3b4902a42a384d17cf4e7a65"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7700936ef9d006b7ef605dc53aa364da2de5a3aa65516a1f3ce73bf82ecfc7ae"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:51584acc5916212e1bf45edd17f3a6b05fe0cbb40482d25e619f824dccb679de"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:942695a206a58d2575033ff1e42b12b2aece98d6003c6bc739fbf33d1773b12f"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b906b5f58892813e5ba5c6056d6a5ad08f358ba49f046d910ad992196ea61397"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6f8e3fecca256fefc91bb6765a693d96692459d7d4c644660a9fff32e517843"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7732770412bab81c5a9f6d20aeb60ae943a9b36dcd990d876a773526468e7163"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:bd1105b50ede37461c1d51b9698c4f4be6e13e69a908ab7751e3807985fc0346"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:618916f5535784960f3ecf8111581f4ad31d347c3de66d02e728de460a46303c"}, + {file = "rpds_py-0.18.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:17c6d2155e2423f7e79e3bb18151c686d40db42d8645e7977442170c360194d4"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6c4c4c3f878df21faf5fac86eda32671c27889e13570645a9eea0a1abdd50922"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:fab6ce90574645a0d6c58890e9bcaac8d94dff54fb51c69e5522a7358b80ab64"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:531796fb842b53f2695e94dc338929e9f9dbf473b64710c28af5a160b2a8927d"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:740884bc62a5e2bbb31e584f5d23b32320fd75d79f916f15a788d527a5e83644"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:998125738de0158f088aef3cb264a34251908dd2e5d9966774fdab7402edfab7"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2be6e9dd4111d5b31ba3b74d17da54a8319d8168890fbaea4b9e5c3de630ae5"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0cee71bc618cd93716f3c1bf56653740d2d13ddbd47673efa8bf41435a60daa"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2c3caec4ec5cd1d18e5dd6ae5194d24ed12785212a90b37f5f7f06b8bedd7139"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:27bba383e8c5231cd559affe169ca0b96ec78d39909ffd817f28b166d7ddd4d8"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:a888e8bdb45916234b99da2d859566f1e8a1d2275a801bb8e4a9644e3c7e7909"}, + {file = "rpds_py-0.18.1-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:6031b25fb1b06327b43d841f33842b383beba399884f8228a6bb3df3088485ff"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48c2faaa8adfacefcbfdb5f2e2e7bdad081e5ace8d182e5f4ade971f128e6bb3"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:d85164315bd68c0806768dc6bb0429c6f95c354f87485ee3593c4f6b14def2bd"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6afd80f6c79893cfc0574956f78a0add8c76e3696f2d6a15bca2c66c415cf2d4"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa242ac1ff583e4ec7771141606aafc92b361cd90a05c30d93e343a0c2d82a89"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21be4770ff4e08698e1e8e0bce06edb6ea0626e7c8f560bc08222880aca6a6f"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c45a639e93a0c5d4b788b2613bd637468edd62f8f95ebc6fcc303d58ab3f0a8"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910e71711d1055b2768181efa0a17537b2622afeb0424116619817007f8a2b10"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b9bb1f182a97880f6078283b3505a707057c42bf55d8fca604f70dedfdc0772a"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1d54f74f40b1f7aaa595a02ff42ef38ca654b1469bef7d52867da474243cc633"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:8d2e182c9ee01135e11e9676e9a62dfad791a7a467738f06726872374a83db49"}, + {file = "rpds_py-0.18.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:636a15acc588f70fda1661234761f9ed9ad79ebed3f2125d44be0862708b666e"}, + {file = "rpds_py-0.18.1.tar.gz", hash = "sha256:dc48b479d540770c811fbd1eb9ba2bb66951863e448efec2e2c102625328e92f"}, +] + [[package]] name = "safetensors" version = "0.4.3" @@ -1716,6 +2212,21 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "stack-data" +version = "0.6.3" +summary = "Extract data from python stack frames and tracebacks for informative displays" +groups = ["dev"] +dependencies = [ + "asttokens>=2.1.0", + "executing>=1.2.0", + "pure-eval", +] +files = [ + {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, + {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, +] + [[package]] name = "starlette" version = "0.37.2" @@ -1902,6 +2413,26 @@ files = [ {file = "torch-2.3.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5"}, ] +[[package]] +name = "tornado" +version = "6.4.1" +requires_python = ">=3.8" +summary = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +groups = ["dev"] +files = [ + {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8"}, + {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6d5ce3437e18a2b66fbadb183c1d3364fb03f2be71299e7d10dbeeb69f4b2a14"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e20b9113cd7293f164dc46fffb13535266e713cdb87bd2d15ddb336e96cfc4"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae50a504a740365267b2a8d1a90c9fbc86b780a39170feca9bcc1787ff80842"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:613bf4ddf5c7a95509218b149b555621497a6cc0d46ac341b30bd9ec19eac7f3"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:454db8a7ecfcf2ff6042dde58404164d969b6f5d58b926da15e6b23817950fc4"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a02a08cc7a9314b006f653ce40483b9b3c12cda222d6a46d4ac63bb6c9057698"}, + {file = "tornado-6.4.1-cp38-abi3-win32.whl", hash = "sha256:d9a566c40b89757c9aa8e6f032bcdb8ca8795d7c1a9762910c722b1635c9de4d"}, + {file = "tornado-6.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:b24b8982ed444378d7f21d563f4180a2de31ced9d8d84443907a0a64da2072e7"}, + {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"}, +] + [[package]] name = "tqdm" version = "4.66.4" @@ -1916,6 +2447,17 @@ files = [ {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, ] +[[package]] +name = "traitlets" +version = "5.14.3" +requires_python = ">=3.8" +summary = "Traitlets Python configuration system" +groups = ["dev"] +files = [ + {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, + {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, +] + [[package]] name = "transformer-lens" version = "0.0.0" @@ -2200,6 +2742,16 @@ files = [ {file = "watchfiles-0.22.0.tar.gz", hash = "sha256:988e981aaab4f3955209e7e28c7794acdb690be1efa7f16f8ea5aba7ffdadacb"}, ] +[[package]] +name = "wcwidth" +version = "0.2.13" +summary = "Measures the displayed width of unicode strings in a terminal" +groups = ["dev"] +files = [ + {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, + {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, +] + [[package]] name = "websockets" version = "12.0" diff --git a/pyproject.toml b/pyproject.toml index 608ce12..b96f09c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "python-dotenv>=1.0.1", "jaxtyping>=0.2.25", "safetensors>=0.4.3", + "pydantic>=2.7.3", ] requires-python = "==3.10.*" readme = "README.md" @@ -41,6 +42,9 @@ license = {text = "MIT"} dev = [ "-e file:///${PROJECT_ROOT}/TransformerLens#egg=transformer-lens", "mypy>=1.10.0", + "ipykernel>=6.29.4", + "nbformat>=5.10.4", + "kaleido==0.2.1", ] [tool.mypy] diff --git a/src/lm_saes/activation/activation_dataset.py b/src/lm_saes/activation/activation_dataset.py index 72e33d4..9caba45 100644 --- a/src/lm_saes/activation/activation_dataset.py +++ b/src/lm_saes/activation/activation_dataset.py @@ -15,12 +15,12 @@ def make_activation_dataset( model: HookedTransformer, cfg: ActivationGenerationConfig ): - element_size = torch.finfo(cfg.dtype).bits / 8 - token_act_size = element_size * cfg.d_model + element_size = torch.finfo(cfg.lm.dtype).bits / 8 + token_act_size = element_size * cfg.lm.d_model max_tokens_per_chunk = cfg.chunk_size // token_act_size print_once(f"Making activation dataset with approximately {max_tokens_per_chunk} tokens per chunk") - token_source = TokenSource.from_config(model=model, cfg=cfg) + token_source = TokenSource.from_config(model=model, cfg=cfg.dataset) if not cfg.use_ddp or cfg.rank == 0: for hook_point in cfg.hook_points: @@ -37,13 +37,13 @@ def make_activation_dataset( pbar = tqdm(total=total_generating_tokens, desc=f"Activation dataset Rank {cfg.rank}" if cfg.use_ddp else "Activation dataset") while n_tokens < total_generating_tokens: - act_dict = {hook_point: torch.empty((0, cfg.context_size, cfg.d_model), dtype=cfg.dtype, device=cfg.device) for hook_point in cfg.hook_points} - context = torch.empty((0, cfg.context_size), dtype=torch.long, device=cfg.device) + act_dict = {hook_point: torch.empty((0, cfg.dataset.context_size, cfg.lm.d_model), dtype=cfg.lm.dtype, device=cfg.lm.device) for hook_point in cfg.hook_points} + context = torch.empty((0, cfg.dataset.context_size), dtype=torch.long, device=cfg.lm.device) n_tokens_in_chunk = 0 while n_tokens_in_chunk < max_tokens_per_chunk: - tokens = token_source.next(cfg.store_batch_size) + tokens = token_source.next(cfg.dataset.store_batch_size) _, cache = model.run_with_cache_until(tokens, names_filter=cfg.hook_points, until=cfg.hook_points[-1]) for hook_point in cfg.hook_points: act = cache[hook_point] @@ -54,7 +54,7 @@ def make_activation_dataset( pbar.update(tokens.size(0) * tokens.size(1)) - position = torch.arange(cfg.context_size, device=cfg.device, dtype=torch.long).unsqueeze(0).expand(context.size(0), -1) + position = torch.arange(cfg.dataset.context_size, device=cfg.lm.device, dtype=torch.long).unsqueeze(0).expand(context.size(0), -1) for hook_point in cfg.hook_points: torch.save( diff --git a/src/lm_saes/activation/activation_source.py b/src/lm_saes/activation/activation_source.py index b19d6bd..ce8a01c 100644 --- a/src/lm_saes/activation/activation_source.py +++ b/src/lm_saes/activation/activation_source.py @@ -35,12 +35,12 @@ class TokenActivationSource(ActivationSource): An activation source that generates activations from a token source. """ def __init__(self, model: HookedTransformer, cfg: ActivationStoreConfig): - self.token_source = TokenSource.from_config(model=model, cfg=cfg) + self.token_source = TokenSource.from_config(model=model, cfg=cfg.dataset) self.model = model self.cfg = cfg def next(self) -> Dict[str, torch.Tensor] | None: - tokens = self.token_source.next(self.cfg.store_batch_size) + tokens = self.token_source.next(self.cfg.dataset.store_batch_size) if tokens is None: return None @@ -70,7 +70,7 @@ def __init__(self, cfg: ActivationStoreConfig): self.chunk_paths = [p for i, p in enumerate(self.chunk_paths) if i % cfg.world_size == cfg.rank] random.shuffle(self.chunk_paths) - self.token_buffer = torch.empty((0, cfg.context_size), dtype=torch.long, device=cfg.device) + self.token_buffer = torch.empty((0, cfg.dataset.context_size), dtype=torch.long, device=cfg.device) def _load_next_chunk(self): if len(self.chunk_paths) == 0: diff --git a/src/lm_saes/activation/activation_store.py b/src/lm_saes/activation/activation_store.py index 75527fa..ed2c49e 100644 --- a/src/lm_saes/activation/activation_store.py +++ b/src/lm_saes/activation/activation_store.py @@ -85,7 +85,7 @@ def from_config(model: HookedTransformer, cfg: ActivationStoreConfig): ) return ActivationStore( act_source=act_source, - d_model=cfg.d_model, + d_model=cfg.lm.d_model, n_tokens_in_buffer=cfg.n_tokens_in_buffer, device=cfg.device, use_ddp=cfg.use_ddp, diff --git a/src/lm_saes/activation/token_source.py b/src/lm_saes/activation/token_source.py index 6c94e73..a215c02 100644 --- a/src/lm_saes/activation/token_source.py +++ b/src/lm_saes/activation/token_source.py @@ -16,7 +16,6 @@ def __init__( is_dataset_tokenized: bool, concat_tokens: list[bool], seq_len: int, - device: str, sample_probs: list[float], ): self.dataloader = dataloader @@ -24,7 +23,7 @@ def __init__( self.is_dataset_tokenized = is_dataset_tokenized self.concat_tokens = concat_tokens self.seq_len = seq_len - self.device = device + self.device = model.cfg.device self.data_iter = [iter(dataloader) for dataloader in self.dataloader] @@ -121,37 +120,5 @@ def from_config(model: HookedTransformer, cfg: TextDatasetConfig): is_dataset_tokenized=cfg.is_dataset_tokenized, concat_tokens=cfg.concat_tokens, seq_len=cfg.context_size, - device=cfg.device, sample_probs=cfg.sample_probs, - ) - - -if __name__ == "__main__": - from lm_saes.config import LanguageModelSAETrainingConfig - from transformer_lens import HookedTransformer - import os - - - if os.path.exists("./results/test"): - import shutil - - shutil.rmtree("./results/test") - - - cfg = LanguageModelSAETrainingConfig( - dataset_path=[], - concat_tokens=[True, False], - sample_probs=[0.5, 0.5], - is_dataset_on_disk=True, - is_dataset_tokenized=False, - store_batch_size=1, - context_size=16, - device="cuda", - use_ddp=False, - ) - - model = HookedTransformer.from_pretrained("gpt2", cfg.device) - token_source = TokenSource.from_config(model, cfg) - - for i in range(5): - print(model.tokenizer.batch_decode(token_source.next(2).cpu().numpy())) \ No newline at end of file + ) \ No newline at end of file diff --git a/src/lm_saes/analysis/auto_interp.py b/src/lm_saes/analysis/auto_interp.py index da847f7..13e191c 100644 --- a/src/lm_saes/analysis/auto_interp.py +++ b/src/lm_saes/analysis/auto_interp.py @@ -118,7 +118,7 @@ def generate_description( cfg: AutoInterpConfig, ): tokenizer = model.tokenizer - client = OpenAI(api_key=cfg.openai_api_key, base_url=cfg.openai_base_url) + client = OpenAI(api_key=cfg.openai.openai_api_key, base_url=cfg.openai.openai_base_url) prompt = _sample_sentences( cfg, tokenizer, feature_activation ) @@ -151,7 +151,7 @@ def check_description( Otherwise, a `feature_activations` dataset is required for further processing. """ tokenizer = model.tokenizer - client = OpenAI(api_key=cfg.openai_api_key, base_url=cfg.openai_base_url) + client = OpenAI(api_key=cfg.openai.openai_api_key, base_url=cfg.openai.openai_base_url) if using_sae: assert sae is not None, "Sparse Auto Encoder is not provided." prompt_prefix = "We are analyzing the activation levels of features in a neural network, where each feature activates certain tokens in a text. Each token's activation value indicates its relevance to the feature, with higher values showing stronger association. We will describe a feature's meaning and traits. Your output must be multiple sentences that activates the feature." @@ -162,8 +162,8 @@ def check_description( cost = _calculate_cost(input_tokens, output_tokens) input_index, input_text = index, response input_token = model.to_tokens(input_text) - _, cache = model.run_with_cache_until(input_token, names_filter=[cfg.hook_point_in, cfg.hook_point_out], until=cfg.hook_point_out) - activation_in, activation_out = cache[cfg.hook_point_in][0], cache[cfg.hook_point_out][0] + _, cache = model.run_with_cache_until(input_token, names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out) + activation_in, activation_out = cache[cfg.sae.hook_point_in][0], cache[cfg.sae.hook_point_out][0] feature_acts = sae.encode(activation_in, label=activation_out) max_value, max_pos = torch.max(feature_acts, dim=0) passed = torch.max(feature_acts) > 1 diff --git a/src/lm_saes/analysis/features_to_logits.py b/src/lm_saes/analysis/features_to_logits.py index 8bd7284..f9236b2 100644 --- a/src/lm_saes/analysis/features_to_logits.py +++ b/src/lm_saes/analysis/features_to_logits.py @@ -8,7 +8,7 @@ def features_to_logits(sae: SparseAutoEncoder, model: HookedTransformer, cfg: FeaturesDecoderConfig): num_ones = int(torch.sum(sae.feature_act_mask).item()) - feature_acts = torch.zeros(num_ones, cfg.d_sae).to(cfg.device) + feature_acts = torch.zeros(num_ones, cfg.sae.d_sae).to(cfg.sae.device) index = 0 for i in range(len(sae.feature_act_mask)): diff --git a/src/lm_saes/analysis/sample_feature_activations.py b/src/lm_saes/analysis/sample_feature_activations.py index 2c24940..7308c00 100644 --- a/src/lm_saes/analysis/sample_feature_activations.py +++ b/src/lm_saes/analysis/sample_feature_activations.py @@ -28,10 +28,10 @@ def sample_feature_activations( ): if cfg.use_ddp: raise ValueError("Sampling feature activations does not support DDP yet") - assert cfg.d_sae is not None # Make mypy happy + assert cfg.sae.d_sae is not None # Make mypy happy total_analyzing_tokens = cfg.total_analyzing_tokens - total_analyzing_steps = total_analyzing_tokens // cfg.store_batch_size // cfg.context_size + total_analyzing_steps = total_analyzing_tokens // cfg.act_store.dataset.store_batch_size // cfg.act_store.dataset.context_size print_once(f"Total Analyzing Tokens: {total_analyzing_tokens}") print_once(f"Total Analyzing Steps: {total_analyzing_steps}") @@ -43,27 +43,27 @@ def sample_feature_activations( pbar = tqdm(total=total_analyzing_tokens, desc=f"Sampling activations of chunk {sae_chunk_id} of {n_sae_chunks}", smoothing=0.01) - d_sae = cfg.d_sae // n_sae_chunks + d_sae = cfg.sae.d_sae // n_sae_chunks start_index = sae_chunk_id * d_sae end_index = (sae_chunk_id + 1) * d_sae sample_result = {k: { - "elt": torch.empty((0, d_sae), dtype=cfg.dtype, device=cfg.device), - "feature_acts": torch.empty((0, d_sae, cfg.context_size), dtype=cfg.dtype, device=cfg.device), - "contexts": torch.empty((0, d_sae, cfg.context_size), dtype=torch.int32, device=cfg.device), + "elt": torch.empty((0, d_sae), dtype=cfg.sae.dtype, device=cfg.sae.device), + "feature_acts": torch.empty((0, d_sae, cfg.act_store.dataset.context_size), dtype=cfg.sae.dtype, device=cfg.sae.device), + "contexts": torch.empty((0, d_sae, cfg.act_store.dataset.context_size), dtype=torch.int32, device=cfg.sae.device), } for k in cfg.subsample.keys()} - act_times = torch.zeros((d_sae,), dtype=torch.long, device=cfg.device) - feature_acts_all = [torch.empty((0,), dtype=cfg.dtype, device=cfg.device) for _ in range(d_sae)] - max_feature_acts = torch.zeros((d_sae,), dtype=cfg.dtype, device=cfg.device) + act_times = torch.zeros((d_sae,), dtype=torch.long, device=cfg.sae.device) + feature_acts_all = [torch.empty((0,), dtype=cfg.sae.dtype, device=cfg.sae.device) for _ in range(d_sae)] + max_feature_acts = torch.zeros((d_sae,), dtype=cfg.sae.dtype, device=cfg.sae.device) while n_training_tokens < total_analyzing_tokens: - batch = activation_store.next_tokens(cfg.store_batch_size) + batch = activation_store.next_tokens(cfg.act_store.dataset.store_batch_size) if batch is None: raise ValueError("Not enough tokens to sample") - _, cache = model.run_with_cache_until(batch, names_filter=[cfg.hook_point_in, cfg.hook_point_out], until=cfg.hook_point_out) - activation_in, activation_out = cache[cfg.hook_point_in], cache[cfg.hook_point_out] + _, cache = model.run_with_cache_until(batch, names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out) + activation_in, activation_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out] feature_acts = sae.encode(activation_in, label=activation_out)[..., start_index: end_index] @@ -73,7 +73,7 @@ def sample_feature_activations( if cfg.enable_sampling: weights = feature_acts.clamp(min=0.0).pow(cfg.sample_weight_exponent).max(dim=1).values - elt = torch.rand(batch.size(0), d_sae, device=cfg.device, dtype=cfg.dtype).log() / weights + elt = torch.rand(batch.size(0), d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype).log() / weights elt[weights == 0.0] = -torch.inf else: elt = feature_acts.clamp(min=0.0).max(dim=1).values @@ -107,7 +107,7 @@ def sample_feature_activations( max_feature_acts = torch.max(max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values) - n_tokens_current = torch.tensor(batch.size(0) * batch.size(1), device=cfg.device, dtype=torch.int) + n_tokens_current = torch.tensor(batch.size(0) * batch.size(1), device=cfg.sae.device, dtype=torch.int) n_training_tokens += cast(int, n_tokens_current.item()) n_training_steps += 1 @@ -120,7 +120,7 @@ def sample_feature_activations( } for k1, v1 in sample_result.items()} result = { - "index": torch.arange(start_index, end_index, device=cfg.device, dtype=torch.int32), + "index": torch.arange(start_index, end_index, device=cfg.sae.device, dtype=torch.int32), "act_times": act_times, "feature_acts_all": feature_acts_all, "max_feature_acts": max_feature_acts, diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 205d698..110da85 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -8,18 +8,19 @@ import os +from lm_saes.utils.config import FlattenableModel from lm_saes.utils.huggingface import parse_pretrained_name_or_path from lm_saes.utils.misc import print_once from transformer_lens.loading_from_pretrained import get_official_model_name -@dataclass -class BaseConfig: +@dataclass(kw_only=True) +class BaseConfig(FlattenableModel): def __post_init__(self): pass -@dataclass +@dataclass(kw_only=True) class BaseModelConfig(BaseConfig): device: str = "cpu" seed: int = 42 @@ -37,7 +38,7 @@ def from_dict(cls, d: Dict[str, Any], **kwargs): d = {k: v for k, v in d.items() if k in [field.name for field in fields(cls)]} return cls(**d, **kwargs) -@dataclass +@dataclass(kw_only=True) class RunnerConfig(BaseConfig): use_ddp: bool = False @@ -59,7 +60,7 @@ def __post_init__(self): os.makedirs(os.path.join(self.exp_result_dir, self.exp_name), exist_ok=True) -@dataclass +@dataclass(kw_only=True) class LanguageModelConfig(BaseModelConfig): model_name: str = "gpt2" model_from_pretrained_path: Optional[str] = None @@ -84,19 +85,14 @@ def from_pretrained_sae(pretrained_name_or_path: str, **kwargs): lm_config = json.load(f) return LanguageModelConfig.from_dict(lm_config, **kwargs) - def save_lm_config(self, sae_path: Optional[str] = None): - if sae_path is None: - if isinstance(self, RunnerConfig): - sae_path = os.path.join(self.exp_result_dir, self.exp_name) - else: - raise ValueError("sae_path must be specified if not called from a RunnerConfig.") + def save_lm_config(self, sae_path: str): assert os.path.exists(sae_path), f"{sae_path} does not exist. Unable to save LanguageModelConfig." with open(os.path.join(sae_path, "lm_config.json"), "w") as f: json.dump(self.to_dict(), f, indent=4) -@dataclass -class TextDatasetConfig(BaseModelConfig, RunnerConfig): +@dataclass(kw_only=True) +class TextDatasetConfig(RunnerConfig): dataset_path: List[str] = 'openwebtext' # type: ignore cache_dir: Optional[str] = None is_dataset_tokenized: bool = False @@ -120,8 +116,10 @@ def __post_init__(self): assert len(self.concat_tokens) == len(self.dataset_path), "Number of concat_tokens must match number of dataset paths" -@dataclass -class ActivationStoreConfig(LanguageModelConfig, TextDatasetConfig): +@dataclass(kw_only=True) +class ActivationStoreConfig(BaseModelConfig, RunnerConfig): + lm: LanguageModelConfig + dataset: TextDatasetConfig hook_points: List[str] = field(default_factory=lambda: ["blocks.0.hook_resid_pre"]) """ Hook points to store activations from, i.e. the layer output of which is used for training/evaluating the dictionary. Will run until the last hook point in the list, so make sure to order them correctly. """ @@ -135,25 +133,19 @@ def __post_init__(self): # Autofill cached_activations_path unless the user overrode it if self.cached_activations_path is None: self.cached_activations_path = [ - f"activations/{path.split('/')[-1]}/{self.model_name.replace('/', '_')}_{self.context_size}" - for path in self.dataset_path + f"activations/{path.split('/')[-1]}/{self.lm.model_name.replace('/', '_')}_{self.dataset.context_size}" + for path in self.dataset.dataset_path ] -@dataclass +@dataclass(kw_only=True) class WandbConfig(BaseConfig): log_to_wandb: bool = True wandb_project: str = "gpt2-sae-training" - run_name: Optional[str] = None + exp_name: Optional[str] = None wandb_entity: Optional[str] = None - def __post_init__(self): - super().__post_init__() - if self.run_name is None and isinstance(self, RunnerConfig): - self.run_name = self.exp_name - - -@dataclass +@dataclass(kw_only=True) class SAEConfig(BaseModelConfig): """ Configuration for training or running a sparse autoencoder. @@ -223,12 +215,7 @@ def get_hyperparameters( } return hyperparams - def save_hyperparameters(self, sae_path: Optional[str] = None, remove_loading_info: bool = True): - if sae_path is None: - if isinstance(self, RunnerConfig): - sae_path = os.path.join(self.exp_result_dir, self.exp_name) - else: - raise ValueError("sae_path must be specified if not called from a RunnerConfig.") + def save_hyperparameters(self, sae_path: str, remove_loading_info: bool = True): assert os.path.exists(sae_path), f"{sae_path} does not exist. Unable to save hyperparameters." d = self.to_dict() if remove_loading_info: @@ -237,25 +224,30 @@ def save_hyperparameters(self, sae_path: Optional[str] = None, remove_loading_in with open(os.path.join(sae_path, "hyperparams.json"), "w") as f: json.dump(d, f, indent=4) -@dataclass +@dataclass(kw_only=True) class OpenAIConfig(BaseConfig): openai_api_key: str openai_base_url: str -@dataclass -class AutoInterpConfig(SAEConfig, LanguageModelConfig, OpenAIConfig): +@dataclass(kw_only=True) +class AutoInterpConfig(BaseConfig): + sae: SAEConfig + lm: LanguageModelConfig + openai: OpenAIConfig num_sample: int = 10 p: float = 0.7 num_left_token: int = 10 num_right_token: int = 5 -@dataclass -class LanguageModelSAERunnerConfig(SAEConfig, WandbConfig, ActivationStoreConfig, RunnerConfig): - pass - +@dataclass(kw_only=True) +class LanguageModelSAERunnerConfig(RunnerConfig): + sae: SAEConfig + lm: LanguageModelConfig + act_store: ActivationStoreConfig + wandb: WandbConfig -@dataclass +@dataclass(kw_only=True) class LanguageModelSAETrainingConfig(LanguageModelSAERunnerConfig): """ Configuration for training a sparse autoencoder on a language model. @@ -312,11 +304,7 @@ def __post_init__(self): total_training_steps = self.total_training_tokens // self.effective_batch_size print_once(f"Total training steps: {total_training_steps}") - if self.use_ghost_grads: - print_once("Using Ghost Grads.") - - -@dataclass +@dataclass(kw_only=True) class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): """ Configuration for pruning a sparse autoencoder on a language model. @@ -330,8 +318,11 @@ class LanguageModelSAEPruningConfig(LanguageModelSAERunnerConfig): decoder_norm_threshold: float = 0.99 -@dataclass -class ActivationGenerationConfig(LanguageModelConfig, TextDatasetConfig): +@dataclass(kw_only=True) +class ActivationGenerationConfig(RunnerConfig): + lm: LanguageModelConfig + dataset: TextDatasetConfig + hook_points: list[str] = field(default_factory=list) activation_save_path: str = None # type: ignore @@ -347,17 +338,22 @@ def __post_init__(self): self.activation_save_path = f"activations/{self.dataset_path[0].split('/')[-1]}/{self.model_name.replace('/', '_')}_{self.context_size}" os.makedirs(self.activation_save_path, exist_ok=True) -@dataclass +@dataclass(kw_only=True) class MongoConfig(BaseConfig): mongo_uri: str = "mongodb://localhost:27017" mongo_db: str = "mechinterp" -@dataclass -class LanguageModelSAEAnalysisConfig(SAEConfig, ActivationStoreConfig, MongoConfig, RunnerConfig): +@dataclass(kw_only=True) +class LanguageModelSAEAnalysisConfig(RunnerConfig): """ Configuration for analyzing a sparse autoencoder on a language model. """ + sae: SAEConfig + lm: LanguageModelConfig + act_store: ActivationStoreConfig + mongo: MongoConfig + total_analyzing_tokens: int = 300_000_000 enable_sampling: bool = ( False # If True, we will sample the activations based on weights. Otherwise, top n_samples activations will be used. @@ -369,9 +365,12 @@ class LanguageModelSAEAnalysisConfig(SAEConfig, ActivationStoreConfig, MongoConf def __post_init__(self): super().__post_init__() - assert self.d_sae % self.n_sae_chunks == 0, f"d_sae ({self.d_sae}) must be divisible by n_sae_chunks ({self.n_sae_chunks})" + assert self.sae.d_sae % self.n_sae_chunks == 0, f"d_sae ({self.sae.d_sae}) must be divisible by n_sae_chunks ({self.n_sae_chunks})" -@dataclass -class FeaturesDecoderConfig(SAEConfig, LanguageModelConfig, MongoConfig, RunnerConfig): +@dataclass(kw_only=True) +class FeaturesDecoderConfig(RunnerConfig): + sae: SAEConfig + lm: LanguageModelConfig + mongo: MongoConfig top: int = 10 diff --git a/src/lm_saes/evals.py b/src/lm_saes/evals.py index ea2e286..5ac765d 100644 --- a/src/lm_saes/evals.py +++ b/src/lm_saes/evals.py @@ -22,7 +22,7 @@ def run_evals( n_training_steps: int, ): ### Evals - eval_tokens = activation_store.next_tokens(cfg.store_batch_size) + eval_tokens = activation_store.next_tokens(cfg.act_store.dataset.store_batch_size) # Get Reconstruction Score losses_df = recons_loss_batched( @@ -42,12 +42,12 @@ def run_evals( _, cache = model.run_with_cache_until( eval_tokens, prepend_bos=False, - names_filter=[cfg.hook_point_in, cfg.hook_point_out], - until=cfg.hook_point_out, + names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], + until=cfg.sae.hook_point_out, ) # get act - original_act_in, original_act_out = cache[cfg.hook_point_in], cache[cfg.hook_point_out] + original_act_in, original_act_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out] feature_acts = sae.encode(original_act_in, label=original_act_out) reconstructed = sae.decode(feature_acts) @@ -81,7 +81,7 @@ def run_evals( "metrics/ce_loss_with_ablation": zero_abl_loss, } - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): wandb.log( metrics, step=n_training_steps + 1, @@ -100,7 +100,7 @@ def recons_loss_batched( if (not cfg.use_ddp or cfg.rank == 0): pbar = tqdm(total=n_batches, desc="Evaluation", smoothing=0.01) for _ in range(n_batches): - batch_tokens = activation_store.next_tokens(cfg.store_batch_size) + batch_tokens = activation_store.next_tokens(cfg.act_store.dataset.store_batch_size) assert batch_tokens is not None, "Not enough tokens in the store" score, loss, recons_loss, zero_abl_loss = get_recons_loss( model, sae, cfg, batch_tokens @@ -144,10 +144,10 @@ def get_recons_loss( _, cache = model.run_with_cache_until( batch_tokens, prepend_bos=False, - names_filter=[cfg.hook_point_in, cfg.hook_point_out], - until=cfg.hook_point_out, + names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], + until=cfg.sae.hook_point_out, ) - activations_in, activations_out = cache[cfg.hook_point_in], cache[cfg.hook_point_out] + activations_in, activations_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out] replacements = sae.forward(activations_in, label=activations_out).to(activations_out.dtype) def replacement_hook(activations: torch.Tensor, hook: Any): @@ -156,11 +156,11 @@ def replacement_hook(activations: torch.Tensor, hook: Any): recons_loss: torch.Tensor = model.run_with_hooks( batch_tokens, return_type="loss", - fwd_hooks=[(cfg.hook_point_out, replacement_hook)], + fwd_hooks=[(cfg.sae.hook_point_out, replacement_hook)], ) zero_abl_loss: torch.Tensor = model.run_with_hooks( - batch_tokens, return_type="loss", fwd_hooks=[(cfg.hook_point_out, zero_ablate_hook)] + batch_tokens, return_type="loss", fwd_hooks=[(cfg.sae.hook_point_out, zero_ablate_hook)] ) score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index 336ebb9..881bf47 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -3,6 +3,8 @@ import wandb +from dataclasses import asdict + import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -29,9 +31,9 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): - cfg.save_hyperparameters() - cfg.save_lm_config() - sae = SparseAutoEncoder.from_config(cfg=cfg) + cfg.sae.save_hyperparameters(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + cfg.lm.save_lm_config(os.path.join(cfg.exp_result_dir, cfg.exp_name)) + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) if cfg.finetuning: # Fine-tune SAE with frozen encoder weights and bias @@ -39,42 +41,50 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): hf_model = AutoModelForCausalLM.from_pretrained( ( - cfg.model_name - if cfg.model_from_pretrained_path is None - else cfg.model_from_pretrained_path + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path ), - cache_dir=cfg.cache_dir, - local_files_only=cfg.local_files_only, - torch_dtype=cfg.dtype, + cache_dir=cfg.lm.cache_dir, + local_files_only=cfg.lm.local_files_only, + torch_dtype=cfg.lm.dtype, ) hf_tokenizer = AutoTokenizer.from_pretrained( ( - cfg.model_name - if cfg.model_from_pretrained_path is None - else cfg.model_from_pretrained_path + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path ), trust_remote_code=True, use_fast=True, add_bos_token=True, ) + model = HookedTransformer.from_pretrained( - cfg.model_name, - device=cfg.device, - cache_dir=cfg.cache_dir, + cfg.lm.model_name, + device=cfg.lm.device, + cache_dir=cfg.lm.cache_dir, hf_model=hf_model, tokenizer=hf_tokenizer, - dtype=cfg.dtype, + dtype=cfg.lm.dtype, ) model.eval() - activation_store = ActivationStore.from_config(model=model, cfg=cfg) - - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + wandb_config: dict = { + **asdict(cfg), + **asdict(cfg.sae), + **asdict(cfg.lm), + } + del wandb_config["sae"] + del wandb_config["lm"] wandb_run = wandb.init( - project=cfg.wandb_project, - config=cast(Any, cfg), - name=cfg.run_name, - entity=cfg.wandb_entity, + project=cfg.wandb.wandb_project, + config=wandb_config, + name=cfg.wandb.exp_name, + entity=cfg.wandb.wandb_entity, ) with open( os.path.join(cfg.exp_result_dir, cfg.exp_name, "train_wandb_id.txt"), "w" @@ -90,35 +100,45 @@ def language_model_sae_runner(cfg: LanguageModelSAETrainingConfig): cfg, ) - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): wandb.finish() return sae def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): - sae = SparseAutoEncoder.from_config(cfg=cfg) + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) hf_model = AutoModelForCausalLM.from_pretrained( - cfg.model_name, - cache_dir=cfg.cache_dir, - local_files_only=cfg.local_files_only, - torch_dtype=cfg.dtype, + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + cache_dir=cfg.lm.cache_dir, + local_files_only=cfg.lm.local_files_only, ) model = HookedTransformer.from_pretrained( - cfg.model_name, - device=cfg.device, - cache_dir=cfg.cache_dir, + cfg.lm.model_name, + device=cfg.lm.device, + cache_dir=cfg.lm.cache_dir, hf_model=hf_model, - dtype=cfg.dtype, + dtype=cfg.lm.dtype, ) model.eval() - activation_store = ActivationStore.from_config(model=model, cfg=cfg) - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + wandb_config: dict = { + **asdict(cfg), + **asdict(cfg.sae), + **asdict(cfg.lm), + } + del wandb_config["sae"] + del wandb_config["lm"] wandb_run = wandb.init( - project=cfg.wandb_project, - config=cast(Any, cfg), - name=cfg.run_name, - entity=cfg.wandb_entity, + project=cfg.wandb.wandb_project, + config=wandb_config, + name=cfg.wandb.exp_name, + entity=cfg.wandb.wandb_entity, ) with open( os.path.join(cfg.exp_result_dir, cfg.exp_name, "prune_wandb_id.txt"), "w" @@ -138,31 +158,44 @@ def language_model_sae_prune_runner(cfg: LanguageModelSAEPruningConfig): for key, value in result.items(): print(f"{key}: {value}") - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): wandb.finish() def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): - sae = SparseAutoEncoder.from_config(cfg=cfg) + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) hf_model = AutoModelForCausalLM.from_pretrained( - cfg.model_name, cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + cache_dir=cfg.lm.cache_dir, + local_files_only=cfg.lm.local_files_only, ) model = HookedTransformer.from_pretrained( - cfg.model_name, - device=cfg.device, - cache_dir=cfg.cache_dir, + cfg.lm.model_name, + device=cfg.lm.device, + cache_dir=cfg.lm.cache_dir, hf_model=hf_model, - dtype=cfg.dtype, + dtype=cfg.lm.dtype, ) model.eval() - activation_store = ActivationStore.from_config(model=model, cfg=cfg) - - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) + + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + wandb_config: dict = { + **asdict(cfg), + **asdict(cfg.sae), + **asdict(cfg.lm), + } + del wandb_config["sae"] + del wandb_config["lm"] wandb_run = wandb.init( - project=cfg.wandb_project, - config=cast(Any, cfg), - name=cfg.run_name, - entity=cfg.wandb_entity, + project=cfg.wandb.wandb_project, + config=wandb_config, + name=cfg.wandb.exp_name, + entity=cfg.wandb.wandb_entity, ) with open( os.path.join(cfg.exp_result_dir, cfg.exp_name, "eval_wandb_id.txt"), "w" @@ -176,18 +209,28 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig): for key, value in result.items(): print(f"{key}: {value}") - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): wandb.finish() return sae def activation_generation_runner(cfg: ActivationGenerationConfig): + hf_model = AutoModelForCausalLM.from_pretrained( + ( + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path + ), + cache_dir=cfg.lm.cache_dir, + local_files_only=cfg.lm.local_files_only, + ) model = HookedTransformer.from_pretrained( - cfg.model_name, - device=cfg.device, - cache_dir=cfg.cache_dir, - dtype=cfg.dtype, + cfg.lm.model_name, + device=cfg.lm.device, + cache_dir=cfg.lm.cache_dir, + hf_model=hf_model, + dtype=cfg.lm.dtype, ) model.eval() @@ -195,31 +238,31 @@ def activation_generation_runner(cfg: ActivationGenerationConfig): def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): - sae = SparseAutoEncoder.from_config(cfg=cfg) + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) hf_model = AutoModelForCausalLM.from_pretrained( ( - cfg.model_name - if cfg.model_from_pretrained_path is None - else cfg.model_from_pretrained_path + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path ), - cache_dir=cfg.cache_dir, - local_files_only=cfg.local_files_only, + cache_dir=cfg.lm.cache_dir, + local_files_only=cfg.lm.local_files_only, ) model = HookedTransformer.from_pretrained( - cfg.model_name, - device=cfg.device, - cache_dir=cfg.cache_dir, + cfg.lm.model_name, + device=cfg.lm.device, + cache_dir=cfg.lm.cache_dir, hf_model=hf_model, - dtype=cfg.dtype, + dtype=cfg.lm.dtype, ) model.eval() - client = MongoClient(cfg.mongo_uri, cfg.mongo_db) - client.create_dictionary(cfg.exp_name, cfg.d_sae, cfg.exp_series) + client = MongoClient(cfg.mongo.mongo_uri, cfg.mongo.mongo_db) + client.create_dictionary(cfg.exp_name, cfg.sae.d_sae, cfg.exp_series) for chunk_id in range(cfg.n_sae_chunks): - activation_store = ActivationStore.from_config(model=model, cfg=cfg) + activation_store = ActivationStore.from_config(model=model, cfg=cfg.act_store) result = sample_feature_activations(sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks) for i in range(len(result["index"].cpu().numpy().tolist())): @@ -246,29 +289,29 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig): @torch.no_grad() def features_to_logits_runner(cfg: FeaturesDecoderConfig): - sae = SparseAutoEncoder.from_config(cfg=cfg) + sae = SparseAutoEncoder.from_config(cfg=cfg.sae) hf_model = AutoModelForCausalLM.from_pretrained( ( - cfg.model_name - if cfg.model_from_pretrained_path is None - else cfg.model_from_pretrained_path + cfg.lm.model_name + if cfg.lm.model_from_pretrained_path is None + else cfg.lm.model_from_pretrained_path ), - cache_dir=cfg.cache_dir, - local_files_only=cfg.local_files_only, + cache_dir=cfg.lm.cache_dir, + local_files_only=cfg.lm.local_files_only, ) model = HookedTransformer.from_pretrained( - cfg.model_name, - device=cfg.device, - cache_dir=cfg.cache_dir, + cfg.lm.model_name, + device=cfg.lm.device, + cache_dir=cfg.lm.cache_dir, hf_model=hf_model, - dtype=cfg.dtype, + dtype=cfg.lm.dtype, ) model.eval() result_dict = features_to_logits(sae, model, cfg) - client = MongoClient(cfg.mongo_uri, cfg.mongo_db) + client = MongoClient(cfg.mongo.mongo_uri, cfg.mongo.mongo_db) for feature_index, logits in result_dict.items(): sorted_indeces = torch.argsort(logits) diff --git a/src/lm_saes/sae_training.py b/src/lm_saes/sae_training.py index bb7d3db..682b5e9 100644 --- a/src/lm_saes/sae_training.py +++ b/src/lm_saes/sae_training.py @@ -50,12 +50,11 @@ def train_sae( # sae.initialize_decoder_bias(activation_store._store[cfg.hook_point_in]) if cfg.use_ddp: - ddp = DDP(sae, device_ids=[cfg.rank], output_device=cfg.device) + ddp = DDP(sae, device_ids=[cfg.rank], output_device=cfg.sae.device) - assert cfg.d_sae is not None - act_freq_scores = torch.zeros(cfg.d_sae, device=cfg.device, dtype=cfg.dtype) - n_forward_passes_since_fired = torch.zeros(cfg.d_sae, device=cfg.device, dtype=cfg.dtype) - n_frac_active_tokens = torch.tensor([0], device=cfg.device, dtype=torch.int) + act_freq_scores = torch.zeros(cfg.sae.d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype) + n_forward_passes_since_fired = torch.zeros(cfg.sae.d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype) + n_frac_active_tokens = torch.tensor([0], device=cfg.sae.device, dtype=torch.int) optimizer = Adam(sae.parameters(), lr=cfg.lr, betas=cfg.betas) @@ -77,7 +76,7 @@ def train_sae( # Get the next batch of activations batch = activation_store.next(batch_size=cfg.train_batch_size) assert batch is not None, "Activation store is empty" - activation_in, activation_out = batch[cfg.hook_point_in], batch[cfg.hook_point_out] + activation_in, activation_out = batch[cfg.sae.hook_point_in], batch[cfg.sae.hook_point_out] scheduler.step() optimizer.zero_grad() @@ -117,7 +116,7 @@ def train_sae( act_freq_scores += (aux_data["feature_acts"].abs() > 0).float().sum(0) n_frac_active_tokens += activation_in.size(0) - n_tokens_current = torch.tensor(activation_in.size(0), device=cfg.device, dtype=torch.int) + n_tokens_current = torch.tensor(activation_in.size(0), device=cfg.sae.device, dtype=torch.int) if cfg.use_ddp: dist.reduce(n_tokens_current, dst=0) n_training_tokens += cast(int, n_tokens_current.item()) @@ -127,7 +126,7 @@ def train_sae( if cfg.use_ddp: dist.reduce(act_freq_scores, dst=0) dist.reduce(n_frac_active_tokens, dst=0) - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): feature_sparsity = act_freq_scores / n_frac_active_tokens log_feature_sparsity = torch.log10(feature_sparsity + 1e-10) wandb_histogram = wandb.Histogram(log_feature_sparsity.detach().cpu().float().numpy()) @@ -141,8 +140,8 @@ def train_sae( step=n_training_steps + 1, ) - act_freq_scores = torch.zeros(cfg.d_sae, device=cfg.device) - n_frac_active_tokens = torch.tensor([0], device=cfg.device, dtype=torch.int) + act_freq_scores = torch.zeros(cfg.sae.d_sae, device=cfg.sae.device) + n_frac_active_tokens = torch.tensor([0], device=cfg.sae.device, dtype=torch.int) if ((n_training_steps + 1) % cfg.log_frequency == 0): # metrics for currents acts @@ -184,7 +183,7 @@ def train_sae( current_learning_rate = optimizer.param_groups[0]["lr"] - if cfg.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): + if cfg.wandb.log_to_wandb and (not cfg.use_ddp or cfg.rank == 0): wandb.log( { # losses @@ -265,13 +264,12 @@ def prune_sae( ): sae.eval() n_training_tokens = 0 - assert cfg.d_sae is not None # Make mypy happy - act_times = torch.zeros(cfg.d_sae, device=cfg.device, dtype=torch.int) - max_acts = torch.zeros(cfg.d_sae, device=cfg.device, dtype=cfg.dtype) + act_times = torch.zeros(cfg.sae.d_sae, device=cfg.sae.device, dtype=torch.int) + max_acts = torch.zeros(cfg.sae.d_sae, device=cfg.sae.device, dtype=cfg.sae.dtype) activation_store.initialize() if cfg.use_ddp: - ddp = DDP(sae, device_ids=[cfg.rank], output_device=cfg.device) + ddp = DDP(sae, device_ids=[cfg.rank], output_device=cfg.sae.device) if not cfg.use_ddp or cfg.rank == 0: pbar = tqdm(total=cfg.total_training_tokens, desc="Pruning SAE", smoothing=0.01) @@ -279,7 +277,7 @@ def prune_sae( # Get the next batch of activations batch = activation_store.next(batch_size=cfg.train_batch_size) assert batch is not None, "Activation store is empty" - activation_in, activation_out = batch[cfg.hook_point_in], batch[cfg.hook_point_out] + activation_in, activation_out = batch[cfg.sae.hook_point_in], batch[cfg.sae.hook_point_out] feature_acts = sae.encode(activation_in, label=activation_out) @@ -307,7 +305,7 @@ def prune_sae( ) & (max_acts > cfg.dead_feature_max_act_threshold) & (sae.decoder.norm(p=2, dim=1) >= cfg.decoder_norm_threshold)).float() sae.feature_act_mask.requires_grad_(False) - if cfg.log_to_wandb: + if cfg.wandb.log_to_wandb: wandb.log( { "sparsity/dead_features": (act_times < cfg.dead_feature_threshold * cfg.total_training_tokens).sum().item(), diff --git a/src/lm_saes/utils/config.py b/src/lm_saes/utils/config.py new file mode 100644 index 0000000..211c135 --- /dev/null +++ b/src/lm_saes/utils/config.py @@ -0,0 +1,214 @@ +from typing import Any, get_origin, get_args +from dataclasses import fields as dataclass_fields, is_dataclass +import inspect + +from typing_extensions import TypedDict, is_typeddict, Self + +class Field(TypedDict): + name: str + type: Any + +def fields(cls) -> list[Field]: + assert is_dataclass(cls) or is_typeddict(cls), f"{cls} is not a dataclass or TypedDict" + if is_dataclass(cls): + return [Field(name=f.name, type=f.type) for f in dataclass_fields(cls)] + else: + return [Field(name=k, type=v) for k, v in inspect.get_annotations(cls).items()] + +def flattened_fields(cls) -> list[Field]: + if is_dataclass(cls) or is_typeddict(cls): + f = [] + for field in fields(cls): + f.extend(flattened_fields(field["type"])) + f.append(field) + return f + elif get_origin(cls) == list: + return flattened_fields(get_args(cls)[0]) + elif get_origin(cls) == dict: + return flattened_fields(get_args(cls)[1]) + return [] + +def is_flattenable(cls) -> bool: + return len(flattened_fields(cls)) > 0 + +def from_flattened(cls, data: Any, context: dict | None = None, path: str = "obj"): + """Construct an object, especially a dataclass or TypedDict, from a flat structure. + This is a superset of a traditional deserialization aimed at conveniently constructing nested dataclasses. + + The difference between this function and a traditional deserialization is that this function will further + pass the fields in the outer dataclass to the inner dataclasses. This is useful when we want to construct + nested dataclasses, where different subclasses hold fields with the same name and we want theses fields to + be exactly the same. + + Args: + cls: The class to construct. + data: The data to construct the object. + context: The context to pass to the inner dataclasses. Not necessary to be manually specified. + path: The path of the current field, mainly for debugging. Not necessary to be manually specified. + + Returns: + The constructed object. + + Examples: + + Construct a dataclass from a flat structure: + >>> from dataclasses import dataclass + >>> + >>> @dataclass + ... class A: + ... a1: int + ... a2: str + ... + >>> @dataclass + ... class B: + ... a: A + ... b: int + ... + >>> from_flattened(B, {"a1": 1, "a2": "2", "b": 3}) # Construct the object B with the fields in A and B. Fields in A will automatically be passed to construct A. + B(a=A(a1=1, a2='2'), b=3) + + Construct a dataclass with a list of dataclasses: + >>> from dataclasses import dataclass + >>> + >>> @dataclass + ... class A: + ... a1: int + ... a2: str + ... + >>> @dataclass + ... class B: + ... a: list[A] + ... b: int + ... + >>> from_flattened(B, {"a": [{"a1": 1}, {"a1": 2}], "a2": "3", "b": 4}) # Construct the object B with the fields in A and B. Fields in A will automatically be passed to all elements in the list. + B(a=[A(a1=1, a2='3'), A(a1=2, a2='3')], b=4) + + Construct a deep nested dataclass with default values: + >>> from dataclasses import dataclass + >>> + >>> @dataclass + ... class A: + ... a1: int + ... a2: str + ... + >>> @dataclass + ... class B: + ... a: A + ... b1: int = 3 + ... + >>> @dataclass + ... class C: + ... b: B + ... c: int + ... + >>> from_flattened(C, {"a1": 1, "a2": "2", "c": 4}) # Deep nested dataclasses are also supported. + C(b=B(a=A(a1=1, a2='2'), b1=3), c=4) + """ + + if context is None: + context = {} + if not is_flattenable(cls): + # Skip further checking for non-flattenable classes + return data + if is_dataclass(cls) or is_typeddict(cls): + if data == "__missing__": # Accept not specified fields and construct the object with the context. + data = {} + if is_dataclass(cls) and isinstance(data, cls): + return data + assert isinstance(data, dict), f"Field {path} is not a dict" + data = {**context, **data} + context = {**context, **data} + for field in fields(cls): + # We have to further transform the specified data (if exists) into the specified type + specified_data = data[field["name"]] if field["name"] in data else "__missing__" # We use __missing__ to indicate that the field is not specified. Not specified fields may have different behaviors due to their types and default values. + f = from_flattened(field["type"], specified_data, context, f"{path}.{field['name']}") + if f != "__missing__": # Don't update the field if it is still regarded as not specified, so that the default value can be used. + data[field["name"]] = f + # Remove the fields that are not in the data. This fields may exist in the child classes + # and we don't want to pass them to the constructor of the current class. + data = {k: v for k, v in data.items() if k in [f["name"] for f in fields(cls)]} + return cls(**data) + elif get_origin(cls) == list: + if data == "__missing__": + return "__missing__" + assert isinstance(data, list), f"Field {path} is not a list" + return [from_flattened(get_args(cls)[0], d, context, f"{path}.{i}") for i, d in enumerate(data)] + elif get_origin(cls) == dict: + if data == "__missing__": + return "__missing__" + assert isinstance(data, dict), f"Field {path} is not a dict" + return {k: from_flattened(get_args(cls)[1], v, context, f"{path}.{k}") for k, v in data.items()} + raise ValueError(f"Unexpected flattenable type {cls}. It's an internal error. Please report this issue to the developers.") + +class FlattenableModel: + @classmethod + def from_flattened(cls, data: Any) -> Self: + """Construct from a flat structure. This is a superset of a traditional deserialization aimed at conveniently constructing nested dataclasses. + + The difference between this function and a traditional deserialization is that this function will further + pass the fields in the outer dataclass to the inner dataclasses. This is useful when we want to construct + nested dataclasses, where different subclasses hold fields with the same name and we want theses fields to + be exactly the same. + + Args: + data: The data to construct the object. + + Returns: + The constructed object. + + Examples: + + Construct a dataclass from a flat structure: + >>> from dataclasses import dataclass + >>> + >>> @dataclass + ... class A(FlattenableModel): + ... a1: int + ... a2: str + ... + >>> @dataclass + ... class B(FlattenableModel): + ... a: A + ... b: int + ... + >>> B.from_flattened({"a1": 1, "a2": "2", "b": 3}) # Construct the object B with the fields in A and B. Fields in A will automatically be passed to construct A. + B(a=A(a1=1, a2='2'), b=3) + + Construct a dataclass with a list of dataclasses: + >>> from dataclasses import dataclass + >>> + >>> @dataclass + ... class A(FlattenableModel): + ... a1: int + ... a2: str + ... + >>> @dataclass + ... class B(FlattenableModel): + ... a: list[A] + ... b: int + ... + >>> B.from_flattened({"a": [{"a1": 1}, {"a1": 2}], "a2": "3", "b": 4}) # Construct the object B with the fields in A and B. Fields in A will automatically be passed to all elements in the list. + B(a=[A(a1=1, a2='3'), A(a1=2, a2='3')], b=4) + + Construct a deep nested dataclass with default values: + >>> from dataclasses import dataclass + >>> + >>> @dataclass + ... class A(FlattenableModel): + ... a1: int + ... a2: str + ... + >>> @dataclass + ... class B(FlattenableModel): + ... a: A + ... b1: int = 3 + ... + >>> @dataclass + ... class C(FlattenableModel): + ... b: B + ... c: int + ... + >>> C.from_flattened({"a1": 1, "a2": "2", "c": 4}) # Deep nested dataclasses are also supported. + C(b=B(a=A(a1=1, a2='2'), b1=3), c=4) + """ + return from_flattened(cls, data) \ No newline at end of file From 08c7f880908310a09f8c7c703b26fad885d4cab7 Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Thu, 13 Jun 2024 16:13:09 +0800 Subject: [PATCH 2/4] feat(entrypoint): add entry point for lm_saes --- examples/configuration/train.toml | 60 +++++++++++++++++++ examples/{ => programmatic}/analyze.py | 14 +---- examples/{ => programmatic}/train.py | 18 +----- pdm.lock | 34 ++++++++++- pyproject.toml | 7 +++ src/lm_saes/config.py | 7 ++- src/lm_saes/entrypoint.py | 83 ++++++++++++++++++++++++++ src/lm_saes/utils/config.py | 3 +- src/lm_saes/utils/misc.py | 34 +++++++++++ 9 files changed, 228 insertions(+), 32 deletions(-) create mode 100644 examples/configuration/train.toml rename examples/{ => programmatic}/analyze.py (84%) rename examples/{ => programmatic}/train.py (91%) create mode 100644 src/lm_saes/entrypoint.py diff --git a/examples/configuration/train.toml b/examples/configuration/train.toml new file mode 100644 index 0000000..1392280 --- /dev/null +++ b/examples/configuration/train.toml @@ -0,0 +1,60 @@ +use_ddp = false +exp_name = "L3M" +exp_result_dir = "results" +device = "cuda" +seed = 42 +dtype = "torch.float32" +total_training_tokens = 1_600_000_000 +lr = 4e-4 +betas = [ 0.0, 0.9999,] +lr_scheduler_name = "constantwithwarmup" +lr_warm_up_steps = 5000 +lr_cool_down_steps = 10000 +train_batch_size = 4096 +finetuning = false +feature_sampling_window = 1000 +dead_feature_window = 5000 +dead_feature_threshold = 1e-6 +eval_frequency = 1000 +log_frequency = 100 +n_checkpoints = 10 + + +[sae] +hook_point_in = "blocks.3.hook_mlp_out" +hook_point_out = "blocks.3.hook_mlp_out" +strict_loading = true +use_decoder_bias = false +apply_decoder_bias_to_pre_encoder = true +decoder_bias_init_method = "geometric_median" +expansion_factor = 32 +d_model = 768 +norm_activation = "token-wise" +decoder_exactly_unit_norm = false +use_glu_encoder = false +l1_coefficient = 1.2e-4 +lp = 1 +use_ghost_grads = true + +[lm] +model_name = "gpt2" +d_model = 768 + +[dataset] +dataset_path = "openwebtext" +is_dataset_on_disk = false +concat_tokens = false +context_size = 256 +store_batch_size = 32 + +[act_store] +device = "cuda" +seed = 42 +dtype = "torch.float32" +hook_points = [ "blocks.3.hook_mlp_out",] +use_cached_activations = false +n_tokens_in_buffer = 500000 + +[wandb] +log_to_wandb = true +wandb_project = "gpt2-sae" \ No newline at end of file diff --git a/examples/analyze.py b/examples/programmatic/analyze.py similarity index 84% rename from examples/analyze.py rename to examples/programmatic/analyze.py index e4aa5b1..0646c9c 100644 --- a/examples/analyze.py +++ b/examples/programmatic/analyze.py @@ -4,13 +4,6 @@ from lm_saes.config import LanguageModelSAEAnalysisConfig, SAEConfig from lm_saes.runner import sample_feature_activations_runner -use_ddp = False - -if use_ddp: - os.environ["TOKENIZERS_PARALLELISM"] = "false" - dist.init_process_group(backend='nccl') - torch.cuda.set_device(dist.get_rank()) - cfg = LanguageModelSAEAnalysisConfig( # LanguageModelConfig model_name = "gpt2", @@ -44,7 +37,6 @@ mongo_uri="mongodb://localhost:27017", # MongoDB URI. # RunnerConfig - use_ddp = use_ddp, device = "cuda", seed = 42, dtype = torch.float32, @@ -54,8 +46,4 @@ exp_result_dir = "results", ) -sample_feature_activations_runner(cfg) - -if use_ddp: - dist.destroy_process_group() - torch.cuda.empty_cache() \ No newline at end of file +sample_feature_activations_runner(cfg) \ No newline at end of file diff --git a/examples/train.py b/examples/programmatic/train.py similarity index 91% rename from examples/train.py rename to examples/programmatic/train.py index aa844e6..7febf03 100644 --- a/examples/train.py +++ b/examples/programmatic/train.py @@ -1,24 +1,14 @@ -import os -import sys import torch -import torch.distributed as dist from lm_saes.config import LanguageModelSAETrainingConfig from lm_saes.runner import language_model_sae_runner -use_ddp = False - -if use_ddp: - os.environ["TOKENIZERS_PARALLELISM"] = "false" - dist.init_process_group(backend='nccl') - torch.cuda.set_device(dist.get_rank()) - cfg = LanguageModelSAETrainingConfig( # LanguageModelConfig model_name = "gpt2", # The model name or path for the pre-trained model. d_model = 768, # The hidden size of the model. # TextDatasetConfig - dataset_path = "data/openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. + dataset_path = "openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field. is_dataset_tokenized = False, # Whether the dataset is tokenized. is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training. concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used. @@ -61,7 +51,6 @@ wandb_project= "gpt2-sae", # The wandb project name. # RunnerConfig - use_ddp = use_ddp, # Whether to use the DistributedDataParallel. device = "cuda", # The device to place all torch tensors. seed = 42, # The random seed. dtype = torch.float32, # The torch data type of non-integer tensors. @@ -71,7 +60,4 @@ exp_result_dir = "results" ) -sparse_autoencoder = language_model_sae_runner(cfg) - -if use_ddp: - dist.destroy_process_group() \ No newline at end of file +sparse_autoencoder = language_model_sae_runner(cfg) \ No newline at end of file diff --git a/pdm.lock b/pdm.lock index a4474c2..f849a30 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:c80acaac6865f83ab37f1a93ef1622a1cc51a94c3b9e7ea959d11f146d5acbf0" +content_hash = "sha256:a3c5d63d1687646068d8b7958440fbce60aa7f49398550eb479d665ad95128c2" [[package]] name = "accelerate" @@ -114,6 +114,16 @@ files = [ {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, ] +[[package]] +name = "argparse" +version = "1.4.0" +summary = "Python command-line parsing library" +groups = ["default"] +files = [ + {file = "argparse-1.4.0-py2.py3-none-any.whl", hash = "sha256:c31647edb69fd3d465a847ea3157d37bed1f95f19760b11a47aa91c04b666314"}, + {file = "argparse-1.4.0.tar.gz", hash = "sha256:62b089a55be1d8949cd2bc7e0df0bddb9e028faefc8c32038cc84862aefdd6e4"}, +] + [[package]] name = "asttokens" version = "2.4.1" @@ -2379,6 +2389,17 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "tomlkit" +version = "0.12.5" +requires_python = ">=3.7" +summary = "Style preserving TOML library" +groups = ["default"] +files = [ + {file = "tomlkit-0.12.5-py3-none-any.whl", hash = "sha256:af914f5a9c59ed9d0762c7b64d3b5d5df007448eb9cd2edc8a46b1eafead172f"}, + {file = "tomlkit-0.12.5.tar.gz", hash = "sha256:eef34fba39834d4d6b73c9ba7f3e4d1c417a4e56f89a7e96e090dd0d24b8fb3c"}, +] + [[package]] name = "torch" version = "2.3.0" @@ -2550,6 +2571,17 @@ files = [ {file = "typer-0.12.3.tar.gz", hash = "sha256:49e73131481d804288ef62598d97a1ceef3058905aa536a1134f90891ba35482"}, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20240311" +requires_python = ">=3.8" +summary = "Typing stubs for PyYAML" +groups = ["default"] +files = [ + {file = "types-PyYAML-6.0.12.20240311.tar.gz", hash = "sha256:a9e0f0f88dc835739b0c1ca51ee90d04ca2a897a71af79de9aec5f38cb0a5342"}, + {file = "types_PyYAML-6.0.12.20240311-py3-none-any.whl", hash = "sha256:b845b06a1c7e54b8e5b4c683043de0d9caf205e7434b3edc678ff2411979b8f6"}, +] + [[package]] name = "typing-extensions" version = "4.12.0" diff --git a/pyproject.toml b/pyproject.toml index b96f09c..f043816 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,11 +33,18 @@ dependencies = [ "jaxtyping>=0.2.25", "safetensors>=0.4.3", "pydantic>=2.7.3", + "argparse>=1.4.0", + "pyyaml>=6.0.1", + "types-pyyaml>=6.0.12.20240311", + "tomlkit>=0.12.5", ] requires-python = "==3.10.*" readme = "README.md" license = {text = "MIT"} +[project.scripts] +lm-saes = "lm_saes.entrypoint:entrypoint" + [tool.pdm.dev-dependencies] dev = [ "-e file:///${PROJECT_ROOT}/TransformerLens#egg=transformer-lens", diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 110da85..9e51230 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -10,7 +10,7 @@ from lm_saes.utils.config import FlattenableModel from lm_saes.utils.huggingface import parse_pretrained_name_or_path -from lm_saes.utils.misc import print_once +from lm_saes.utils.misc import convert_str_to_torch_dtype, print_once from transformer_lens.loading_from_pretrained import get_official_model_name @@ -38,6 +38,11 @@ def from_dict(cls, d: Dict[str, Any], **kwargs): d = {k: v for k, v in d.items() if k in [field.name for field in fields(cls)]} return cls(**d, **kwargs) + def __post_init__(self): + super().__post_init__() + if isinstance(self.dtype, str): + self.dtype = convert_str_to_torch_dtype(self.dtype) + @dataclass(kw_only=True) class RunnerConfig(BaseConfig): use_ddp: bool = False diff --git a/src/lm_saes/entrypoint.py b/src/lm_saes/entrypoint.py new file mode 100644 index 0000000..91d7278 --- /dev/null +++ b/src/lm_saes/entrypoint.py @@ -0,0 +1,83 @@ +import argparse +from enum import Enum + +import torch + + +class SupportedRunner(Enum): + TRAIN = 'train' + EVAL = 'eval' + ANALYZE = 'analyze' + PRUNE = 'prune' + + def __str__(self): + return self.value + +def entrypoint(): + parser = argparse.ArgumentParser(description='Launch runners from given configuration.') + parser.add_argument('runner', type=SupportedRunner, help=f'The runner to launch. Supported runners: {", ".join([str(runner) for runner in SupportedRunner])}.', choices=list(SupportedRunner), metavar='runner') + parser.add_argument('config', type=str, help='The configuration to use.') + parser.add_argument('--sae', type=str, help='The path to the pretrained SAE model.') + args = parser.parse_args() + + config_file: str = args.config + if config_file.endswith('.json'): + import json + with open(config_file, 'r') as f: + config = json.load(f) + elif config_file.endswith('.yaml') or config_file.endswith('.yml'): + import yaml + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + elif config_file.endswith('.toml'): + import tomlkit + with open(config_file, 'r') as f: + config = tomlkit.load(f).unwrap() + print(config) + elif config_file.endswith('.py'): + import importlib.util + spec = importlib.util.spec_from_file_location("__lm_sae_config__", config_file) + assert spec is not None and spec.loader is not None, f'Failed to load configuration file: {config_file}.' + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + config = module.config + else: + raise ValueError(f'Unsupported configuration file format: {config_file}. Supported formats: json, yaml, toml, py.') + + if args.sae is not None: + from lm_saes.config import SAEConfig + config['sae'] = SAEConfig.from_pretrained(args.sae).to_dict() + + use_ddp = "use_ddp" in config and config["use_ddp"] + if use_ddp: + import os + import torch.distributed as dist + os.environ["TOKENIZERS_PARALLELISM"] = "false" + dist.init_process_group(backend='nccl') + torch.cuda.set_device(dist.get_rank()) + + if args.runner == SupportedRunner.TRAIN: + from lm_saes.runner import language_model_sae_runner + from lm_saes.config import LanguageModelSAETrainingConfig + config = LanguageModelSAETrainingConfig.from_flattened(config) + language_model_sae_runner(config) + elif args.runner == SupportedRunner.EVAL: + from lm_saes.runner import language_model_sae_eval_runner + from lm_saes.config import LanguageModelSAERunnerConfig + config = LanguageModelSAERunnerConfig.from_flattened(config) + language_model_sae_eval_runner(config) + elif args.runner == SupportedRunner.ANALYZE: + from lm_saes.runner import sample_feature_activations_runner + from lm_saes.config import LanguageModelSAEAnalysisConfig + config = LanguageModelSAEAnalysisConfig.from_flattened(config) + sample_feature_activations_runner(config) + elif args.runner == SupportedRunner.PRUNE: + from lm_saes.runner import language_model_sae_prune_runner + from lm_saes.config import LanguageModelSAEPruningConfig + config = LanguageModelSAEPruningConfig.from_flattened(config) + language_model_sae_prune_runner(config) + else: + raise ValueError(f'Unsupported runner: {args.runner}.') + + if use_ddp: + dist.destroy_process_group() diff --git a/src/lm_saes/utils/config.py b/src/lm_saes/utils/config.py index 211c135..29735d4 100644 --- a/src/lm_saes/utils/config.py +++ b/src/lm_saes/utils/config.py @@ -211,4 +211,5 @@ def from_flattened(cls, data: Any) -> Self: >>> C.from_flattened({"a1": 1, "a2": "2", "c": 4}) # Deep nested dataclasses are also supported. C(b=B(a=A(a1=1, a2='2'), b1=3), c=4) """ - return from_flattened(cls, data) \ No newline at end of file + return from_flattened(cls, data) + \ No newline at end of file diff --git a/src/lm_saes/utils/misc.py b/src/lm_saes/utils/misc.py index 51981f2..b99de82 100644 --- a/src/lm_saes/utils/misc.py +++ b/src/lm_saes/utils/misc.py @@ -26,3 +26,37 @@ def check_file_path_unused(file_path): print(f"Error: File {file_path} already exists. Please choose a different file path.") exit() +str_dtype_map = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "int8": torch.int8, + "int16": torch.int16, + "int32": torch.int32, + "int64": torch.int64, + "bool": torch.bool, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float": torch.float, + "fp16": torch.float16, + "fp32": torch.float32, + "fp64": torch.float64, + "int": torch.int, + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.int8": torch.int8, + "torch.int16": torch.int16, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.bool": torch.bool, + "torch.bfloat16": torch.bfloat16, + "torch.float": torch.float, + "torch.int": torch.int, +} + +def convert_str_to_torch_dtype(str_dtype: str) -> torch.dtype: + if str_dtype in str_dtype_map: + return str_dtype_map[str_dtype] + else: + raise ValueError(f"Unsupported data type: {str_dtype}. Supported data types: {list(str_dtype_map.keys())}.") \ No newline at end of file From a1e1e418b06bd95f02b4f63a90fe944f6271aa0f Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Thu, 13 Jun 2024 20:03:17 +0800 Subject: [PATCH 3/4] docs(config): add docs for configuration-based launcher --- README.md | 22 ++++++++++++++-- examples/configuration/analyze.toml | 41 +++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 examples/configuration/analyze.toml diff --git a/README.md b/README.md index 77ab2d3..b6dc144 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,27 @@ bun install It's worth noting that `bun` is not well-supported on Windows, so you may need to use WSL or other Linux-based solutions to run the frontend, or consider using a different package manager, such as `pnpm` or `yarn`. -## Training/Analyzing a Dictionary +## Launch an Experiment -We give some basic examples to show how to train a dictionary and analyze the learned dictionary in the [examples](https://github.com/OpenMOSS/Language-Model-SAEs/exapmles). You can copy the example scripts to the `exp` directory and modify them to fit your needs. More examples will be added in the future. +We provide both a programmatic and a configuration-based way to launch an experiment. The configuration-based way is more flexible and recommended for most users. You can find the configuration files in the [examples/configuration](https://github.com/OpenMOSS/Language-Model-SAEs/examples/configuration) directory, and modify them to fit your needs. The programmatic way is more suitable for advanced users who want to customize the training process, and you can find the example scripts in the [examples/programmatic](https://github.com/OpenMOSS/Language-Model-SAEs/examples/programmatic) directory. + +To simply begin a training process, you can run the following command: + +```bash +lm-saes train examples/configuration/train.toml +``` + +which will start the training process using the configuration file [examples/configuration/train.toml](https://github.com/OpenMOSS/Language-Model-SAEs/examples/configuration/train.toml). + +To analyze a trained dictionary, you can run the following command: + +```bash +lm-saes analyze examples/configuration/analyze.toml --sae +``` + +which will start the analysis process using the configuration file [examples/configuration/analyze.toml](https://github.com/OpenMOSS/Language-Model-SAEs/examples/configuration/analyze.toml). The analysis process requires a trained SAE model, which can be obtained from the training process. You may need launch a MongoDB server to store the analysis results, and you can modify the MongoDB settings in the configuration file. + +Generally, our configuration-based pipeline uses outer layer settings as default of the inner layer settings. This is beneficial for easily building deeply nested configurations, where sub-configurations can be reused (such as device and dtype settings). More detail will be provided in the configuration files. ## Visualizing the Learned Dictionary diff --git a/examples/configuration/analyze.toml b/examples/configuration/analyze.toml new file mode 100644 index 0000000..76d11c4 --- /dev/null +++ b/examples/configuration/analyze.toml @@ -0,0 +1,41 @@ +total_analyzing_tokens = 20_000_000 + +use_ddp = false +device = "cuda" +seed = 42 +dtype = "torch.float32" + +exp_name = "L3M" +exp_series = "default" +exp_result_dir = "results" + +[subsample] +"top_activations" = { "proportion" = 1.0, "n_samples" = 80 } +"subsample-0.9" = { "proportion" = 0.9, "n_samples" = 20} +"subsample-0.8" = { "proportion" = 0.8, "n_samples" = 20} +"subsample-0.7" = { "proportion" = 0.7, "n_samples" = 20} +"subsample-0.5" = { "proportion" = 0.5, "n_samples" = 20} + +[lm] +model_name = "gpt2" +d_model = 768 + +[dataset] +dataset_path = "openwebtext" +is_dataset_tokenized = false +is_dataset_on_disk = true +concat_tokens = false +context_size = 256 +store_batch_size = 32 + +[act_store] +device = "cuda" +seed = 42 +dtype = "torch.float32" +hook_points = [ "blocks.3.hook_mlp_out",] +use_cached_activations = false +n_tokens_in_buffer = 500000 + +[mongo] +mongo_db = "mechinterp" +mongo_uri = "mongodb://localhost:27017" From 4861daf49b47dd0ef4b5bc6a40b602a00cfa09c3 Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Thu, 13 Jun 2024 20:07:45 +0800 Subject: [PATCH 4/4] fix(examples): fix programmatic runners --- examples/programmatic/analyze.py | 6 ++---- examples/programmatic/train.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/programmatic/analyze.py b/examples/programmatic/analyze.py index 0646c9c..d60f259 100644 --- a/examples/programmatic/analyze.py +++ b/examples/programmatic/analyze.py @@ -1,10 +1,8 @@ import torch -import os -import torch.distributed as dist from lm_saes.config import LanguageModelSAEAnalysisConfig, SAEConfig from lm_saes.runner import sample_feature_activations_runner -cfg = LanguageModelSAEAnalysisConfig( +cfg = LanguageModelSAEAnalysisConfig.from_flattened(dict( # LanguageModelConfig model_name = "gpt2", @@ -44,6 +42,6 @@ exp_name = "L3M", exp_series = "default", exp_result_dir = "results", -) +)) sample_feature_activations_runner(cfg) \ No newline at end of file diff --git a/examples/programmatic/train.py b/examples/programmatic/train.py index 7febf03..dd7e695 100644 --- a/examples/programmatic/train.py +++ b/examples/programmatic/train.py @@ -2,7 +2,7 @@ from lm_saes.config import LanguageModelSAETrainingConfig from lm_saes.runner import language_model_sae_runner -cfg = LanguageModelSAETrainingConfig( +cfg = LanguageModelSAETrainingConfig.from_flattened(dict( # LanguageModelConfig model_name = "gpt2", # The model name or path for the pre-trained model. d_model = 768, # The hidden size of the model. @@ -58,6 +58,6 @@ exp_name = "L3M", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name. exp_series = "default", exp_result_dir = "results" -) +)) sparse_autoencoder = language_model_sae_runner(cfg) \ No newline at end of file