Skip to content

Commit

Permalink
Improve logging (#235)
Browse files Browse the repository at this point in the history
* WIP: improve logging

* Remove unused code

* Cut long string, configure via env vars, restructure utils folder

* ruff

* Fix tests

* Update changelog

* There is no reason not to test that

* Rename

* Add tests

* Update log message

* Ruff

* Ruff again

* Spelling
  • Loading branch information
stellasia authored Jan 6, 2025
1 parent 39a4b73 commit feeddbb
Show file tree
Hide file tree
Showing 14 changed files with 245 additions and 26 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

### Fixed
- IDs for the Document and Chunk nodes in the lexical graph are now randomly generated and unique across multiple runs, fixing issues in the lexical graph where relationships were created between chunks that were created by different pipeline runs.

- Improve logging for a better debugging experience: long lists and strings are now truncated. The max length can be controlled using the `LOGGING__MAX_LIST_LENGTH` and `LOGGING__MAX_STRING_LENGTH` env variables.

## 1.3.0

Expand Down
6 changes: 6 additions & 0 deletions examples/build_graph/simple_kg_builder_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import asyncio
import logging

import neo4j
from neo4j_graphrag.embeddings import OpenAIEmbeddings
Expand All @@ -20,6 +21,11 @@
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.openai_llm import OpenAILLM

logging.basicConfig()
logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG)
# logging.getLogger("neo4j_graphrag").setLevel(logging.INFO)


# Neo4j db infos
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils.logging import prettify

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,25 +217,23 @@ async def extract_for_chunk(
result = json.loads(llm_generated_json)
except (json.JSONDecodeError, InvalidJSONError) as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError(
f"LLM response is not valid JSON {llm_result.content}: {e}"
)
raise LLMGenerationError("LLM response is not valid JSON") from e
else:
logger.error(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}"
f"LLM response is not valid JSON for chunk_index={chunk.index}"
)
logger.debug(f"Invalid JSON: {llm_result.content}")
result = {"nodes": [], "relationships": []}
try:
chunk_graph = Neo4jGraph(**result)
except ValidationError as e:
if self.on_error == OnError.RAISE:
raise LLMGenerationError(
f"LLM response has improper format {result}: {e}"
)
raise LLMGenerationError("LLM response has improper format") from e
else:
logger.error(
f"LLM response has improper format {result} for chunk_index={chunk.index}"
f"LLM response has improper format for chunk_index={chunk.index}"
)
logger.debug(f"Invalid JSON format: {result}")
chunk_graph = Neo4jGraph()
return chunk_graph

Expand Down Expand Up @@ -336,5 +335,5 @@ async def run(
]
chunk_graphs: list[Neo4jGraph] = list(await asyncio.gather(*tasks))
graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs)
logger.debug(f"{self.__class__.__name__}: {graph}")
logger.debug(f"Extracted graph: {prettify(graph)}")
return graph
8 changes: 7 additions & 1 deletion src/neo4j_graphrag/experimental/pipeline/config/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition
from neo4j_graphrag.utils.logging import prettify

logger = logging.getLogger(__name__)

Expand All @@ -70,6 +71,7 @@ class PipelineConfigWrapper(BaseModel):
] = Field(discriminator=Discriminator(_get_discriminator_value))

def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition:
logger.debug("PIPELINE_CONFIG: start parsing config...")
return self.config.parse(resolved_data)

def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -101,10 +103,14 @@ def from_config(
cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False
) -> Self:
wrapper = PipelineConfigWrapper.model_validate({"config": config})
logger.debug(
f"PIPELINE_RUNNER: instantiating Pipeline from config type: {wrapper.config.template_}"
)
return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning)

@classmethod
def from_config_file(cls, file_path: Union[str, Path]) -> Self:
logger.info(f"PIPELINE_RUNNER: reading config file from {file_path}")
if not isinstance(file_path, str):
file_path = str(file_path)
data = ConfigReader().read(file_path)
Expand All @@ -119,7 +125,7 @@ async def run(self, user_input: dict[str, Any]) -> PipelineResult:
else:
run_param = deep_update(self.run_params, user_input)
logger.info(
f"PIPELINE_RUNNER: starting pipeline {self.pipeline} with run_params={run_param}"
f"PIPELINE_RUNNER: starting pipeline {self.pipeline} with run_params={prettify(run_param)}"
)
result = await self.pipeline.run(data=run_param)
if self.do_cleaning:
Expand Down
30 changes: 19 additions & 11 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from timeit import default_timer
from typing import Any, AsyncGenerator, Optional

from neo4j_graphrag.utils.logging import prettify

