Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Novice Lee 2025-01-09 08:48:00 +08:00
commit 00ad751a57
231 changed files with 3564 additions and 2382 deletions

View File

@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

View File

@ -14,7 +14,10 @@ if is_db_command():
app = create_migrations_app()
else:
if os.environ.get("FLASK_DEBUG", "False") != "True":
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey # type: ignore
# gevent

View File

@ -546,6 +546,11 @@ class AuthConfig(BaseSettings):
default=60,
)
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
description="Expiration time for refresh tokens in days",
default=30,
)
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400,
@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings):
default=4000,
)
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
description="Maximum number of child chunks to preview",
default=50,
)
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(

View File

@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
description="Name of the Milvus database to connect to (default is 'default')",
default="default",
)
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.14.2",
default="0.15.0",
)
COMMIT_SHA: str = Field(

View File

@ -57,12 +57,13 @@ class AppListApi(Resource):
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}

View File

@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -273,8 +273,7 @@ FROM
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)

View File

@ -2,7 +2,7 @@ import base64
import secrets
from flask import request
from flask_restful import Resource, reqparse
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -129,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
pass
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
return {"result": "success"}

View File

@ -4,7 +4,7 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restful import Resource
from flask_restful import Resource # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized

View File

@ -2,8 +2,8 @@ import datetime
import json
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound

View File

@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.LINDORM

View File

@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)

View File

@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper

View File

@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_user

View File

@ -1,7 +1,7 @@
import os
from flask import session
from flask_restful import Resource, reqparse
from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session

View File

@ -1,6 +1,6 @@
from functools import wraps
from flask_login import current_user
from flask_login import current_user # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden

View File

@ -1,5 +1,5 @@
from flask_login import current_user
from flask_restful import Resource
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required

View File

@ -1,5 +1,5 @@
from flask_login import current_user
from flask_restful import Resource, reqparse
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api

View File

@ -1,8 +1,8 @@
import io
from flask import request, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from configs import dify_config

View File

@ -1,5 +1,5 @@
from flask import request
from flask_restful import Resource, marshal_with
from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden
import services

View File

@ -1,4 +1,4 @@
from flask_restful import Resource
from flask_restful import Resource # type: ignore
from controllers.console.wraps import setup_required
from controllers.inner_api import api

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Optional
from flask import request
from flask_restful import reqparse
from flask_restful import reqparse # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session

View File

@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@ -74,7 +73,7 @@ class CompletionApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@ -133,7 +132,7 @@ class ChatApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
user=current_user,
source="datasets",
)
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
knowledge_config = KnowledgeConfig(**args)
@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)

View File

