Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Novice Lee 2025-01-09 08:48:00 +08:00
commit 00ad751a57
231 changed files with 3564 additions and 2382 deletions

View File

@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes # Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60 ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# celery configuration # celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

View File

@ -14,7 +14,10 @@ if is_db_command():
app = create_migrations_app() app = create_migrations_app()
else: else:
if os.environ.get("FLASK_DEBUG", "False") != "True": # It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey # type: ignore from gevent import monkey # type: ignore
# gevent # gevent

View File

@ -546,6 +546,11 @@ class AuthConfig(BaseSettings):
default=60, default=60,
) )
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
description="Expiration time for refresh tokens in days",
default=30,
)
LOGIN_LOCKOUT_DURATION: PositiveInt = Field( LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400, default=86400,
@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings):
default=4000, default=4000,
) )
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
description="Maximum number of child chunks to preview",
default=50,
)
class MultiModalTransferConfig(BaseSettings): class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(

View File

@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
description="Name of the Milvus database to connect to (default is 'default')", description="Name of the Milvus database to connect to (default is 'default')",
default="default", default="default",
) )
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.14.2", default="0.15.0",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

View File

@ -57,12 +57,13 @@ class AppListApi(Resource):
) )
parser.add_argument("name", type=str, location="args", required=False) parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args() args = parser.parse_args()
# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}

View File