try:
import pygraphviz as pgv
except ImportError:
Expand Down Expand Up @@ -90,21 +92,21 @@ async def execute(self, **kwargs: Any) -> RunResult | None:
if the task run successfully, None if the status update
was unsuccessful.
"""
logger.debug(f"Running component {self.name} with {kwargs}")
start_time = default_timer()
component_result = await self.component.run(**kwargs)
run_result = RunResult(
result=component_result,
)
end_time = default_timer()
logger.debug(f"Component {self.name} finished in {end_time - start_time}s")
return run_result

async def run(self, inputs: dict[str, Any]) -> RunResult | None:
"""Main method to execute the task."""
logger.debug(f"TASK START {self.name=} {inputs=}")
logger.debug(f"TASK START {self.name=} input={prettify(inputs)}")
start_time = default_timer()
res = await self.execute(**inputs)
logger.debug(f"TASK RESULT {self.name=} {res=}")
end_time = default_timer()
logger.debug(
f"TASK FINISHED {self.name} in {end_time - start_time} res={prettify(res)}"
)
return res


Expand Down Expand Up @@ -141,7 +143,9 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
try:
await self.set_task_status(task.name, RunStatus.RUNNING)
except PipelineStatusUpdateError:
logger.info(f"Component {task.name} already running or done")
logger.debug(
f"ORCHESTRATOR: TASK ABORTED: {task.name} is already running or done, aborting"
)
return None
res = await task.run(inputs)
await self.set_task_status(task.name, RunStatus.DONE)
Expand Down Expand Up @@ -198,7 +202,8 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
d_status = await self.get_status_for_component(d.start)
if d_status != RunStatus.DONE:
logger.debug(
f"Missing dependency {d.start} for {task.name} (status: {d_status}). "
f"ORCHESTRATOR {self.run_id}: TASK DELAYED: Missing dependency {d.start} for {task.name} "
f"(status: {d_status}). "
"Will try again when dependency is complete."
)
raise PipelineMissingDependencyError()
Expand Down Expand Up @@ -227,6 +232,9 @@ async def next(
await self.check_dependencies_complete(next_node)
except PipelineMissingDependencyError:
continue
logger.debug(
f"ORCHESTRATOR {self.run_id}: enqueuing next task: {next_node.name}"
)
yield next_node
return

Expand Down Expand Up @@ -315,7 +323,6 @@ async def run(self, data: dict[str, Any]) -> None:
(node without any parent). Then the callback on_task_complete
will handle the task dependencies.
"""
logger.debug(f"PIPELINE START {data=}")
tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
await asyncio.gather(*tasks)

Expand Down Expand Up @@ -624,15 +631,16 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
return True

async def run(self, data: dict[str, Any]) -> PipelineResult:
logger.debug("Starting pipeline")
logger.debug("PIPELINE START")
start_time = default_timer()
self.invalidate()
self.validate_input_data(data)
orchestrator = Orchestrator(self)
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
await orchestrator.run(data)
end_time = default_timer()
logger.debug(
f"Pipeline {orchestrator.run_id} finished in {end_time - start_time}s"
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
)
return PipelineResult(
run_id=orchestrator.run_id,
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
model_validator,
)

from neo4j_graphrag.utils import validate_search_query_input
from neo4j_graphrag.utils.validation import validate_search_query_input


class RawSearchResult(BaseModel):
Expand Down
Empty file.
80 changes: 80 additions & 0 deletions src/neo4j_graphrag/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import os
from typing import Any

from pydantic import BaseModel

DEFAULT_MAX_LIST_LENGTH: int = 5
DEFAULT_MAX_STRING_LENGTH: int = 200


class Prettifier:
"""Prettyfy any object for logging.
I.e.: truncate long lists and strings, even nested.
Max list and string length can be configured using env variables:
- LOGGING__MAX_LIST_LENGTH (int)
- LOGGING__MAX_STRING_LENGTH (int)
"""

def __init__(self) -> None:
self.max_list_length = int(
os.environ.get("LOGGING__MAX_LIST_LENGTH", DEFAULT_MAX_LIST_LENGTH)
)
self.max_string_length = int(
os.environ.get("LOGGING__MAX_STRING_LENGTH", DEFAULT_MAX_STRING_LENGTH)
)

def _prettify_dict(self, value: dict[Any, Any]) -> dict[Any, Any]:
return {
k: self(v) # prettyfy each value
for k, v in value.items()
}

def _prettify_list(self, value: list[Any]) -> list[Any]:
items = [
self(v) # prettify each item
for v in value[: self.max_list_length]
]
remaining_items = len(value) - len(items)
if remaining_items > 0:
items.append(f"... ({remaining_items} items)")
return items

def _prettify_str(self, value: str) -> str:
new_value = value[: self.max_string_length]
remaining_chars = len(value) - len(new_value)
if remaining_chars > 0:
new_value += f"... ({remaining_chars} chars)"
return new_value

def __call__(self, value: Any) -> Any:
"""Takes any value and returns a prettified version for logging."""
if isinstance(value, dict):
return self._prettify_dict(value)
if isinstance(value, BaseModel):
return self(value.model_dump())
if isinstance(value, list):
return self._prettify_list(value)
if isinstance(value, str):
return self._prettify_str(value)
return value


prettify = Prettifier()
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/e2e/test_kg_writer_component_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
if start_node.embedding_properties: # for mypy
for key, val in start_node.embedding_properties.items():
assert key in node_a.keys()
assert node_a.get(key) == [1.0, 2.0, 3.0]
assert val == node_a.get(key)

node_b = record["b"]
assert end_node.label in list(node_b.labels)
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/test_simplekgpipeline_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import neo4j
import pytest
from neo4j import Driver

from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/experimental/components/test_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

import pytest
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks
from neo4j_graphrag.experimental.components.types import (
TextChunk,
TextChunks,
)


@pytest.mark.asyncio
Expand Down
Empty file added tests/unit/utils/__init__.py
Empty file.
Loading

0 comments on commit feeddbb

Please sign in to comment.