Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/support extractor tools update [WIP] #11009

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/controllers/console/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
external,
hit_testing,
website,
fta_test,
)

# Import explore controllers
Expand Down
145 changes: 145 additions & 0 deletions api/controllers/console/datasets/fta_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import json

import requests
from flask import Response
from flask_restful import Resource, reqparse
from sqlalchemy import text

from controllers.console import api
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.fta import ComponentFailure, ComponentFailureStats


class FATTestApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("log_process_data", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
print(args["log_process_data"])
# Extract the JSON string from the text field
json_str = args["log_process_data"].strip("```json\\n").strip("```").strip().replace("\\n", "")
log_data = json.loads(json_str)
db.session.query(ComponentFailure).delete()
for data in log_data:
if not isinstance(data, dict):
raise TypeError("Data must be a dictionary.")

required_keys = {"Date", "Component", "FailureMode", "Cause", "RepairAction", "Technician"}
if not required_keys.issubset(data.keys()):
raise ValueError(f"Data dictionary must contain the following keys: {required_keys}")

try:
# Clear existing stats
component_failure = ComponentFailure(
Date=data["Date"],
Component=data["Component"],
FailureMode=data["FailureMode"],
Cause=data["Cause"],
RepairAction=data["RepairAction"],
Technician=data["Technician"],
)
db.session.add(component_failure)
db.session.commit()
except Exception as e:
print(e)
# Clear existing stats
db.session.query(ComponentFailureStats).delete()

# Insert calculated statistics
try:
db.session.execute(
text("""
INSERT INTO component_failure_stats ("Component", "FailureMode", "Cause", "PossibleAction", "Probability", "MTBF")
SELECT
cf."Component",
cf."FailureMode",
cf."Cause",
cf."RepairAction" as "PossibleAction",
COUNT(*) * 1.0 / (SELECT COUNT(*) FROM component_failure WHERE "Component" = cf."Component") AS "Probability",
COALESCE(AVG(EXTRACT(EPOCH FROM (next_failure_date::timestamp - cf."Date"::timestamp)) / 86400.0),0)AS "MTBF"
FROM (
SELECT
"Component",
"FailureMode",
"Cause",
"RepairAction",
"Date",
LEAD("Date") OVER (PARTITION BY "Component", "FailureMode", "Cause" ORDER BY "Date") AS next_failure_date
FROM
component_failure
) cf
GROUP BY
cf."Component", cf."FailureMode", cf."Cause", cf."RepairAction";
""")
)
db.session.commit()
except Exception as e:
db.session.rollback()
print(f"Error during stats calculation: {e}")
# output format
# [
# (17, 'Hydraulic system', 'Leak', 'Hose rupture', 'Replaced hydraulic hose', 0.3333333333333333, None),
# (18, 'Hydraulic system', 'Leak', 'Seal Wear', 'Replaced the faulty seal', 0.3333333333333333, None),
# (19, 'Hydraulic system', 'Pressure drop', 'Fluid leak', 'Replaced hydraulic fluid and seals', 0.3333333333333333, None)
# ]

component_failure_stats = db.session.query(ComponentFailureStats).all()
# Convert stats to list of tuples format
stats_list = []
for stat in component_failure_stats:
stats_list.append(
(
stat.StatID,
stat.Component,
stat.FailureMode,
stat.Cause,
stat.PossibleAction,
stat.Probability,
stat.MTBF,
)
)
return {"data": stats_list}, 200


# generate-fault-tree
class GenerateFaultTreeApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("llm_text", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
entities = args["llm_text"].replace("```", "").replace("\\n", "\n")
print(entities)
request_data = {"fault_tree_text": entities}
url = "https://fta.cognitech-dev.live/generate-fault-tree"
headers = {"accept": "application/json", "Content-Type": "application/json"}

response = requests.post(url, json=request_data, headers=headers)
print(response.json())
return {"data": response.json()}, 200


class ExtractSVGApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("svg_text", nullable=False, required=True, type=str, location="args")
args = parser.parse_args()
# svg_text = ''.join(args["svg_text"].splitlines())
svg_text = args["svg_text"].replace("\n", "")
svg_text = svg_text.replace('"', '"')
print(svg_text)
svg_text_json = json.loads(svg_text)
svg_content = svg_text_json.get("data").get("svg_content")[0]
svg_content = svg_content.replace("\n", "").replace('"', '"')
file_key = "fta_svg/" + "fat.svg"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, svg_content.encode("utf-8"))
generator = storage.load(file_key, stream=True)

return Response(generator, mimetype="image/svg+xml")


api.add_resource(FATTestApi, "/fta/db-handler")
api.add_resource(GenerateFaultTreeApi, "/fta/generate-fault-tree")
api.add_resource(ExtractSVGApi, "/fta/extract-svg")
34 changes: 34 additions & 0 deletions api/core/file/file_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import tempfile
from pathlib import Path

from configs import dify_config
from core.file import file_repository
Expand All @@ -18,6 +20,38 @@
from .tool_file_parser import ToolFileParser


def download_to_target_path(f: File, temp_dir: str, /):
if f.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
suffix = Path(tool_file.file_key).suffix
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
_download_file_to_target_path(tool_file.file_key, target_path)
return target_path
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
suffix = Path(upload_file.key).suffix
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
_download_file_to_target_path(upload_file.key, target_path)
return target_path
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")


def _download_file_to_target_path(path: str, target_path: str, /):
"""
Download and return the contents of a file as bytes.

This function loads the file from storage and ensures it's in bytes format.

Args:
path (str): The path to the file in storage.
target_path (str): The path to the target file.
Raises:
ValueError: If the loaded file is not a bytes object.
"""
storage.download(path, target_path)


def get_attr(*, file: File, attr: FileAttribute):
match attr:
case FileAttribute.TYPE:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Any

from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class FileExtractorProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
pass
15 changes: 15 additions & 0 deletions api/core/tools/provider/builtin/file_extractor/file_extractor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
identity:
author: Jyong
name: file_extractor
label:
en_US: File Extractor
zh_Hans: 文件提取
pt_BR: File Extractor
description:
en_US: Extract text from file
zh_Hans: 从文件中提取文本
pt_BR: Extract text from file
icon: icon.png
tags:
- utilities
- productivity
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import tempfile
from typing import Any, Union

from core.file.enums import FileType
from core.file.file_manager import download_to_target_path
from core.rag.extractor.text_extractor import TextExtractor
from core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError
from core.tools.tool.builtin_tool import BuiltinTool


class FileExtractorTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
# image file for workflow mode
file = tool_parameters.get("text_file")
if file and file.type != FileType.DOCUMENT:
raise ToolParameterValidationError("Not a valid document")

if file:
with tempfile.TemporaryDirectory() as temp_dir:
file_path = download_to_target_path(file, temp_dir)
extractor = TextExtractor(file_path, autodetect_encoding=True)
documents = extractor.extract()
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=tool_parameters.get("max_token", 500),
chunk_overlap=0,
fixed_separator=tool_parameters.get("separator", "\n\n"),
separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=None,
)
chunks = character_splitter.split_documents(documents)