@ -1,5 +1,5 @@
from collections.abc import Callable
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import wraps
from typing import Optional
@ -8,6 +8,8 @@ from flask import current_app, request
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
return decorator
def validate_and_get_api_token(scope=None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
"""
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = (
db.session.query(ApiToken)
.filter(
ApiToken.token == auth_token,
ApiToken.type == scope,
current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
)
.first()
)
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()
if not api_token:
raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token

View File

@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value

View File

@ -14,7 +14,11 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from models.model import App, AppMode, EndUser

View File

@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[],
)
usage_dict = {}
usage_dict: dict[str, Optional[LLMUsage]] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",

View File

@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from extensions.ext_database import db
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
)
else:
inputs = self.application_generate_entity.inputs

View File

@ -68,24 +68,17 @@ from models.account import Account
from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
class AdvancedChatAppGenerateTaskPipeline:
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream: bool,
dialogue_count: int,
) -> None:
super().__init__(
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._wip_workflow_agent_logs = {}
self._message_cycle_manager = MessageCycleManage(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
self._conversation_name_generate_thread = None
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = ""
self._workflow_run_id: str = ""
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self._base_task_pipeline._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# init fake graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None
for queue_message in self._queue_manager.listen():
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
with Session(db.engine, expire_on_commit=False) as session:
err = self._base_task_pipeline._handle_error(
event=event, session=session, message_id=self._message_id
)
session.commit()
yield self._error_to_stream_response(err)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_start_to_stream_response(
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
)
node_retry_resp = self._workflow_node_retry_to_stream_response(
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
)
node_start_resp = self._workflow_node_start_to_stream_response(
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent):
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
)
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -368,32 +373,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if node_finish_resp:
yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event
)
response_finish = self._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response_finish:
yield response_finish
if node_finish_resp:
yield node_finish_resp
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_start_resp
@ -401,13 +409,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_finish_resp
@ -415,9 +427,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -429,9 +443,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -443,9 +459,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -460,8 +478,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -472,21 +490,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -497,21 +517,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=None,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -523,20 +545,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit()
yield workflow_finish_resp
yield self._error_to_stream_response(err)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -547,7 +571,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -561,18 +585,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._message_cycle_manager._handle_retriever_resources(event)
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
session.commit()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._message_cycle_manager._handle_annotation_reply(event)
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@ -593,29 +617,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(
yield self._message_cycle_manager._message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer
)
# Save message
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit()
yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue
@ -629,7 +659,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session)
message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
@ -693,20 +723,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
self._base_task_pipeline._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:
self._output_moderation_handler.append_new_token(text)
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
return False

View File

@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Union
from core.app.entities.app_invoke_entities import InvokeFrom
@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)

View File

@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
) -> Generator[str | Mapping[str, Any], None, None]: ...
@overload
def generate(
@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
def generate(
self,
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[Mapping, Generator[str, None, None]]:
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.

View File

@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union
from typing import Optional, Union
from sqlalchemy.orm import Session
@ -59,7 +59,6 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
@ -67,16 +66,11 @@ from models.workflow import (
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
class WorkflowAppGenerateTaskPipeline:
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser],
stream: bool,
) -> None:
super().__init__(
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState()
self._workflow_run_id = ""
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self._base_task_pipeline._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
graph_runtime_state = None
for queue_message in self._queue_manager.listen():
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event=event)
yield self._error_to_stream_response(err)
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
start_resp = self._workflow_start_to_stream_response(
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
)
node_start_response = self._workflow_node_start_to_stream_response(
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
node_success_response = self._workflow_node_finish_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_success_response:
yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session,
event=event,
)
node_failed_response = self._workflow_node_finish_to_stream_response(
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_start_resp
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_finish_resp
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@ -514,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue

View File

@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
# app config
app_config: WorkflowUIBasedAppConfig
workflow_run_id: Optional[str] = None
workflow_run_id: str
class SingleIterationRunEntity(BaseModel):
"""

View File

@ -329,6 +329,7 @@ class QueueAgentLogEvent(AppQueueEvent):
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
class QueueNodeRetryEvent(QueueNodeStartedEvent):

View File

@ -716,6 +716,7 @@ class AgentLogStreamResponse(StreamResponse):
error: str | None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
event: StreamEvent = StreamEvent.AGENT_LOG
data: Data

View File

@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
from core.app.entities.task_entities import (
ErrorStreamResponse,
PingStreamResponse,
TaskState,
)
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._start_at = time.perf_counter()

View File

@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def __init__(
self,
*,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""

View File

@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
@ -60,13 +59,20 @@ from models.workflow import (
WorkflowRunStatus,
)
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
from .exc import WorkflowRunNotFoundError
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
def _handle_workflow_run_start(
self,
@ -104,7 +110,8 @@ class WorkflowCycleManage:
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id
@ -241,7 +248,7 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution).where(
stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
@ -249,15 +256,18 @@ class WorkflowCycleManage:
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
running_workflow_node_executions = session.scalars(stmt).all()
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
finish_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = finish_at
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
if trace_manager:
trace_manager.add_trace_task(
@ -299,6 +309,8 @@ class WorkflowCycleManage:
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(
@ -325,6 +337,7 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed(
@ -364,6 +377,7 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(
@ -415,6 +429,8 @@ class WorkflowCycleManage:
workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
#################################################
@ -811,25 +827,23 @@ class WorkflowCycleManage:
return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
workflow_node_execution = session.scalar(stmt)
if not workflow_node_execution:
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
return workflow_node_execution
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
@ -848,5 +862,6 @@ class WorkflowCycleManage:
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
),
)

View File

@ -178,7 +178,7 @@ class ModelInstance:
def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> list[int]:
) -> int:
"""
Get number of tokens for llm
@ -191,7 +191,7 @@ class ModelInstance:
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
list[int],
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,

View File

@ -1,9 +1,47 @@
import tiktoken
from threading import Lock
from typing import Any
_tokenizer: Any = None
_lock = Lock()
class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
encoding = tiktoken.encoding_for_model("gpt2")
tiktoken_vec = encoding.encode(text)
return len(tiktoken_vec)
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
#
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
#
try:
import tiktoken
_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer

View File

@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
stream: bool,
inputs: Mapping,
files: list[dict],
):
) -> Generator[Mapping | str, None, None] | Mapping:
"""
invoke workflow app
"""
@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
stream: bool,
inputs: Mapping,
files: list[dict],
):
) -> Generator[Mapping | str, None, None] | Mapping:
"""
invoke completion app
"""

View File

@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
return summary.message.content
lines = content.split("\n")
new_lines = []
new_lines: list[str] = []
# split long line into multiple lines
for i in range(len(lines)):
line = lines[i]
@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
for i in new_lines: # type: ignore
if len(messages) == 0:
messages.append(i)
messages.append(i) # type: ignore
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
messages[-1] += i # type: ignore
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
messages.append(i) # type: ignore
else:
messages[-1] += i
messages[-1] += i # type: ignore
summaries = []
for i in range(len(messages)):

View File

@ -103,7 +103,7 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line))
yield type(**json.loads(line)) # type: ignore
def _request_with_model(
self,

View File

@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
return False
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")

View File

@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
self._client.delete_collection(self._collection_name)
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(ids=ids)

View File

@ -0,0 +1,104 @@
import json
import logging
from typing import Any, Optional
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)

View File

@ -6,6 +6,8 @@ class Field(Enum):
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"

View File

@ -2,6 +2,7 @@ import json
import logging
from typing import Any, Optional
from packaging import version
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
uri: str
token: Optional[str] = None
user: str
password: str
batch_size: int = 100
database: str = "default"
"""
Configuration class for Milvus connection.
"""
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
return values
def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
"uri": self.uri,
"token": self.token,
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session"
self._fields: list[str] = []
self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False
try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False
def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name):
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name):
return False
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
return len(result) > 0
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields
def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)
return docs
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
"""
Search for documents by vector similarity.
"""
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection
collection_name = self._collection_name
self._client.create_collection(
collection_name=collection_name,
collection_name=self._collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client
class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
),
)

View File

@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
return results.row_count > 0
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)

View File

@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:

View File

@ -167,6 +167,8 @@ class OracleVector(BaseVector):
return docs
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))

View File

@ -129,6 +129,11 @@ class PGVector(BaseVector):
return docs
def delete_by_ids(self, ids: list[str]) -> None:
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))

View File

@ -140,6 +140,8 @@ class TencentVector(BaseVector):
return False
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:

View File

@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "",
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@ -90,6 +90,12 @@ class Vector:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory

View File

@ -16,6 +16,7 @@ class VectorType(StrEnum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
ELASTICSEARCH_JA = "elasticsearch-ja"
LINDORM = "lindorm"
COUCHBASE = "couchbase"
BAIDU = "baidu"

View File

@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
plaintext_file_key = ""
plaintext_file_exists = False
if self._file_cache_key:
try:
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode("utf-8"))
if not plaintext_file_exists and self._file_cache_key:
storage.save(self._file_cache_key, text.encode("utf-8"))
return documents

View File

@ -3,6 +3,7 @@
import uuid
from typing import Optional
from configs import dify_config
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
if kwargs.get("preview"):
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
document.children = child_nodes
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)

View File

@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
items.append((provider, model.model))
return items
def get_runtime_parameters(self) -> list[ToolParameter]:
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
parameters = []
options = []

View File

@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
items.append((provider, model.model, voices))
return items
def get_runtime_parameters(self) -> list[ToolParameter]:
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
parameters = []
options = []

View File

@ -212,8 +212,23 @@ class ApiTool(Tool):
else:
body = body
if method in {"get", "head", "post", "put", "delete", "patch"}:
response: httpx.Response = getattr(ssrf_proxy, method)(
if method in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
url,
params=params,
headers=headers,

View File

@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
@field_validator("variable_name", mode="before")
@classmethod
def transform_variable_name(cls, value) -> str:
def transform_variable_name(cls, value: str) -> str:
"""
The variable name must be a string.
"""
@ -167,6 +167,7 @@ class ToolInvokeMessage(BaseModel):
error: Optional[str] = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
class MessageType(Enum):
TEXT = "text"

View File

@ -203,6 +203,7 @@ class AgentLogEvent(BaseAgentEvent):
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent

View File

@ -89,7 +89,11 @@ class AgentNode(ToolNode):
try:
# convert tool messages
yield from self._transform_message(message_stream, {}, parameters_for_log)
yield from self._transform_message(
message_stream,
{"provider": (cast(AgentNodeData, self.node_data)).agent_strategy_provider_name},
parameters_for_log,
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -170,7 +174,12 @@ class AgentNode(ToolNode):
extra.get("descrption", "") or tool_runtime.entity.description.llm
)
tool_value.append(tool_runtime.entity.model_dump(mode="json"))
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
}
)
value = tool_value
if parameter.type == "model-selector":
value = cast(dict[str, Any], value)

View File

@ -2,14 +2,18 @@ import csv
import io
import json
import logging
import operator
import os
import tempfile
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
import docx
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from docx.table import Table
from docx.text.paragraph import Paragraph
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
process_data=process_data,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
"""Extract text from a file based on its MIME type."""
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables
for table in doc.tables:
# Table header
try:
# table maybe cause errors so ignore it.
if len(table.rows) > 0 and table.rows[0].cells is not None:
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
# Process paragraphs and tables
for i, paragraph in enumerate(doc.paragraphs):
if paragraph.text.strip():
content_items.append((i, "paragraph", paragraph))
for i, table in enumerate(doc.tables):
content_items.append((i, "table", table))
# Sort content items based on their original position
content_items.sort(key=operator.itemgetter(0))
# Process sorted content
for _, item_type, item in content_items:
if item_type == "paragraph":
if isinstance(item, Table):
continue
text.append(item.text)
elif item_type == "table":
# Process tables
if not isinstance(item, Table):
continue
try:
# Check if any cell in the table has text
has_content = False
for row in table.rows:
for row in item.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
for row in table.rows[1:]:
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
markdown_table = f"| {' | '.join(cell_texts)} |\n"
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
for row in item.rows[1:]:
# Replace newlines with <br> in each cell
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
markdown_table += "| " + " | ".join(row_cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e

View File

@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data.
"""
method: Literal["get", "post", "put", "patch", "delete", "head"]
method: Literal[
"get",
"post",
"put",
"patch",
"delete",
"head",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str
authorization: HttpRequestNodeAuthorization
headers: str

View File

@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
class Executor:
method: Literal["get", "head", "post", "put", "delete", "patch"]
method: Literal[
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str
params: list[tuple[str, str]] | None
content: str | bytes | None
@ -67,12 +82,6 @@ class Executor:
node_data.authorization.config.api_key
).text
# check if node_data.url is a valid URL
if not node_data.url:
raise InvalidURLError("url is required")
if not node_data.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
self.url: str = node_data.url
self.method = node_data.method
self.auth = node_data.authorization
@ -99,6 +108,12 @@ class Executor:
def _init_url(self):
self.url = self.variable_pool.convert_template(self.node_data.url).text
# check if url is a valid URL
if not self.url:
raise InvalidURLError("url is required")
if not self.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
def _init_params(self):
"""
Almost same as _init_headers(), difference:
@ -158,7 +173,10 @@ class Executor:
if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string, strict=False)
try:
json_object = json.loads(json_string, strict=False)
except json.JSONDecodeError as e:
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
self.json = json_object
# self.json = self._parse_object_contains_variables(json_object)
case "binary":
@ -246,7 +264,22 @@ class Executor:
"""
do http request depending on api bundle
"""
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
if self.method not in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
@ -263,7 +296,7 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
# FIXME: fix type ignore, this maybe httpx type issue

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -197,6 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
variables: dict[str, Any] = {}
@ -264,6 +265,11 @@ class ToolNode(BaseNode[ToolNodeData]):
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value for key, value in msg_metadata.items() if key in NodeRunMetadataKey
}
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
@ -299,6 +305,7 @@ class ToolNode(BaseNode[ToolNodeData]):
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
)
# check if the agent log is already in the list
@ -309,6 +316,7 @@ class ToolNode(BaseNode[ToolNodeData]):
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
@ -319,7 +327,11 @@ class ToolNode(BaseNode[ToolNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info, NodeRunMetadataKey.AGENT_LOG: agent_logs},
metadata={
**agent_execution_metadata,
NodeRunMetadataKey.TOOL_INFO: tool_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
)
)

View File

@ -340,6 +340,10 @@ class WorkflowEntry:
):
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
# environment variable already exist in variable pool, not from user inputs
if variable_pool.get(variable_selector):
continue
# fetch variable node id from variable selector
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]

View File

@ -33,6 +33,7 @@ else
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
fi

View File

@ -46,7 +46,7 @@ def init_app(app: DifyApp):
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers:
if handler.formatter:

View File

@ -158,7 +158,7 @@ def _build_from_remote_url(
tenant_id: str,
transfer_method: FileTransferMethod,
) -> File:
url = mapping.get("url")
url = mapping.get("url") or mapping.get("remote_url")
if not url:
raise ValueError("Invalid file url")

View File

@ -6,7 +6,7 @@ import string
import subprocess
import time
import uuid
from collections.abc import Generator
from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional, Union, cast
@ -182,7 +182,7 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype="application/json")
else:

View File

@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource):
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json()
if response.status_code != 200:
raise ValueError(f"Error fetching block parent page ID: {response_json.message}")
message = response_json.get("message", "unknown error")
raise ValueError(f"Error fetching block parent page ID: {message}")
parent = response_json["parent"]
parent_type = parent["type"]
if parent_type == "block_id":

View File

@ -0,0 +1,41 @@
"""change workflow_runs.total_tokens to bigint
Revision ID: a91b476a53de
Revises: 923752d42eb6
Create Date: 2025-01-01 20:00:01.207369
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a91b476a53de'
down_revision = '923752d42eb6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.INTEGER(),
type_=sa.BigInteger(),
existing_nullable=False,
existing_server_default=sa.text('0'))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.BigInteger(),
type_=sa.INTEGER(),
existing_nullable=False,
existing_server_default=sa.text('0'))
# ### end Alembic commands ###

View File

@ -415,8 +415,8 @@ class WorkflowRun(Base):
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False)

2397
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -71,7 +71,7 @@ pyjwt = "~2.8.0"
pypdfium2 = "~4.30.0"
python = ">=3.11,<3.13"
python-docx = "~1.1.0"
python-dotenv = "1.0.0"
python-dotenv = "1.0.1"
pyyaml = "~6.0.1"
readabilipy = "0.2.0"
redis = { version = "~5.0.3", extras = ["hiredis"] }
@ -82,7 +82,7 @@ scikit-learn = "~1.5.1"
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
sqlalchemy = "~2.0.29"
starlette = "0.41.0"
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
tencentcloud-sdk-python-hunyuan = "~3.0.1294"
tiktoken = "~0.8.0"
tokenizers = "~0.15.0"
transformers = "~4.35.0"
@ -92,7 +92,7 @@ validators = "0.21.0"
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
websocket-client = "~1.7.0"
xinference-client = "0.15.2"
yarl = "~1.9.4"
yarl = "~1.18.3"
youtube-transcript-api = "~0.6.2"
zhipuai = "~2.1.5"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
@ -157,7 +157,7 @@ opensearch-py = "2.4.0"
oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5"
pymilvus = "~2.4.4"
pymilvus = "~2.5.0"
pymochow = "1.3.1"
pyobvector = "~0.1.6"
qdrant-client = "1.7.3"

View File

@ -168,23 +168,6 @@ def clean_unused_datasets_task():
else:
plan = plan_cache.decode()
if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)

View File

@ -67,7 +67,7 @@ class TokenPair(BaseModel):
REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
@ -921,6 +921,9 @@ class RegisterService:
def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
) -> str:
if not inviter:
raise ValueError("Inviter is required")
"""Invite new member"""
with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first()

