Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: bypass kwargs to SimpeDirectoryReader with better type hinting #16858

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions llama-index-core/llama_index/core/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
from llama_index.core.readers.file.base import (
SimpleDirectoryReader,
FileSystemReaderMixin,
DirectoryReaderArgs,
LoadFileArgs,
BaseDirectoryReaderArgs,
)
from llama_index.core.readers.string_iterable import StringIterableReader
from llama_index.core.schema import Document

__all__ = [
"DirectoryReaderArgs",
"LoadFileArgs",
"BaseDirectoryReaderArgs",
"SimpleDirectoryReader",
"FileSystemReaderMixin",
"ReaderConfig",
Expand Down
2 changes: 1 addition & 1 deletion llama-index-core/llama_index/core/readers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def load_langchain_documents(self, **load_kwargs: Any) -> List["LCDocument"]:
class BasePydanticReader(BaseReader, BaseComponent):
"""Serialiable Data Loader with Pydantic."""

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
is_remote: bool = Field(
default=False,
description="Whether the data is loaded from a remote API or a local file.",
Expand Down
225 changes: 142 additions & 83 deletions llama-index-core/llama_index/core/readers/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,25 @@
from datetime import datetime
from functools import reduce
import asyncio
from itertools import repeat
from pathlib import Path, PurePosixPath
import fsspec
from fsspec.implementations.local import LocalFileSystem
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Type, Union
from typing_extensions import (
Unpack,
NotRequired,
)
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Set,
Type,
Union,
TypedDict,
)

from llama_index.core.readers.base import BaseReader, ResourcesReaderMixin
from llama_index.core.async_utils import run_jobs, get_asyncio_module
Expand Down Expand Up @@ -174,7 +188,59 @@ def is_default_fs(fs: fsspec.AbstractFileSystem) -> bool:
logger = logging.getLogger(__name__)


class SimpleDirectoryReader(BaseReader, ResourcesReaderMixin, FileSystemReaderMixin):
class BaseDirectoryReaderArgs(TypedDict):
"""
Base args for directory readers.
"""

file_metadata: NotRequired[Callable[[str], Dict]]
file_extractor: NotRequired[Dict[str, BaseReader]]
filename_as_id: NotRequired[bool]
encoding: NotRequired[str]
errors: NotRequired[str]
raise_on_error: NotRequired[bool]
fs: NotRequired[fsspec.AbstractFileSystem]


class LoadFileArgs(BaseDirectoryReaderArgs):
"""
Args for `load_file` and `aload_file`.
"""

input_file: NotRequired[Path]


class DirectoryReaderArgs(BaseDirectoryReaderArgs):
"""
Args for `SimpleDirectoryReader`.
"""

exclude: NotRequired[Optional[List[str]]]
exclude_hidden: NotRequired[bool]
recursive: NotRequired[bool]
required_exts: NotRequired[Optional[List[str]]]
num_files_limit: NotRequired[Optional[int]]


class DirectoryReaderData:
"""
Base data for directory readers.
"""

exclude: Optional[List] = None
exclude_hidden: Optional[bool] = None
encoding: Optional[str] = None
errors: Optional[str] = None
recursive: Optional[bool] = None
filename_as_id: Optional[bool] = None
required_exts: Optional[List[str]] = None
raise_on_error: Optional[bool] = None
num_files_limit: Optional[int] = None


class SimpleDirectoryReader(
BaseReader, ResourcesReaderMixin, FileSystemReaderMixin, DirectoryReaderData
):
"""
Simple directory reader.

Expand Down Expand Up @@ -217,35 +283,23 @@ def __init__(
self,
input_dir: Optional[Union[Path, str]] = None,
input_files: Optional[List] = None,
exclude: Optional[List] = None,
exclude_hidden: bool = True,
errors: str = "ignore",
recursive: bool = False,
encoding: str = "utf-8",
filename_as_id: bool = False,
required_exts: Optional[List[str]] = None,
file_extractor: Optional[Dict[str, BaseReader]] = None,
num_files_limit: Optional[int] = None,
file_metadata: Optional[Callable[[str], Dict]] = None,
raise_on_error: bool = False,
fs: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Unpack[DirectoryReaderArgs],
) -> None:
"""Initialize with parameters."""
super().__init__()

super()
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")

self.fs = fs or get_default_fs()
self.errors = errors
self.encoding = encoding

self.exclude = exclude
self.recursive = recursive
self.exclude_hidden = exclude_hidden
self.required_exts = required_exts
self.num_files_limit = num_files_limit
self.raise_on_error = raise_on_error
self.fs = kwargs.get("fs") or get_default_fs()
self.errors = kwargs.get("errors", "ignore")
self.encoding = kwargs.get("encoding", "utf-8")

self.exclude = kwargs.get("exclude")
self.recursive = kwargs.get("recursive", False)
self.exclude_hidden = kwargs.get("exclude_hidden", True)
self.required_exts = kwargs.get("required_exts")
self.num_files_limit = kwargs.get("num_files_limit")
self.raise_on_error = kwargs.get("raise_on_error", False)
_Path = Path if is_default_fs(self.fs) else PurePosixPath

if input_files:
Expand All @@ -259,16 +313,18 @@ def __init__(
if not self.fs.isdir(input_dir):
raise ValueError(f"Directory {input_dir} does not exist.")
self.input_dir = _Path(input_dir)
self.exclude = exclude
self.exclude = self.exclude or []
self.input_files = self._add_files(self.input_dir)

if file_extractor is not None:
self.file_extractor = file_extractor
if kwargs.get("file_extractor") is not None:
self.file_extractor = kwargs["file_extractor"]
else:
self.file_extractor = {}

self.file_metadata = file_metadata or _DefaultFileMetadataFunc(self.fs)
self.filename_as_id = filename_as_id
self.file_metadata = kwargs.get("file_metadata") or _DefaultFileMetadataFunc(
self.fs
)
self.filename_as_id = kwargs.get("filename_as_id", False)

def is_hidden(self, path: Path) -> bool:
return any(
Expand Down Expand Up @@ -418,7 +474,7 @@ def get_resource_info(self, resource_id: str, *args: Any, **kwargs: Any) -> Dict
}

def load_resource(
self, resource_id: str, *args: Any, **kwargs: Any
self, resource_id: str, *args: Any, **kwargs: Unpack[BaseDirectoryReaderArgs]
) -> List[Document]:
file_metadata = kwargs.get("file_metadata", self.file_metadata)
file_extractor = kwargs.get("file_extractor", self.file_extractor)
Expand All @@ -439,11 +495,10 @@ def load_resource(
errors=errors,
raise_on_error=raise_on_error,
fs=fs,
**kwargs,
)

async def aload_resource(
self, resource_id: str, *args: Any, **kwargs: Any
self, resource_id: str, *args: Any, **kwargs: Unpack[BaseDirectoryReaderArgs]
) -> List[Document]:
file_metadata = kwargs.get("file_metadata", self.file_metadata)
file_extractor = kwargs.get("file_extractor", self.file_extractor)
Expand All @@ -462,7 +517,6 @@ async def aload_resource(
errors=errors,
raise_on_error=raise_on_error,
fs=fs,
**kwargs,
)

def read_file_content(self, input_file: Path, **kwargs: Any) -> bytes:
Expand All @@ -472,16 +526,7 @@ def read_file_content(self, input_file: Path, **kwargs: Any) -> bytes:
return f.read()

@staticmethod
def load_file(
input_file: Path,
file_metadata: Callable[[str], Dict],
file_extractor: Dict[str, BaseReader],
filename_as_id: bool = False,
encoding: str = "utf-8",
errors: str = "ignore",
raise_on_error: bool = False,
fs: Optional[fsspec.AbstractFileSystem] = None,
) -> List[Document]:
def load_file(**kwargs: Unpack[LoadFileArgs]) -> List[Document]:
"""
Static method for loading file.

Expand Down Expand Up @@ -514,6 +559,15 @@ def load_file(
Returns:
List[Document]: loaded documents
"""
input_file = kwargs["input_file"]
file_extractor = kwargs["file_extractor"]
file_metadata = kwargs.get("file_metadata", None)
filename_as_id = kwargs.get("filename_as_id", False)
encoding = kwargs.get("encoding", "utf-8")
errors = kwargs.get("errors", "ignore")
raise_on_error = kwargs.get("raise_on_error", False)
fs = kwargs.get("fs", None)

# TODO: make this less redundant
default_file_reader_cls = SimpleDirectoryReader.supported_suffix_fn()
default_file_reader_suffix = list(default_file_reader_cls.keys())
Expand All @@ -534,10 +588,10 @@ def load_file(

# load data -- catch all errors except for ImportError
try:
kwargs = {"extra_info": metadata}
reader_kwargs = {"extra_info": metadata, **kwargs}
if fs and not is_default_fs(fs):
kwargs["fs"] = fs
docs = reader.load_data(input_file, **kwargs)
reader_kwargs["fs"] = fs
docs = reader.load_data(**reader_kwargs)
except ImportError as e:
# ensure that ImportError is raised so user knows
# about missing dependencies
Expand Down Expand Up @@ -573,17 +627,17 @@ def load_file(
return documents

@staticmethod
async def aload_file(
input_file: Path,
file_metadata: Callable[[str], Dict],
file_extractor: Dict[str, BaseReader],
filename_as_id: bool = False,
encoding: str = "utf-8",
errors: str = "ignore",
raise_on_error: bool = False,
fs: Optional[fsspec.AbstractFileSystem] = None,
) -> List[Document]:
async def aload_file(**kwargs: Unpack[LoadFileArgs]) -> List[Document]:
"""Load file asynchronously."""
input_file = kwargs["input_file"]
file_extractor = kwargs["file_extractor"]
file_metadata = kwargs.get("file_metadata", None)
filename_as_id = kwargs.get("filename_as_id", False)
encoding = kwargs.get("encoding", "utf-8")
errors = kwargs.get("errors", "ignore")
raise_on_error = kwargs.get("raise_on_error", False)
fs = kwargs.get("fs")

# TODO: make this less redundant
default_file_reader_cls = SimpleDirectoryReader.supported_suffix_fn()
default_file_reader_suffix = list(default_file_reader_cls.keys())
Expand All @@ -604,10 +658,10 @@ async def aload_file(

# load data -- catch all errors except for ImportError
try:
kwargs = {"extra_info": metadata}
reader_kwargs = {"extra_info": metadata, **kwargs}
if fs and not is_default_fs(fs):
kwargs["fs"] = fs
docs = await reader.aload_data(input_file, **kwargs)
reader_kwargs["fs"] = fs
docs = await reader.aload_data(input_file, **reader_kwargs)
except ImportError as e:
# ensure that ImportError is raised so user knows
# about missing dependencies
Expand Down Expand Up @@ -675,19 +729,24 @@ def load_data(
num_workers = num_cpus

with multiprocessing.get_context("spawn").Pool(num_workers) as p:
results = p.starmap(
SimpleDirectoryReader.load_file,
zip(
files_to_process,
repeat(self.file_metadata),
repeat(self.file_extractor),
repeat(self.filename_as_id),
repeat(self.encoding),
repeat(self.errors),
repeat(self.raise_on_error),
repeat(fs),
),
)
async_results = []
for input_file in files_to_process:
async_results.append(
p.apply_async(
SimpleDirectoryReader.load_file,
kwds={
"input_file": input_file,
"file_metadata": self.file_metadata,
"file_extractor": self.file_extractor,
"filename_as_id": self.filename_as_id,
"encoding": self.encoding,
"errors": self.errors,
"raise_on_error": self.raise_on_error,
"fs": fs,
},
)
)
results = [res.get() for res in async_results]
documents = reduce(lambda x, y: x + y, results)

else:
Expand Down Expand Up @@ -734,14 +793,14 @@ async def aload_data(

coroutines = [
SimpleDirectoryReader.aload_file(
input_file,
self.file_metadata,
self.file_extractor,
self.filename_as_id,
self.encoding,
self.errors,
self.raise_on_error,
fs,
input_file=input_file,
file_metadata=self.file_metadata,
file_extractor=self.file_extractor,
filename_as_id=self.filename_as_id,
encoding=self.encoding,
errors=self.errors,
raise_on_error=self.raise_on_error,
fs=fs,
)
for input_file in files_to_process
]
Expand Down
Loading
Loading