content = "\n".join([chunk.page_content for chunk in chunks])
return self.create_text_message(content)

else:
raise ToolParameterValidationError("Please provide either file")
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
identity:
name: text extractor
author: Jyong
label:
en_US: Text extractor
zh_Hans: Text 文本解析
description:
en_US: Extract content from text file and support split to chunks by split characters and token length
zh_Hans: 支持从文本文件中提取内容并支持通过分割字符和令牌长度分割成块
pt_BR: Extract content from text file and support split to chunks by split characters and token length
description:
human:
en_US: Text extractor is a text extract tool
zh_Hans: Text extractor 是一个文本提取工具
pt_BR: Text extractor is a text extract tool
llm: Text extractor is a tool used to extract text file
parameters:
- name: text_file
type: file
label:
en_US: Text file
human_description:
en_US: The text file to be extracted.
zh_Hans: 要提取的 text 文档。
llm_description: you should not input this parameter. just input the image_id.
form: llm
- name: separator
type: string
required: false
label:
en_US: split character
zh_Hans: 分隔符号
human_description:
en_US: Text content split character
zh_Hans: 用于文档分隔的符号
llm_description: it is used for split content to chunks
form: form
- name: max_token
type: number
required: false
label:
en_US: Maximum chunk length
zh_Hans: 最大分段长度
human_description:
en_US: Maximum chunk length
zh_Hans: 最大分段长度
llm_description: it is used for limit chunk's max length
form: form