View File

@ -2,6 +2,7 @@ import logging
import uuid
from enum import StrEnum
from typing import Optional, cast
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
@ -124,7 +125,7 @@ class AppDslService:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: bytes | str = b""
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return Import(
@ -133,13 +134,17 @@ class AppDslService:
error="yaml_url is required when import_mode is yaml-url",
)
try:
# tricky way to handle url from github to github raw url
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")):
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content
content = response.content.decode()
if len(content) > DSL_MAX_SIZE:
return Import(

View File

@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
"""
Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id
:param args: request args
:return:
@ -44,6 +45,8 @@ class AppService:
elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))

View File

@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@ -17,7 +17,6 @@ class BillingService:
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
@classmethod
@ -47,12 +46,13 @@ class BillingService:
retry=retry_if_exception_type(httpx.RequestError),
reraise=True,
)
def _send_request(cls, method, endpoint, json=None, params=None):
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json()
@staticmethod

View File

@ -86,7 +86,7 @@ class DatasetService:
else:
return [], 0
else:
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
# show all datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(
@ -382,7 +382,7 @@ class DatasetService:
if dataset.tenant_id != user.current_tenant_id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
@ -404,7 +404,7 @@ class DatasetService:
if not user:
raise ValueError("User not found")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.")
@ -434,6 +434,12 @@ class DatasetService:
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
"document_ids": [],
"count": 0,
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
@ -786,13 +792,19 @@ class DocumentService:
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
@ -804,7 +816,11 @@ class DocumentService:
"score_threshold_enabled": False,
}
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id:

View File

@ -27,7 +27,7 @@ class WorkflowAppService:
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
if keyword:
keyword_like_val = f"%{args['keyword'][:30]}%"
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val),

View File

@ -298,7 +298,7 @@ class WorkflowService:
start_at: float,
tenant_id: str,
node_id: str,
):
) -> WorkflowNodeExecution:
"""
Handle node run result

View File

@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
def test_validate_credentials():
model = GPUStackSpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_model():
model = GPUStackSpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
file = Path(audio_file_path).read_bytes()
result = model.invoke(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
file=file,
)
assert isinstance(result, str)
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@ -0,0 +1,24 @@
import os
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
def test_invoke_model():
model = GPUStackText2SpeechModel()
result = model.invoke(
model="cosyvoice-300m-sft",
tenant_id="test",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
content_text="Hello world",
voice="Chinese Female",
)
content = b""
for chunk in result:
content += chunk
assert content != b""

View File

@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest):
)
def search_by_full_text(self):
# milvus dos not support full text searching yet in < 2.3.x
# milvus support BM25 full text search after version 2.5.0-beta
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
assert len(hits_by_full_text) >= 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)

View File

@ -2,7 +2,7 @@ version: '3'
services:
# API service
api:
image: langgenius/dify-api:0.14.2
image: langgenius/dify-api:0.15.0
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -227,7 +227,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.14.2
image: langgenius/dify-api:0.15.0
restart: always
environment:
CONSOLE_WEB_URL: ''
@ -397,7 +397,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.14.2
image: langgenius/dify-web:0.15.0
restart: always
environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

Some files were not shown because too many files have changed in this diff Show More