@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -273,8 +273,7 @@ FROM
messages m messages m
ON c.id = m.conversation_id ON c.id = m.conversation_id
WHERE WHERE
c.override_model_configs IS NULL c.app_id = :app_id"""
AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)

View File

@ -2,7 +2,7 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -129,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
pass pass
except AccountRegisterError as are: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
return {"result": "success"} return {"result": "success"}

View File

@ -4,7 +4,7 @@ from typing import Optional
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource from flask_restful import Resource # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized

View File

@ -2,8 +2,8 @@ import datetime
import json import json
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound

View File

@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE | VectorType.MYSCALE
| VectorType.ORACLE | VectorType.ORACLE
| VectorType.ELASTICSEARCH | VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR | VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT | VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM | VectorType.LINDORM
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MYSCALE | VectorType.MYSCALE
| VectorType.ORACLE | VectorType.ORACLE
| VectorType.ELASTICSEARCH | VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE | VectorType.COUCHBASE
| VectorType.PGVECTOR | VectorType.PGVECTOR
| VectorType.LINDORM | VectorType.LINDORM

View File

@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument( parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )

View File

@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper

View File

@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.login import current_user from libs.login import current_user

View File

@ -1,7 +1,7 @@
import os import os
from flask import session from flask import session
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

View File

@ -1,6 +1,6 @@
from functools import wraps from functools import wraps
from flask_login import current_user from flask_login import current_user # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden

View File

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource from flask_restful import Resource # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required

View File

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api

View File

@ -1,8 +1,8 @@
import io import io
from flask import request, send_file from flask import request, send_file
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config

View File

@ -1,5 +1,5 @@
from flask import request from flask import request
from flask_restful import Resource, marshal_with from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services

View File

@ -1,4 +1,4 @@
from flask_restful import Resource from flask_restful import Resource # type: ignore
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import api from controllers.inner_api import api

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Optional from typing import Optional
from flask import request from flask import request
from flask_restful import reqparse from flask_restful import reqparse # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

View File

@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
@ -74,7 +73,7 @@ class CompletionApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
@ -133,7 +132,7 @@ class ChatApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
user=current_user, user=current_user,
source="datasets", source="datasets",
) )
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig(**args)
@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)

View File

@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import UTC, datetime from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
@ -8,6 +8,8 @@ from flask import current_app, request
from flask_login import user_logged_in # type: ignore from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore from flask_restful import Resource # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
return decorator return decorator
def validate_and_get_api_token(scope=None): def validate_and_get_api_token(scope: str | None = None):
""" """
Validate and get API token. Validate and get API token.
""" """
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'") raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = ( current_time = datetime.now(UTC).replace(tzinfo=None)
db.session.query(ApiToken) cutoff_time = current_time - timedelta(minutes=1)
.filter( with Session(db.engine, expire_on_commit=False) as session:
ApiToken.token == auth_token, update_stmt = (
ApiToken.type == scope, update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
) )
.first() result = session.execute(update_stmt)
) api_token = result.scalar_one_or_none()
if not api_token: if not api_token:
raise Unauthorized("Access token is invalid") stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) if not api_token:
db.session.commit() raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token return api_token

View File

@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value

View File

@ -14,7 +14,11 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser

View File

@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[], callbacks=[],
) )
usage_dict = {} usage_dict: dict[str, Optional[LLMUsage]] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
agent_response="", agent_response="",

View File

@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from extensions.ext_database import db from extensions.ext_database import db
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

View File

@ -68,24 +68,17 @@ from models.account import Account
from models.enums import CreatedByRole from models.enums import CreatedByRole
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution,
WorkflowRunStatus, WorkflowRunStatus,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): class AdvancedChatAppGenerateTaskPipeline:
""" """
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__( def __init__(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream: bool, stream: bool,
dialogue_count: int, dialogue_count: int,
) -> None: ) -> None:
super().__init__( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=stream, stream=stream,
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else: else:
raise NotImplementedError(f"User type not supported: {type(user)}") raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_id = workflow.id self._workflow_cycle_manager = WorkflowCycleManage(
self._workflow_features_dict = workflow.features_dict application_generate_entity=application_generate_entity,
workflow_system_variables={
self._conversation_id = conversation.id SystemVariableKey.QUERY: message.query,
self._conversation_mode = conversation.mode SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
self._message_id = message.id SystemVariableKey.USER_ID: user_session_id,
self._message_created_at = int(message.created_at.timestamp()) SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
self._workflow_system_variables = { SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.QUERY: message.query, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
SystemVariableKey.FILES: application_generate_entity.files, },
SystemVariableKey.CONVERSATION_ID: conversation.id, )
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {} self._message_cycle_manager = MessageCycleManage(
self._wip_workflow_agent_logs = {} application_generate_entity=application_generate_entity, task_state=self._task_state
)
self._conversation_name_generate_thread = None self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = [] self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = "" self._workflow_run_id: str = ""
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
""" """
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return: :return:
""" """
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query conversation_id=self._conversation_id, query=self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self._base_task_pipeline._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# init fake graph runtime state # init fake graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None graph_runtime_state: Optional[GraphRuntimeState] = None
for queue_message in self._queue_manager.listen(): for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event event = queue_message.event
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id) err = self._base_task_pipeline._handle_error(
event=event, session=session, message_id=self._message_id
)
session.commit() session.commit()
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start( workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id, user_id=self._user_id,
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not message: if not message:
raise ValueError(f"Message not found: {self._message_id}") raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_start_to_stream_response( workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_workflow_node_execution_retried( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_node_retry_to_stream_response( node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_node_execution_start( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_start_resp = self._workflow_node_start_to_stream_response( node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
# Record files if it's an answer node or end node # Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]: if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -368,32 +373,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event
)
response_finish = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
if response_finish:
yield response_finish
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_start_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_start_resp yield parallel_start_resp
@ -401,13 +409,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_finish_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_finish_resp yield parallel_finish_resp
@ -415,9 +427,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_start_resp = self._workflow_iteration_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -429,9 +443,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_next_resp = self._workflow_iteration_next_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -443,9 +459,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_finish_resp = self._workflow_iteration_completed_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -460,8 +478,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -472,21 +490,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent): elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_partial_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -497,21 +517,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -523,20 +545,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager, trace_manager=trace_manager,
exceptions_count=event.exceptions_count, exceptions_count=event.exceptions_count,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id) err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueStopEvent): elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state: if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -547,7 +571,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -561,18 +585,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._message_cycle_manager._handle_retriever_resources(event)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
session.commit() session.commit()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event) self._message_cycle_manager._handle_annotation_reply(event)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@ -593,29 +617,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(queue_message) tts_publisher.publish(queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response( yield self._message_cycle_manager._message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer: if output_moderation_answer:
self._task_state.answer = output_moderation_answer self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer) yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer
)
# Save message # Save message
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state) self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit() session.commit()
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent): elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event) yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else: else:
continue continue
@ -629,7 +659,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session) message = self._get_message(session=session)
message.answer = self._task_state.answer message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
@ -693,20 +723,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:param text: text :param text: text
:return: True if output moderation should direct output, otherwise False :return: True if output moderation should direct output, otherwise False
""" """
if self._output_moderation_handler: if self._base_task_pipeline._output_moderation_handler:
if self._output_moderation_handler.should_direct_output(): if self._base_task_pipeline._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output() self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._queue_manager.publish( self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
) )
self._queue_manager.publish( self._base_task_pipeline._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
) )
return True return True
else: else:
self._output_moderation_handler.append_new_token(text) self._base_task_pipeline._output_moderation_handler.append_new_token(text)
return False return False

View File

@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -1,6 +1,6 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator, Mapping
from typing import Any, Union from typing import Any, Union
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
def convert( def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)

View File

@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: Literal[True], streaming: Literal[True],
) -> Generator[str, None, None]: ... ) -> Generator[str | Mapping[str, Any], None, None]: ...
@overload @overload
def generate( def generate(
@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = False, streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ... ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
def generate( def generate(
self, self,
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser], user: Union[Account, EndUser],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> Union[Mapping, Generator[str, None, None]]: ) -> Union[Mapping, Generator[Mapping | str, None, None]]:
""" """
Generate App response. Generate App response.

View File

@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"] node_id=node_id, inputs=args["inputs"]
), ),
workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union from typing import Optional, Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -59,7 +59,6 @@ from models.workflow import (
Workflow, Workflow,
WorkflowAppLog, WorkflowAppLog,
WorkflowAppLogCreatedFrom, WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
) )
@ -67,16 +66,11 @@ from models.workflow import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): class WorkflowAppGenerateTaskPipeline:
""" """
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def __init__( def __init__(
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
super().__init__( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=stream, stream=stream,
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else: else:
raise ValueError(f"Invalid user type: {type(user)}") raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict self._workflow_features_dict = workflow.features_dict
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._workflow_run_id = "" self._workflow_run_id = ""
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return: :return:
""" """
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self._base_task_pipeline._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
""" """
graph_runtime_state = None graph_runtime_state = None
for queue_message in self._queue_manager.listen(): for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event event = queue_message.event
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event=event) err = self._base_task_pipeline._handle_error(event=event)
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start( workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id, user_id=self._user_id,
created_by_role=self._created_by_role, created_by_role=self._created_by_role,
) )
self._workflow_run_id = workflow_run.id self._workflow_run_id = workflow_run.id
start_resp = self._workflow_start_to_stream_response( start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
): ):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_workflow_node_execution_retried( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
response = self._workflow_node_retry_to_stream_response( response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_node_execution_start( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_start_response = self._workflow_node_start_to_stream_response( node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
node_success_response = self._workflow_node_finish_to_stream_response( session=session, event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, session=session,
event=event, event=event,
) )
node_failed_response = self._workflow_node_finish_to_stream_response( node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_start_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_start_resp yield parallel_start_resp
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_finish_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_finish_resp yield parallel_finish_resp
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_start_resp = self._workflow_iteration_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_next_resp = self._workflow_iteration_next_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_finish_resp = self._workflow_iteration_completed_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_partial_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -514,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
delta_text, from_variable_selector=event.from_variable_selector delta_text, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueAgentLogEvent): elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event) yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else: else:
continue continue

View File

@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
workflow_run_id: Optional[str] = None workflow_run_id: str
class SingleIterationRunEntity(BaseModel): class SingleIterationRunEntity(BaseModel):
""" """

View File

@ -329,6 +329,7 @@ class QueueAgentLogEvent(AppQueueEvent):
error: str | None error: str | None
status: str status: str
data: Mapping[str, Any] data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
class QueueNodeRetryEvent(QueueNodeStartedEvent): class QueueNodeRetryEvent(QueueNodeStartedEvent):

View File

@ -716,6 +716,7 @@ class AgentLogStreamResponse(StreamResponse):
error: str | None error: str | None
status: str status: str
data: Mapping[str, Any] data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
event: StreamEvent = StreamEvent.AGENT_LOG event: StreamEvent = StreamEvent.AGENT_LOG
data: Data data: Data

View File

@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
ErrorStreamResponse, ErrorStreamResponse,
PingStreamResponse, PingStreamResponse,
TaskState,
) )
from core.errors.error import QuotaExceededError from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__( def __init__(
self, self,
application_generate_entity: AppGenerateEntity, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
) -> None: ) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._start_at = time.perf_counter() self._start_at = time.perf_counter()

View File

@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage: class MessageCycleManage:
_application_generate_entity: Union[ def __init__(
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity self,
] *,
_task_state: Union[EasyUITaskState, WorkflowTaskState] application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
""" """

View File

@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
ParallelBranchStartStreamResponse, ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
WorkflowTaskState,
) )
from core.file import FILE_MODEL_IDENTITY, File from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -60,13 +59,20 @@ from models.workflow import (
WorkflowRunStatus, WorkflowRunStatus,
) )
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError from .exc import WorkflowRunNotFoundError
class WorkflowCycleManage: class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] def __init__(
_task_state: WorkflowTaskState self,
_workflow_system_variables: dict[SystemVariableKey, Any] *,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,
@ -104,7 +110,8 @@ class WorkflowCycleManage:
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run # init workflow run
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
workflow_run = WorkflowRun() workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id workflow_run.id = workflow_run_id
@ -241,7 +248,7 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution).where( stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
@ -249,15 +256,18 @@ class WorkflowCycleManage:
WorkflowNodeExecution.workflow_run_id == workflow_run.id, WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
) )
ids = session.scalars(stmt).all()
running_workflow_node_executions = session.scalars(stmt).all() # Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
finish_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = now
workflow_node_execution.finished_at = finish_at workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -299,6 +309,8 @@ class WorkflowCycleManage:
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success( def _handle_workflow_node_execution_success(
@ -325,6 +337,7 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
@ -364,6 +377,7 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
@ -415,6 +429,8 @@ class WorkflowCycleManage:
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution) session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
################################################# #################################################
@ -811,25 +827,23 @@ class WorkflowCycleManage:
return None return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
""" if self._workflow_run and self._workflow_run.id == workflow_run_id:
Refetch workflow run cached_workflow_run = self._workflow_run
:param workflow_run_id: workflow run id cached_workflow_run = session.merge(cached_workflow_run)
:return: return cached_workflow_run
"""
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt) workflow_run = session.scalar(stmt)
if not workflow_run: if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id) raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) if node_execution_id not in self._workflow_node_executions:
workflow_node_execution = session.scalar(stmt) raise ValueError(f"Workflow node execution not found: {node_execution_id}")
if not workflow_node_execution: cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
raise WorkflowNodeExecutionNotFoundError(node_execution_id) return cached_workflow_node_execution
return workflow_node_execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
""" """
@ -848,5 +862,6 @@ class WorkflowCycleManage:
error=event.error, error=event.error,
status=event.status, status=event.status,
data=event.data, data=event.data,
metadata=event.metadata,
), ),
) )

View File

@ -178,7 +178,7 @@ class ModelInstance:
def get_llm_num_tokens( def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> list[int]: ) -> int:
""" """
Get number of tokens for llm Get number of tokens for llm
@ -191,7 +191,7 @@ class ModelInstance:
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast( return cast(
list[int], int,
self._round_robin_invoke( self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,

View File

@ -1,9 +1,47 @@
import tiktoken from threading import Lock
from typing import Any
_tokenizer: Any = None
_lock = Lock()
class GPT2Tokenizer: class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text)
return len(tokens)
@staticmethod @staticmethod
def get_num_tokens(text: str) -> int: def get_num_tokens(text: str) -> int:
encoding = tiktoken.encoding_for_model("gpt2") # Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
tiktoken_vec = encoding.encode(text) #
return len(tiktoken_vec) # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
#
try:
import tiktoken
_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer

View File

@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
stream: bool, stream: bool,
inputs: Mapping, inputs: Mapping,
files: list[dict], files: list[dict],
): ) -> Generator[Mapping | str, None, None] | Mapping:
""" """
invoke workflow app invoke workflow app
""" """
@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
stream: bool, stream: bool,
inputs: Mapping, inputs: Mapping,
files: list[dict], files: list[dict],
): ) -> Generator[Mapping | str, None, None] | Mapping:
""" """
invoke completion app invoke completion app
""" """

View File

@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
return summary.message.content return summary.message.content
lines = content.split("\n") lines = content.split("\n")
new_lines = [] new_lines: list[str] = []
# split long line into multiple lines # split long line into multiple lines
for i in range(len(lines)): for i in range(len(lines)):
line = lines[i] line = lines[i]
@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
# merge lines into messages with max tokens # merge lines into messages with max tokens
messages: list[str] = [] messages: list[str] = []
for i in new_lines: for i in new_lines: # type: ignore
if len(messages) == 0: if len(messages) == 0:
messages.append(i) messages.append(i) # type: ignore
else: else:
if len(messages[-1]) + len(i) < max_tokens * 0.5: if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
messages[-1] += i messages[-1] += i # type: ignore
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
messages.append(i) messages.append(i) # type: ignore
else: else:
messages[-1] += i messages[-1] += i # type: ignore
summaries = [] summaries = []
for i in range(len(messages)): for i in range(len(messages)):

View File

@ -103,7 +103,7 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API and yield the response as a model. Make a stream request to the plugin daemon inner API and yield the response as a model.
""" """
for line in self._stream_request(method, path, params, headers, data, files): for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line)) yield type(**json.loads(line)) # type: ignore
def _request_with_model( def _request_with_model(
self, self,

View File

@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
return False return False
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids] quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")

View File

@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
self._client.delete_collection(self._collection_name) self._client.delete_collection(self._collection_name)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name) collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(ids=ids) collection.delete(ids=ids)

View File

@ -0,0 +1,104 @@
import json
import logging
from typing import Any, Optional
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
return bool(self._client.exists(index=self._collection_name, id=id)) return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
for id in ids: for id in ids:
self._client.delete(index=self._collection_name, id=id) self._client.delete(index=self._collection_name, id=id)

View File

@ -6,6 +6,8 @@ class Field(Enum):
METADATA_KEY = "metadata" METADATA_KEY = "metadata"
GROUP_KEY = "group_id" GROUP_KEY = "group_id"
VECTOR = "vector" VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text" TEXT_KEY = "text"
PRIMARY_KEY = "id" PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id" DOC_ID = "metadata.doc_id"

View File

@ -2,6 +2,7 @@ import json
import logging import logging
from typing import Any, Optional from typing import Any, Optional
from packaging import version
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore from pymilvus.milvus_client import IndexParams # type: ignore
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel): class MilvusConfig(BaseModel):
uri: str """
token: Optional[str] = None Configuration class for Milvus connection.
user: str """
password: str
batch_size: int = 100 uri: str # Milvus server URI
database: str = "default" token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"): if not values.get("uri"):
raise ValueError("config MILVUS_URI is required") raise ValueError("config MILVUS_URI is required")
if not values.get("user"): if not values.get("user"):
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
return values return values
def to_milvus_params(self): def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return { return {
"uri": self.uri, "uri": self.uri,
"token": self.token, "token": self.token,
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
class MilvusVector(BaseVector): class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""
def __init__(self, collection_name: str, config: MilvusConfig): def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name) super().__init__(collection_name)
self._client_config = config self._client_config = config
self._client = self._init_client(config) self._client = self._init_client(config)
self._consistency_level = "Session" self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False
try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False
def get_type(self) -> str: def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts] metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params) self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings) self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = [] insert_dict_list = []
for i in range(len(documents)): for i in range(len(documents)):
insert_dict = { insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content, Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata, Field.METADATA_KEY.value: documents[i].metadata,
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
insert_dict_list.append(insert_dict) insert_dict_list.append(insert_dict)
# Total insert count # Total insert count
total_count = len(insert_dict_list) total_count = len(insert_dict_list)
pks: list[str] = [] pks: list[str] = []
for i in range(0, total_count, 1000): for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection. # Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try: try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids) pks.extend(ids)
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
return pks return pks
def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query( result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
) )
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
return None return None
def delete_by_metadata_field(self, key: str, value: str): def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value) ids = self.get_ids_by_metadata_field(key, value)
if ids: if ids:
self._client.delete(collection_name=self._collection_name, pks=ids) self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
result = self._client.query( result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=ids) self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None: def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None) self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name): if not self._client.has_collection(self._collection_name):
return False return False
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
return len(result) > 0 return len(result) > 0
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields
def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)
return docs
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters. """
Search for documents by vector similarity.
"""
results = self._client.search( results = self._client.search(
collection_name=self._collection_name, collection_name=self._collection_name,
data=[query_vector], data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
) )
# Organize results.
docs = [] return self._process_search_results(
for result in results[0]: results,
metadata = result["entity"].get(Field.METADATA_KEY.value) output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
metadata["score"] = result["distance"] score_threshold=float(kwargs.get("score_threshold") or 0.0),
score_threshold = float(kwargs.get("score_threshold") or 0.0) )
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search """
return [] Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def create_collection( def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
): ):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name) lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20): with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
return return
# Grab the existing collection if it exists # Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name): if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim # Determine embedding dim
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
if metadatas: if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field # Create the text field, enable_analyzer will be set True to support milvus automatically
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the primary key field # Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors # Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
# Create the schema for the collection
schema = CollectionSchema(fields) schema = CollectionSchema(fields)
# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
for x in schema.fields: for x in schema.fields:
self._fields.append(x.name) self._fields.append(x.name)
# Since primary field is auto-id, no need to track it # Since primary field is auto-id, no need to track it
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
index_params_obj = IndexParams() index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection # Create the collection
collection_name = self._collection_name
self._client.create_collection( self._client.create_collection(
collection_name=collection_name, collection_name=self._collection_name,
schema=schema, schema=schema,
index_params=index_params_obj, index_params=index_params_obj,
consistency_level=self._consistency_level, consistency_level=self._consistency_level,
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient: def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client return client
class MilvusVectorFactory(AbstractVectorFactory): class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict: if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix collection_name = class_prefix
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER or "", user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "", password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "", database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
), ),
) )

