From b4190515fb9cc042a113cf97b74efce99fb7df59 Mon Sep 17 00:00:00 2001 From: team-life Date: Mon, 2 Dec 2024 17:52:15 +0900 Subject: [PATCH] Make safe_pickle_load function. --- .../indices/managed/bge_m3/base.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/llama-index-integrations/indices/llama-index-indices-managed-bge-m3/llama_index/indices/managed/bge_m3/base.py b/llama-index-integrations/indices/llama-index-indices-managed-bge-m3/llama_index/indices/managed/bge_m3/base.py index 9d233d1a97278..dafff907a481e 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-bge-m3/llama_index/indices/managed/bge_m3/base.py +++ b/llama-index-integrations/indices/llama-index-indices-managed-bge-m3/llama_index/indices/managed/bge_m3/base.py @@ -7,11 +7,27 @@ from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.data_structs.data_structs import IndexDict -from llama_index.core.indices.base import BaseIndex, IndexNode +from llama_index.core.indices.base import BaseIndex, IndexNode from llama_index.core.schema import BaseNode, NodeWithScore from llama_index.core.storage.docstore.types import RefDocInfo from llama_index.core.storage.storage_context import StorageContext +class SafeUnpickler(pickle.Unpickler): + ALLOWED_CLASSES = { + "builtins": {"list", "dict", "str", "int", "float", "tuple", "set"}, + "numpy.core.multiarray": {"_reconstruct", "scalar"}, + "numpy": {"ndarray", "dtype"}, + "collections": {"defaultdict"} + } + + def find_class(self, module, name): + if module in self.ALLOWED_CLASSES and name in self.ALLOWED_CLASSES[module]: + return super().find_class(module, name) + raise pickle.UnpicklingError(f"Unauthorized class: {module}.{name}") + +def safe_pickle_load(file): + return SafeUnpickler(file).load() + class BGEM3Index(BaseIndex[IndexDict]): """ @@ -153,7 +169,7 @@ def load_from_disk( int(k): v for k, v in index.index_struct.nodes_dict.items() } index._docs_pos_to_node_id = docs_pos_to_node_id - index._multi_embed_store = pickle.load( + index._multi_embed_store = safe_pickle_load( open(Path(persist_dir) / "multi_embed_store.pkl", "rb") ) return index