78 changes: 78 additions & 0 deletions api/models/fta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from extensions.ext_database import db


class ComponentFailure(db.Model):
__tablename__ = "component_failure"
__table_args__ = (
db.UniqueConstraint("Date", "Component", "FailureMode", "Cause", "Technician", name="unique_failure_entry"),
)

FailureID = db.Column(db.Integer, primary_key=True, autoincrement=True)
Date = db.Column(db.Date, nullable=False)
Component = db.Column(db.String(255), nullable=False)
FailureMode = db.Column(db.String(255), nullable=False)
Cause = db.Column(db.String(255), nullable=False)
RepairAction = db.Column(db.Text, nullable=True)
Technician = db.Column(db.String(255), nullable=False)


class Maintenance(db.Model):
__tablename__ = "maintenance"

MaintenanceID = db.Column(db.Integer, primary_key=True, autoincrement=True)
MaintenanceType = db.Column(db.String(255), nullable=False)
MaintenanceDate = db.Column(db.Date, nullable=False)
ServiceDescription = db.Column(db.Text, nullable=True)
PartsReplaced = db.Column(db.Text, nullable=True)
Technician = db.Column(db.String(255), nullable=False)


class OperationalData(db.Model):
__tablename__ = "operational_data"

OperationID = db.Column(db.Integer, primary_key=True, autoincrement=True)
CraneUsage = db.Column(db.Integer, nullable=False)
LoadWeight = db.Column(db.Float, nullable=False)
LoadFrequency = db.Column(db.Integer, nullable=False)
EnvironmentalConditions = db.Column(db.Text, nullable=True)


class IncidentData(db.Model):
__tablename__ = "incident_data"

IncidentID = db.Column(db.Integer, primary_key=True, autoincrement=True)
IncidentDescription = db.Column(db.Text, nullable=False)
IncidentDate = db.Column(db.Date, nullable=False)
Consequences = db.Column(db.Text, nullable=True)
ResponseActions = db.Column(db.Text, nullable=True)


class ReliabilityData(db.Model):
__tablename__ = "reliability_data"

ComponentID = db.Column(db.Integer, primary_key=True, autoincrement=True)
ComponentName = db.Column(db.String(255), nullable=False)
MTBF = db.Column(db.Float, nullable=False)
FailureRate = db.Column(db.Float, nullable=False)


class SafetyData(db.Model):
__tablename__ = "safety_data"

SafetyID = db.Column(db.Integer, primary_key=True, autoincrement=True)
SafetyInspectionDate = db.Column(db.Date, nullable=False)
SafetyFindings = db.Column(db.Text, nullable=True)
SafetyIncidentDescription = db.Column(db.Text, nullable=True)
ComplianceStatus = db.Column(db.String(50), nullable=False)


class ComponentFailureStats(db.Model):
__tablename__ = "component_failure_stats"

StatID = db.Column(db.Integer, primary_key=True, autoincrement=True)
Component = db.Column(db.String(255), nullable=False)
FailureMode = db.Column(db.String(255), nullable=False)
Cause = db.Column(db.String(255), nullable=False)
PossibleAction = db.Column(db.Text, nullable=True)
Probability = db.Column(db.Float, nullable=False)
MTBF = db.Column(db.Float, nullable=False)
Loading
Loading