View File

@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
return results.row_count > 0 return results.row_count > 0
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.command( self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
) )

View File

@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
return bool(cur.rowcount != 0) return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids) self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:

View File

@ -167,6 +167,8 @@ class OracleVector(BaseVector):
return docs return docs
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))

View File

@ -129,6 +129,11 @@ class PGVector(BaseVector):
return docs return docs
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))

View File

@ -140,6 +140,8 @@ class TencentVector(BaseVector):
return False return False
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._db.collection(self._collection_name).delete(document_ids=ids) self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:

View File

@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
) )
if not tidb_auth_binding: if not tidb_auth_binding:
idle_tidb_auth_binding = ( with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
db.session.query(TidbAuthBinding) tidb_auth_binding = (
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") db.session.query(TidbAuthBinding)
.limit(1) .filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none() .one_or_none()
) )
if idle_tidb_auth_binding: if tidb_auth_binding:
idle_tidb_auth_binding.active = True TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit() else:
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" idle_tidb_auth_binding = (
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding) db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id) .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none() .one_or_none()
) )
if tidb_auth_binding: if idle_tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else: else:
new_cluster = TidbService.create_tidb_serverless_cluster( new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "", dify_config.TIDB_PROJECT_ID or "",
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.add(new_tidb_auth_binding) db.session.add(new_tidb_auth_binding)
db.session.commit() db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else: else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@ -90,6 +90,12 @@ class Vector:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR: case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory

View File

@ -16,6 +16,7 @@ class VectorType(StrEnum):
TENCENT = "tencent" TENCENT = "tencent"
ORACLE = "oracle" ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch" ELASTICSEARCH = "elasticsearch"
ELASTICSEARCH_JA = "elasticsearch-ja"
LINDORM = "lindorm" LINDORM = "lindorm"
COUCHBASE = "couchbase" COUCHBASE = "couchbase"
BAIDU = "baidu" BAIDU = "baidu"

View File

@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
self._file_cache_key = file_cache_key self._file_cache_key = file_cache_key
def extract(self) -> list[Document]: def extract(self) -> list[Document]:
plaintext_file_key = ""
plaintext_file_exists = False plaintext_file_exists = False
if self._file_cache_key: if self._file_cache_key:
try: try:
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
text = "\n\n".join(text_list) text = "\n\n".join(text_list)
# save plaintext file for caching # save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key: if not plaintext_file_exists and self._file_cache_key:
storage.save(plaintext_file_key, text.encode("utf-8")) storage.save(self._file_cache_key, text.encode("utf-8"))
return documents return documents

View File

@ -3,6 +3,7 @@
import uuid import uuid
from typing import Optional from typing import Optional
from configs import dify_config
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_nodes = self._split_child_nodes( child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
) )
if kwargs.get("preview"):
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
document.children = child_nodes document.children = child_nodes
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content) hash = helper.generate_text_hash(document.page_content)

View File

@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
items.append((provider, model.model)) items.append((provider, model.model))
return items return items
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
parameters = [] parameters = []
options = [] options = []

View File

@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
items.append((provider, model.model, voices)) items.append((provider, model.model, voices))
return items return items
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
parameters = [] parameters = []
options = [] options = []

View File

@ -212,8 +212,23 @@ class ApiTool(Tool):
else: else:
body = body body = body
if method in {"get", "head", "post", "put", "delete", "patch"}: if method in {
response: httpx.Response = getattr(ssrf_proxy, method)( "get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
url, url,
params=params, params=params,
headers=headers, headers=headers,

View File

@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
@field_validator("variable_name", mode="before") @field_validator("variable_name", mode="before")
@classmethod @classmethod
def transform_variable_name(cls, value) -> str: def transform_variable_name(cls, value: str) -> str:
""" """
The variable name must be a string. The variable name must be a string.
""" """
@ -167,6 +167,7 @@ class ToolInvokeMessage(BaseModel):
error: Optional[str] = Field(default=None, description="The error message") error: Optional[str] = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log") status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data") data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
class MessageType(Enum): class MessageType(Enum):
TEXT = "text" TEXT = "text"

View File

@ -203,6 +203,7 @@ class AgentLogEvent(BaseAgentEvent):
error: str | None = Field(..., description="error") error: str | None = Field(..., description="error")
status: str = Field(..., description="status") status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data") data: Mapping[str, Any] = Field(..., description="data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent

View File

@ -89,7 +89,11 @@ class AgentNode(ToolNode):
try: try:
# convert tool messages # convert tool messages
yield from self._transform_message(message_stream, {}, parameters_for_log) yield from self._transform_message(
message_stream,
{"provider": (cast(AgentNodeData, self.node_data)).agent_strategy_provider_name},
parameters_for_log,
)
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
@ -170,7 +174,12 @@ class AgentNode(ToolNode):
extra.get("descrption", "") or tool_runtime.entity.description.llm extra.get("descrption", "") or tool_runtime.entity.description.llm
) )
tool_value.append(tool_runtime.entity.model_dump(mode="json")) tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
}
)
value = tool_value value = tool_value
if parameter.type == "model-selector": if parameter.type == "model-selector":
value = cast(dict[str, Any], value) value = cast(dict[str, Any], value)

View File

@ -2,14 +2,18 @@ import csv
import io import io
import json import json
import logging import logging
import operator
import os import os
import tempfile import tempfile
from typing import cast from collections.abc import Mapping, Sequence
from typing import Any, cast
import docx import docx
import pandas as pd import pandas as pd
import pypdfium2 # type: ignore import pypdfium2 # type: ignore
import yaml # type: ignore import yaml # type: ignore
from docx.table import Table
from docx.text.paragraph import Paragraph
from configs import dify_config from configs import dify_config
from core.file import File, FileTransferMethod, file_manager from core.file import File, FileTransferMethod, file_manager
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
process_data=process_data, process_data=process_data,
) )
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
"""Extract text from a file based on its MIME type.""" """Extract text from a file based on its MIME type."""
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
doc_file = io.BytesIO(file_content) doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file) doc = docx.Document(doc_file)
text = [] text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables # Keep track of paragraph and table positions
for table in doc.tables: content_items: list[tuple[int, str, Table | Paragraph]] = []
# Table header
try: # Process paragraphs and tables
# table maybe cause errors so ignore it. for i, paragraph in enumerate(doc.paragraphs):
if len(table.rows) > 0 and table.rows[0].cells is not None: if paragraph.text.strip():
content_items.append((i, "paragraph", paragraph))
for i, table in enumerate(doc.tables):
content_items.append((i, "table", table))
# Sort content items based on their original position
content_items.sort(key=operator.itemgetter(0))
# Process sorted content
for _, item_type, item in content_items:
if item_type == "paragraph":
if isinstance(item, Table):
continue
text.append(item.text)
elif item_type == "table":
# Process tables
if not isinstance(item, Table):
continue
try:
# Check if any cell in the table has text # Check if any cell in the table has text
has_content = False has_content = False
for row in table.rows: for row in item.rows:
if any(cell.text.strip() for cell in row.cells): if any(cell.text.strip() for cell in row.cells):
has_content = True has_content = True
break break
if has_content: if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" markdown_table = f"| {' | '.join(cell_texts)} |\n"
for row in table.rows[1:]: markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
for row in item.rows[1:]:
# Replace newlines with <br> in each cell
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
markdown_table += "| " + " | ".join(row_cells) + " |\n"
text.append(markdown_table) text.append(markdown_table)
except Exception as e: except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}") logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue continue
return "\n".join(text) return "\n".join(text)
except Exception as e: except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e

View File

@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data. Code Node Data.
""" """
method: Literal["get", "post", "put", "patch", "delete", "head"] method: Literal[
"get",
"post",
"put",
"patch",
"delete",
"head",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str url: str
authorization: HttpRequestNodeAuthorization authorization: HttpRequestNodeAuthorization
headers: str headers: str

View File

@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
class Executor: class Executor:
method: Literal["get", "head", "post", "put", "delete", "patch"] method: Literal[
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str url: str
params: list[tuple[str, str]] | None params: list[tuple[str, str]] | None
content: str | bytes | None content: str | bytes | None
@ -67,12 +82,6 @@ class Executor:
node_data.authorization.config.api_key node_data.authorization.config.api_key
).text ).text
# check if node_data.url is a valid URL
if not node_data.url:
raise InvalidURLError("url is required")
if not node_data.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
self.url: str = node_data.url self.url: str = node_data.url
self.method = node_data.method self.method = node_data.method
self.auth = node_data.authorization self.auth = node_data.authorization
@ -99,6 +108,12 @@ class Executor:
def _init_url(self): def _init_url(self):
self.url = self.variable_pool.convert_template(self.node_data.url).text self.url = self.variable_pool.convert_template(self.node_data.url).text
# check if url is a valid URL
if not self.url:
raise InvalidURLError("url is required")
if not self.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
def _init_params(self): def _init_params(self):
""" """
Almost same as _init_headers(), difference: Almost same as _init_headers(), difference:
@ -158,7 +173,10 @@ class Executor:
if len(data) != 1: if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item") raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string, strict=False) try:
json_object = json.loads(json_string, strict=False)
except json.JSONDecodeError as e:
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
self.json = json_object self.json = json_object
# self.json = self._parse_object_contains_variables(json_object) # self.json = self._parse_object_contains_variables(json_object)
case "binary": case "binary":
@ -246,7 +264,22 @@ class Executor:
""" """
do http request depending on api bundle do http request depending on api bundle
""" """
if self.method not in {"get", "head", "post", "put", "delete", "patch"}: if self.method not in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
raise InvalidHttpMethodError(f"Invalid http method {self.method}") raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = { request_args = {
@ -263,7 +296,7 @@ class Executor:
} }
# request_args = {k: v for k, v in request_args.items() if v is not None} # request_args = {k: v for k, v in request_args.items() if v is not None}
try: try:
response = getattr(ssrf_proxy, self.method)(**request_args) response = getattr(ssrf_proxy, self.method.lower())(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e)) raise HttpRequestNodeError(str(e))
# FIXME: fix type ignore, this maybe httpx type issue # FIXME: fix type ignore, this maybe httpx type issue

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast from typing import Any, Optional, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -197,6 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
json: list[dict] = [] json: list[dict] = []
agent_logs: list[AgentLogEvent] = [] agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
@ -264,6 +265,11 @@ class ToolNode(BaseNode[ToolNodeData]):
) )
elif message.type == ToolInvokeMessage.MessageType.JSON: elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage) assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value for key, value in msg_metadata.items() if key in NodeRunMetadataKey
}
json.append(message.message.json_object) json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK: elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, ToolInvokeMessage.TextMessage)
@ -299,6 +305,7 @@ class ToolNode(BaseNode[ToolNodeData]):
status=message.message.status.value, status=message.message.status.value,
data=message.message.data, data=message.message.data,
label=message.message.label, label=message.message.label,
metadata=message.message.metadata,
) )
# check if the agent log is already in the list # check if the agent log is already in the list
@ -309,6 +316,7 @@ class ToolNode(BaseNode[ToolNodeData]):
log.status = agent_log.status log.status = agent_log.status
log.error = agent_log.error log.error = agent_log.error
log.label = agent_log.label log.label = agent_log.label
log.metadata = agent_log.metadata
break break
else: else:
agent_logs.append(agent_log) agent_logs.append(agent_log)
@ -319,7 +327,11 @@ class ToolNode(BaseNode[ToolNodeData]):
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": files, "json": json, **variables}, outputs={"text": text, "files": files, "json": json, **variables},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info, NodeRunMetadataKey.AGENT_LOG: agent_logs}, metadata={
**agent_execution_metadata,
NodeRunMetadataKey.TOOL_INFO: tool_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log, inputs=parameters_for_log,
) )
) )

View File

@ -340,6 +340,10 @@ class WorkflowEntry:
): ):
raise ValueError(f"Variable key {node_variable} not found in user inputs.") raise ValueError(f"Variable key {node_variable} not found in user inputs.")
# environment variable already exist in variable pool, not from user inputs
if variable_pool.get(variable_selector):
continue
# fetch variable node id from variable selector # fetch variable node id from variable selector
variable_node_id = variable_selector[0] variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:] variable_key_list = variable_selector[1:]

View File

@ -33,6 +33,7 @@ else
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \ --workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \ --worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \ --timeout ${GUNICORN_TIMEOUT:-200} \
app:app app:app
fi fi

View File

@ -46,7 +46,7 @@ def init_app(app: DifyApp):
timezone = pytz.timezone(log_tz) timezone = pytz.timezone(log_tz)
def time_converter(seconds): def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers: for handler in logging.root.handlers:
if handler.formatter: if handler.formatter:

View File

@ -158,7 +158,7 @@ def _build_from_remote_url(
tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
) -> File: ) -> File:
url = mapping.get("url") url = mapping.get("url") or mapping.get("remote_url")
if not url: if not url:
raise ValueError("Invalid file url") raise ValueError("Invalid file url")

View File

@ -6,7 +6,7 @@ import string
import subprocess import subprocess
import time import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator, Mapping
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
@ -182,7 +182,7 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest() return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response: def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict): if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype="application/json") return Response(response=json.dumps(response), status=200, mimetype="application/json")
else: else:

View File

@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource):
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise ValueError(f"Error fetching block parent page ID: {response_json.message}") message = response_json.get("message", "unknown error")
raise ValueError(f"Error fetching block parent page ID: {message}")
parent = response_json["parent"] parent = response_json["parent"]
parent_type = parent["type"] parent_type = parent["type"]
if parent_type == "block_id": if parent_type == "block_id":

View File

@ -0,0 +1,41 @@
"""change workflow_runs.total_tokens to bigint
Revision ID: a91b476a53de
Revises: 923752d42eb6
Create Date: 2025-01-01 20:00:01.207369
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a91b476a53de'
down_revision = '923752d42eb6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.INTEGER(),
type_=sa.BigInteger(),
existing_nullable=False,
existing_server_default=sa.text('0'))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.BigInteger(),
type_=sa.INTEGER(),
existing_nullable=False,
existing_server_default=sa.text('0'))
# ### end Alembic commands ###

View File

@ -415,8 +415,8 @@ class WorkflowRun(Base):
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text) error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)

2397
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -71,7 +71,7 @@ pyjwt = "~2.8.0"
pypdfium2 = "~4.30.0" pypdfium2 = "~4.30.0"
python = ">=3.11,<3.13" python = ">=3.11,<3.13"
python-docx = "~1.1.0" python-docx = "~1.1.0"
python-dotenv = "1.0.0" python-dotenv = "1.0.1"
pyyaml = "~6.0.1" pyyaml = "~6.0.1"
readabilipy = "0.2.0" readabilipy = "0.2.0"
redis = { version = "~5.0.3", extras = ["hiredis"] } redis = { version = "~5.0.3", extras = ["hiredis"] }
@ -82,7 +82,7 @@ scikit-learn = "~1.5.1"
sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
sqlalchemy = "~2.0.29" sqlalchemy = "~2.0.29"
starlette = "0.41.0" starlette = "0.41.0"
tencentcloud-sdk-python-hunyuan = "~3.0.1158" tencentcloud-sdk-python-hunyuan = "~3.0.1294"
tiktoken = "~0.8.0" tiktoken = "~0.8.0"
tokenizers = "~0.15.0" tokenizers = "~0.15.0"
transformers = "~4.35.0" transformers = "~4.35.0"
@ -92,7 +92,7 @@ validators = "0.21.0"
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
websocket-client = "~1.7.0" websocket-client = "~1.7.0"
xinference-client = "0.15.2" xinference-client = "0.15.2"
yarl = "~1.9.4" yarl = "~1.18.3"
youtube-transcript-api = "~0.6.2" youtube-transcript-api = "~0.6.2"
zhipuai = "~2.1.5" zhipuai = "~2.1.5"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
@ -157,7 +157,7 @@ opensearch-py = "2.4.0"
oracledb = "~2.2.1" oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5" pgvector = "0.2.5"
pymilvus = "~2.4.4" pymilvus = "~2.5.0"
pymochow = "1.3.1" pymochow = "1.3.1"
pyobvector = "~0.1.6" pyobvector = "~0.1.6"
qdrant-client = "1.7.3" qdrant-client = "1.7.3"

View File

@ -168,23 +168,6 @@ def clean_unused_datasets_task():
else: else:
plan = plan_cache.decode() plan = plan_cache.decode()
if plan == "sandbox": if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index # remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None)

View File

@ -67,7 +67,7 @@ class TokenPair(BaseModel):
REFRESH_TOKEN_PREFIX = "refresh_token:" REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30) REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService: class AccountService:
@ -921,6 +921,9 @@ class RegisterService:
def invite_new_member( def invite_new_member(
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
) -> str: ) -> str:
if not inviter:
raise ValueError("Inviter is required")
"""Invite new member""" """Invite new member"""
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first() account = session.query(Account).filter_by(email=email).first()

View File

@ -2,6 +2,7 @@ import logging
import uuid import uuid
from enum import StrEnum from enum import StrEnum
from typing import Optional, cast from typing import Optional, cast
from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
import yaml # type: ignore import yaml # type: ignore
@ -124,7 +125,7 @@ class AppDslService:
raise ValueError(f"Invalid import_mode: {import_mode}") raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content # Get YAML content
content: bytes | str = b"" content: str = ""
if mode == ImportMode.YAML_URL: if mode == ImportMode.YAML_URL:
if not yaml_url: if not yaml_url:
return Import( return Import(
@ -133,13 +134,17 @@ class AppDslService:
error="yaml_url is required when import_mode is yaml-url", error="yaml_url is required when import_mode is yaml-url",
) )
try: try:
# tricky way to handle url from github to github raw url parsed_url = urlparse(yaml_url)
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")): if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/") yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status() response.raise_for_status()
content = response.content content = response.content.decode()
if len(content) > DSL_MAX_SIZE: if len(content) > DSL_MAX_SIZE:
return Import( return Import(

View File

@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
class AppService: class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
""" """
Get app list with pagination Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id :param tenant_id: tenant id
:param args: request args :param args: request args
:return: :return:
@ -44,6 +45,8 @@ class AppService:
elif args["mode"] == "channel": elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value) filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"): if args.get("name"):
name = args["name"][:30] name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%")) filters.append(App.name.ilike(f"%{name}%"))

View File

@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Literal, Optional
import httpx import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@ -17,7 +17,6 @@ class BillingService:
params = {"tenant_id": tenant_id} params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params) billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info return billing_info
@classmethod @classmethod
@ -47,12 +46,13 @@ class BillingService:
retry=retry_if_exception_type(httpx.RequestError), retry=retry_if_exception_type(httpx.RequestError),
reraise=True, reraise=True,
) )
def _send_request(cls, method, endpoint, json=None, params=None): def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers) response = httpx.request(method, url, json=json, params=params, headers=headers)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json() return response.json()
@staticmethod @staticmethod

View File

@ -86,7 +86,7 @@ class DatasetService:
else: else:
return [], 0 return [], 0
else: else:
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): if user.current_role != TenantAccountRole.OWNER:
# show all datasets that the user has permission to access # show all datasets that the user has permission to access
if permitted_dataset_ids: if permitted_dataset_ids:
query = query.filter( query = query.filter(
@ -382,7 +382,7 @@ class DatasetService:
if dataset.tenant_id != user.current_tenant_id: if dataset.tenant_id != user.current_tenant_id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@ -404,7 +404,7 @@ class DatasetService:
if not user: if not user:
raise ValueError("User not found") raise ValueError("User not found")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id: if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@ -434,6 +434,12 @@ class DatasetService:
@staticmethod @staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict: def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
"document_ids": [],
"count": 0,
}
# get recent 30 days auto disable logs # get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30) start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
@ -786,13 +792,19 @@ class DocumentService:
dataset.indexing_technique = knowledge_config.indexing_technique dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager() model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance( if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING dataset_embedding_model = knowledge_config.embedding_model
) dataset_embedding_model_provider = knowledge_config.embedding_model_provider
dataset.embedding_model = embedding_model.model else:
dataset.embedding_model_provider = embedding_model.provider embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model dataset_embedding_model_provider, dataset_embedding_model
) )
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model: if not dataset.retrieval_model:
@ -804,7 +816,11 @@ class DocumentService:
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = [] documents = []
if knowledge_config.original_document_id: if knowledge_config.original_document_id:

View File

@ -27,7 +27,7 @@ class WorkflowAppService:
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
if keyword: if keyword:
keyword_like_val = f"%{args['keyword'][:30]}%" keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
keyword_conditions = [ keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val), WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val), WorkflowRun.outputs.ilike(keyword_like_val),

View File

@ -298,7 +298,7 @@ class WorkflowService:
start_at: float, start_at: float,
tenant_id: str, tenant_id: str,
node_id: str, node_id: str,
): ) -> WorkflowNodeExecution:
""" """
Handle node run result Handle node run result

View File

@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if not dataset: if not dataset:
raise Exception("Dataset not found") raise Exception("Dataset not found")
index_type = dataset.doc_form index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove": if action == "remove":
index_processor.clean(dataset, None, with_keywords=False) index_processor.clean(dataset, None, with_keywords=False)
@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
{"indexing_status": "error", "error": str(e)}, synchronize_session=False {"indexing_status": "error", "error": str(e)}, synchronize_session=False
) )
db.session.commit() db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
def test_validate_credentials():
model = GPUStackSpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_model():
model = GPUStackSpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
file = Path(audio_file_path).read_bytes()
result = model.invoke(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
file=file,
)
assert isinstance(result, str)
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@ -0,0 +1,24 @@
import os
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
def test_invoke_model():
model = GPUStackText2SpeechModel()
result = model.invoke(
model="cosyvoice-300m-sft",
tenant_id="test",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
content_text="Hello world",
voice="Chinese Female",
)
content = b""
for chunk in result:
content += chunk
assert content != b""

View File

@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest):
) )
def search_by_full_text(self): def search_by_full_text(self):
# milvus dos not support full text searching yet in < 2.3.x # milvus support BM25 full text search after version 2.5.0-beta
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0 assert len(hits_by_full_text) >= 0
def get_ids_by_metadata_field(self): def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)

View File

@ -2,7 +2,7 @@ version: '3'
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:0.14.2 image: langgenius/dify-api:0.15.0
restart: always restart: always
environment: environment:
# Startup mode, 'api' starts the API server. # Startup mode, 'api' starts the API server.
@ -227,7 +227,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:0.14.2 image: langgenius/dify-api:0.15.0
restart: always restart: always
environment: environment:
CONSOLE_WEB_URL: '' CONSOLE_WEB_URL: ''
@ -397,7 +397,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:0.14.2 image: langgenius/dify-web:0.15.0
restart: always restart: always
environment: environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

Some files were not shown because too many files have changed in this diff Show More