mirror of https://github.com/langgenius/dify.git
Merge branch 'fix/chore-fix' into dev/plugin-deploy
This commit is contained in:
commit
00ad751a57
|
|
@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
|
|||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# Refresh token expiration time in days
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,10 @@ if is_db_command():
|
|||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
if os.environ.get("FLASK_DEBUG", "False") != "True":
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey # type: ignore
|
||||
|
||||
# gevent
|
||||
|
|
|
|||
|
|
@ -546,6 +546,11 @@ class AuthConfig(BaseSettings):
|
|||
default=60,
|
||||
)
|
||||
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
|
||||
description="Expiration time for refresh tokens in days",
|
||||
default=30,
|
||||
)
|
||||
|
||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||
default=86400,
|
||||
|
|
@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings):
|
|||
default=4000,
|
||||
)
|
||||
|
||||
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
|
||||
description="Maximum number of child chunks to preview",
|
||||
default=50,
|
||||
)
|
||||
|
||||
|
||||
class MultiModalTransferConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||
|
|
|
|||
|
|
@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
|
|||
description="Name of the Milvus database to connect to (default is 'default')",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
|
||||
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
|
||||
"older versions",
|
||||
default=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.14.2",
|
||||
default="0.15.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
|
|||
|
|
@ -57,12 +57,13 @@ class AppListApi(Resource):
|
|||
)
|
||||
parser.add_argument("name", type=str, location="args", required=False)
|
||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
|
|
@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
|
|||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
|||
|
|
@ -273,8 +273,7 @@ FROM
|
|||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.override_model_configs IS NULL
|
||||
AND c.app_id = :app_id"""
|
||||
c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import base64
|
|||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
|
|||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import datetime
|
|||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
|
|
|||
|
|
@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
|
|
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.LINDORM
|
||||
|
|
|
|||
|
|
@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource):
|
|||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
|
|||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
|
|
|
|||
|
|
@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
|
|||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
|
||||
from flask import session
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from functools import wraps
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import io
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from flask import request
|
||||
from flask_restful import Resource, marshal_with
|
||||
from flask_restful import Resource, marshal_with # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from flask_restful import Resource
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
|||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import reqparse # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
|
|
@ -74,7 +73,7 @@ class CompletionApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
@ -133,7 +132,7 @@ class ChatApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
|
|
@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
|
|||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
|
|
|||
|
|
@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
|
|
@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
|
@ -8,6 +8,8 @@ from flask import current_app, request
|
|||
from flask_login import user_logged_in # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
|
|||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope=None):
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
|
|
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
|
|||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
api_token = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(
|
||||
ApiToken.token == auth_token,
|
||||
ApiToken.type == scope,
|
||||
current_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = result.scalar_one_or_none()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
if not api_token:
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
return api_token
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
|||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
|
|
|
|||
|
|
@ -14,7 +14,11 @@ from controllers.web.error import (
|
|||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict = {}
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
|||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
|
|||
|
|
@ -68,24 +68,17 @@ from models.account import Account
|
|||
from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
|
|
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
stream: bool,
|
||||
dialogue_count: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
|
|
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self._wip_workflow_agent_logs = {}
|
||||
self._message_cycle_manager = MessageCycleManage(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
self._conversation_name_generate_thread: Thread | None = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self._workflow_run_id = ""
|
||||
self._workflow_run_id: str = ""
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
|
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
:return:
|
||||
"""
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
|
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
# init fake graph runtime state
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
err = self._base_task_pipeline._handle_error(
|
||||
event=event, session=session, message_id=self._message_id
|
||||
)
|
||||
session.commit()
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
|
|
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
message.workflow_run_id = workflow_run.id
|
||||
workflow_start_resp = self._workflow_start_to_stream_response(
|
||||
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
|
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_node_retry_to_stream_response(
|
||||
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
node_start_resp = self._workflow_node_start_to_stream_response(
|
||||
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
self._recorded_files.extend(
|
||||
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
|
||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
@ -368,32 +373,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session, event=event
|
||||
)
|
||||
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response_finish:
|
||||
yield response_finish
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
|
|
@ -401,13 +409,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
|
|
@ -415,9 +427,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -429,9 +443,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -443,9 +459,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -460,8 +478,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
if not graph_runtime_state:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
|
@ -472,21 +490,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
|
@ -497,21 +517,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
|
@ -523,20 +545,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
err = self._base_task_pipeline._handle_error(
|
||||
event=err_event, session=session, message_id=self._message_id
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
if self._workflow_run_id and graph_runtime_state:
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
|
@ -547,7 +571,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -561,18 +585,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
self._message_cycle_manager._handle_retriever_resources(event)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
session.commit()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
self._message_cycle_manager._handle_annotation_reply(event)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
|
|
@ -593,29 +617,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(
|
||||
yield self._message_cycle_manager._message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||
self._task_state.answer
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
# Save message
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||
session.commit()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
|
@ -629,7 +659,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
|
@ -693,20 +723,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
if self._base_task_pipeline._output_moderation_handler:
|
||||
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
|||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
|
|
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
|
|||
@classmethod
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
|||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
|
|
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
|||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
|
|
@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
) -> Generator[str | Mapping[str, Any], None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
|
|
@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
|
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
|
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[Mapping, Generator[str, None, None]]:
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
|||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
|
|
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
|
|
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -59,7 +59,6 @@ from models.workflow import (
|
|||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
|
@ -67,16 +66,11 @@ from models.workflow import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
|
|
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
|
|
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._workflow_run_id = ""
|
||||
|
||||
|
|
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
:return:
|
||||
"""
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
|
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
"""
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event=event)
|
||||
yield self._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
created_by_role=self._created_by_role,
|
||||
)
|
||||
self._workflow_run_id = workflow_run.id
|
||||
start_resp = self._workflow_start_to_stream_response(
|
||||
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
|
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_start_response = self._workflow_node_start_to_stream_response(
|
||||
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session,
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
||||
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
|
|
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
|
|
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
|
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
|
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
|
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
|
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
|
@ -514,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
workflow_run_id: Optional[str] = None
|
||||
workflow_run_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -329,6 +329,7 @@ class QueueAgentLogEvent(AppQueueEvent):
|
|||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
|
|
|
|||
|
|
@ -716,6 +716,7 @@ class AgentLogStreamResponse(StreamResponse):
|
|||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
data: Data
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
|||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
PingStreamResponse,
|
||||
TaskState,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
|
|
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
|
|||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: TaskState
|
||||
_application_generate_entity: AppGenerateEntity
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._start_at = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
|
|||
|
||||
|
||||
class MessageCycleManage:
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
|
|||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -60,13 +59,20 @@ from models.workflow import (
|
|||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
|
||||
from .exc import WorkflowRunNotFoundError
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||
) -> None:
|
||||
self._workflow_run: WorkflowRun | None = None
|
||||
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
|
||||
def _handle_workflow_run_start(
|
||||
self,
|
||||
|
|
@ -104,7 +110,8 @@ class WorkflowCycleManage:
|
|||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = workflow_run_id
|
||||
|
|
@ -241,7 +248,7 @@ class WorkflowCycleManage:
|
|||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
|
|
@ -249,15 +256,18 @@ class WorkflowCycleManage:
|
|||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
|
||||
running_workflow_node_executions = session.scalars(stmt).all()
|
||||
ids = session.scalars(stmt).all()
|
||||
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
||||
running_workflow_node_executions = [
|
||||
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
||||
]
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
finish_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = finish_at
|
||||
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
|
||||
workflow_node_execution.finished_at = now
|
||||
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
|
|
@ -299,6 +309,8 @@ class WorkflowCycleManage:
|
|||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(
|
||||
|
|
@ -325,6 +337,7 @@ class WorkflowCycleManage:
|
|||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
|
|
@ -364,6 +377,7 @@ class WorkflowCycleManage:
|
|||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
|
|
@ -415,6 +429,8 @@ class WorkflowCycleManage:
|
|||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
|
|
@ -811,25 +827,23 @@ class WorkflowCycleManage:
|
|||
return None
|
||||
|
||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||
cached_workflow_run = self._workflow_run
|
||||
cached_workflow_run = session.merge(cached_workflow_run)
|
||||
return cached_workflow_run
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||
self._workflow_run = workflow_run
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
|
||||
workflow_node_execution = session.scalar(stmt)
|
||||
if not workflow_node_execution:
|
||||
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
|
||||
|
||||
return workflow_node_execution
|
||||
if node_execution_id not in self._workflow_node_executions:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return cached_workflow_node_execution
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
|
|
@ -848,5 +862,6 @@ class WorkflowCycleManage:
|
|||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class ModelInstance:
|
|||
|
||||
def get_llm_num_tokens(
|
||||
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> list[int]:
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
|
||||
|
|
@ -191,7 +191,7 @@ class ModelInstance:
|
|||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
list[int],
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,47 @@
|
|||
import tiktoken
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||
"""
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
encoding = tiktoken.encoding_for_model("gpt2")
|
||||
tiktoken_vec = encoding.encode(text)
|
||||
return len(tiktoken_vec)
|
||||
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||
#
|
||||
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
# result = future.result()
|
||||
# return cast(int, result)
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
# Try to use tiktoken to get the tokenizer because it is faster
|
||||
#
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||
except Exception:
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
|
||||
return _tokenizer
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
):
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke workflow app
|
||||
"""
|
||||
|
|
@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
):
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke completion app
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
|
|||
return summary.message.content
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
new_lines: list[str] = []
|
||||
# split long line into multiple lines
|
||||
for i in range(len(lines)):
|
||||
line = lines[i]
|
||||
|
|
@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
|
|||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines:
|
||||
for i in new_lines: # type: ignore
|
||||
if len(messages) == 0:
|
||||
messages.append(i)
|
||||
messages.append(i) # type: ignore
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
||||
messages[-1] += i
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
||||
messages.append(i)
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
|
||||
messages[-1] += i # type: ignore
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
|
||||
messages.append(i) # type: ignore
|
||||
else:
|
||||
messages[-1] += i
|
||||
messages[-1] += i # type: ignore
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class BasePluginManager:
|
|||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
yield type(**json.loads(line))
|
||||
yield type(**json.loads(line)) # type: ignore
|
||||
|
||||
def _request_with_model(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
|
|||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||
|
||||
|
|
|
|||
|
|
@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
|
|||
self._client.delete_collection(self._collection_name)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(ids=ids)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,104 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchConfig,
|
||||
ElasticSearchVector,
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchJaVector(ElasticSearchVector):
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
dim = len(embeddings[0])
|
||||
settings = {
|
||||
"analysis": {
|
||||
"analyzer": {
|
||||
"ja_analyzer": {
|
||||
"type": "custom",
|
||||
"char_filter": [
|
||||
"icu_normalizer",
|
||||
"kuromoji_iteration_mark",
|
||||
],
|
||||
"tokenizer": "kuromoji_tokenizer",
|
||||
"filter": [
|
||||
"kuromoji_baseform",
|
||||
"kuromoji_part_of_speech",
|
||||
"ja_stop",
|
||||
"kuromoji_number",
|
||||
"kuromoji_stemmer",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {
|
||||
"type": "text",
|
||||
"analyzer": "ja_analyzer",
|
||||
"search_analyzer": "ja_analyzer",
|
||||
},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return ElasticSearchJaVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
),
|
||||
attributes=[],
|
||||
)
|
||||
|
|
@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
|
|||
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ class Field(Enum):
|
|||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
# Sparse Vector aims to support full text search
|
||||
SPARSE_VECTOR = "sparse_vector"
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = "id"
|
||||
DOC_ID = "metadata.doc_id"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException # type: ignore
|
||||
from pymilvus.milvus_client import IndexParams # type: ignore
|
||||
|
|
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
uri: str
|
||||
token: Optional[str] = None
|
||||
user: str
|
||||
password: str
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
"""
|
||||
Configuration class for Milvus connection.
|
||||
"""
|
||||
|
||||
uri: str # Milvus server URI
|
||||
token: Optional[str] = None # Optional token for authentication
|
||||
user: str # Username for authentication
|
||||
password: str # Password for authentication
|
||||
batch_size: int = 100 # Batch size for operations
|
||||
database: str = "default" # Database name
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""
|
||||
Validate the configuration values.
|
||||
Raises ValueError if required fields are missing.
|
||||
"""
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
if not values.get("user"):
|
||||
|
|
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
|
|||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
"""
|
||||
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||
"""
|
||||
return {
|
||||
"uri": self.uri,
|
||||
"token": self.token,
|
||||
|
|
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
|
|||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
"""
|
||||
Milvus vector storage implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = "Session"
|
||||
self._fields: list[str] = []
|
||||
self._consistency_level = "Session" # Consistency level for Milvus operations
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
"""
|
||||
Check if the current Milvus version supports hybrid search.
|
||||
Returns True if the version is >= 2.5.0, otherwise False.
|
||||
"""
|
||||
if not self._client_config.enable_hybrid_search:
|
||||
return False
|
||||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||
return False
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""
|
||||
Get the type of vector storage (Milvus).
|
||||
"""
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Create a collection and add texts with embeddings.
|
||||
"""
|
||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Add texts and their embeddings to the collection.
|
||||
"""
|
||||
insert_dict_list = []
|
||||
for i in range(len(documents)):
|
||||
insert_dict = {
|
||||
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
|
||||
# function will automatically convert the native text into a sparse vector for us.
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
|
|
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
|
|||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
total_count = len(insert_dict_list)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
for i in range(0, total_count, 1000):
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
# Insert into the collection.
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
try:
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
|
|
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
|
|||
return pks
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Get document IDs by metadata field key and value.
|
||||
"""
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
||||
)
|
||||
|
|
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
|
|||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Delete documents by metadata field key and value.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""
|
||||
Delete documents by their IDs.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
||||
|
|
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
|
|||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete the entire collection.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
self._client.drop_collection(self._collection_name, None)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""
|
||||
Check if a text with the given ID exists in the collection.
|
||||
"""
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
return False
|
||||
|
||||
|
|
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
|
|||
|
||||
return len(result) > 0
|
||||
|
||||
def field_exists(self, field: str) -> bool:
|
||||
"""
|
||||
Check if a field exists in the collection.
|
||||
"""
|
||||
return field in self._fields
|
||||
|
||||
def _process_search_results(
|
||||
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Common method to process search results
|
||||
|
||||
:param results: Search results
|
||||
:param output_fields: Fields to be output
|
||||
:param score_threshold: Score threshold for filtering
|
||||
:return: List of documents
|
||||
"""
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(output_fields[1], {})
|
||||
metadata["score"] = result["distance"]
|
||||
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
# Set search parameters.
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
"""
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
"""
|
||||
Search for documents by full-text search (if hybrid search is enabled).
|
||||
"""
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
return []
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query],
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Create a new collection in Milvus with the specified schema and index parameters.
|
||||
"""
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
|
|
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
|
|||
return
|
||||
# Grab the existing collection if it exists
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
|
||||
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
||||
|
||||
# Determine embedding dim
|
||||
|
|
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
|
|||
if metadatas:
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
|
||||
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
Field.CONTENT_KEY.value,
|
||||
DataType.VARCHAR,
|
||||
max_length=65_535,
|
||||
enable_analyzer=self._hybrid_search_enabled,
|
||||
)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
# Create custom function to support text to sparse vector by BM25
|
||||
if self._hybrid_search_enabled:
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=[Field.CONTENT_KEY.value],
|
||||
output_field_names=[Field.SPARSE_VECTOR.value],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_function)
|
||||
|
||||
for x in schema.fields:
|
||||
self._fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
|
|
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
|
|||
index_params_obj = IndexParams()
|
||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
index_params_obj.add_index(
|
||||
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
|
||||
)
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection(
|
||||
collection_name=collection_name,
|
||||
collection_name=self._collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level,
|
||||
|
|
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
|
|||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
"""
|
||||
Initialize and return a Milvus client.
|
||||
"""
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
return client
|
||||
|
||||
|
||||
class MilvusVectorFactory(AbstractVectorFactory):
|
||||
"""
|
||||
Factory class for creating MilvusVector instances.
|
||||
"""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||
"""
|
||||
Initialize a MilvusVector instance for the given dataset.
|
||||
"""
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
|
|
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
|||
user=dify_config.MILVUS_USER or "",
|
||||
password=dify_config.MILVUS_PASSWORD or "",
|
||||
database=dify_config.MILVUS_DATABASE or "",
|
||||
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
|
|||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
|
|||
return bool(cur.rowcount != 0)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -167,6 +167,8 @@ class OracleVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
|
||||
|
|
|
|||
|
|
@ -129,6 +129,11 @@ class PGVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
|
||||
|
|
|
|||
|
|
@ -140,6 +140,8 @@ class TencentVector(BaseVector):
|
|||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
|
|
|
|||
|
|
@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
if not tidb_auth_binding:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
|
|
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
|
|
|
|||
|
|
@ -90,6 +90,12 @@ class Vector:
|
|||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.ELASTICSEARCH_JA:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||
ElasticSearchJaVectorFactory,
|
||||
)
|
||||
|
||||
return ElasticSearchJaVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class VectorType(StrEnum):
|
|||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
ELASTICSEARCH_JA = "elasticsearch-ja"
|
||||
LINDORM = "lindorm"
|
||||
COUCHBASE = "couchbase"
|
||||
BAIDU = "baidu"
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
|
|||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ""
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
|
|
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
|
|||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode("utf-8"))
|
||||
if not plaintext_file_exists and self._file_cache_key:
|
||||
storage.save(self._file_cache_key, text.encode("utf-8"))
|
||||
|
||||
return documents
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
|
|
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
child_nodes = self._split_child_nodes(
|
||||
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||
)
|
||||
if kwargs.get("preview"):
|
||||
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
|
||||
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
|
||||
|
||||
document.children = child_nodes
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
|
|||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
|
|
|
|||
|
|
@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
|
|||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
|
|
|
|||
|
|
@ -212,8 +212,23 @@ class ApiTool(Tool):
|
|||
else:
|
||||
body = body
|
||||
|
||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response: httpx.Response = getattr(ssrf_proxy, method)(
|
||||
if method in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
|
|||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value) -> str:
|
||||
def transform_variable_name(cls, value: str) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
|
|
@ -167,6 +167,7 @@ class ToolInvokeMessage(BaseModel):
|
|||
error: Optional[str] = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
|
|
|
|||
|
|
@ -203,6 +203,7 @@ class AgentLogEvent(BaseAgentEvent):
|
|||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent
|
||||
|
|
|
|||
|
|
@ -89,7 +89,11 @@ class AgentNode(ToolNode):
|
|||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, {}, parameters_for_log)
|
||||
yield from self._transform_message(
|
||||
message_stream,
|
||||
{"provider": (cast(AgentNodeData, self.node_data)).agent_strategy_provider_name},
|
||||
parameters_for_log,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
|
@ -170,7 +174,12 @@ class AgentNode(ToolNode):
|
|||
extra.get("descrption", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
|
||||
tool_value.append(tool_runtime.entity.model_dump(mode="json"))
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == "model-selector":
|
||||
value = cast(dict[str, Any], value)
|
||||
|
|
|
|||
|
|
@ -2,14 +2,18 @@ import csv
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import tempfile
|
||||
from typing import cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypdfium2 # type: ignore
|
||||
import yaml # type: ignore
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
|
|
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
|||
process_data=process_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: DocumentExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
"""Extract text from a file based on its MIME type."""
|
||||
|
|
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
|||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
text = []
|
||||
# Process paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text.append(paragraph.text)
|
||||
|
||||
# Process tables
|
||||
for table in doc.tables:
|
||||
# Table header
|
||||
try:
|
||||
# table maybe cause errors so ignore it.
|
||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
||||
# Keep track of paragraph and table positions
|
||||
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||
|
||||
# Process paragraphs and tables
|
||||
for i, paragraph in enumerate(doc.paragraphs):
|
||||
if paragraph.text.strip():
|
||||
content_items.append((i, "paragraph", paragraph))
|
||||
|
||||
for i, table in enumerate(doc.tables):
|
||||
content_items.append((i, "table", table))
|
||||
|
||||
# Sort content items based on their original position
|
||||
content_items.sort(key=operator.itemgetter(0))
|
||||
|
||||
# Process sorted content
|
||||
for _, item_type, item in content_items:
|
||||
if item_type == "paragraph":
|
||||
if isinstance(item, Table):
|
||||
continue
|
||||
text.append(item.text)
|
||||
elif item_type == "table":
|
||||
# Process tables
|
||||
if not isinstance(item, Table):
|
||||
continue
|
||||
try:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in table.rows:
|
||||
for row in item.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
||||
for row in table.rows[1:]:
|
||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
||||
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
|
||||
markdown_table = f"| {' | '.join(cell_texts)} |\n"
|
||||
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
|
||||
|
||||
for row in item.rows[1:]:
|
||||
# Replace newlines with <br> in each cell
|
||||
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
|
||||
markdown_table += "| " + " | ".join(row_cells) + " |\n"
|
||||
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
|
|||
Code Node Data.
|
||||
"""
|
||||
|
||||
method: Literal["get", "post", "put", "patch", "delete", "head"]
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"patch",
|
||||
"delete",
|
||||
"head",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
]
|
||||
url: str
|
||||
authorization: HttpRequestNodeAuthorization
|
||||
headers: str
|
||||
|
|
|
|||
|
|
@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
|
|||
|
||||
|
||||
class Executor:
|
||||
method: Literal["get", "head", "post", "put", "delete", "patch"]
|
||||
method: Literal[
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
]
|
||||
url: str
|
||||
params: list[tuple[str, str]] | None
|
||||
content: str | bytes | None
|
||||
|
|
@ -67,12 +82,6 @@ class Executor:
|
|||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
||||
# check if node_data.url is a valid URL
|
||||
if not node_data.url:
|
||||
raise InvalidURLError("url is required")
|
||||
if not node_data.url.startswith(("http://", "https://")):
|
||||
raise InvalidURLError("url should start with http:// or https://")
|
||||
|
||||
self.url: str = node_data.url
|
||||
self.method = node_data.method
|
||||
self.auth = node_data.authorization
|
||||
|
|
@ -99,6 +108,12 @@ class Executor:
|
|||
def _init_url(self):
|
||||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||
|
||||
# check if url is a valid URL
|
||||
if not self.url:
|
||||
raise InvalidURLError("url is required")
|
||||
if not self.url.startswith(("http://", "https://")):
|
||||
raise InvalidURLError("url should start with http:// or https://")
|
||||
|
||||
def _init_params(self):
|
||||
"""
|
||||
Almost same as _init_headers(), difference:
|
||||
|
|
@ -158,7 +173,10 @@ class Executor:
|
|||
if len(data) != 1:
|
||||
raise RequestBodyError("json body type should have exactly one item")
|
||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
try:
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
|
||||
self.json = json_object
|
||||
# self.json = self._parse_object_contains_variables(json_object)
|
||||
case "binary":
|
||||
|
|
@ -246,7 +264,22 @@ class Executor:
|
|||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
if self.method not in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
|
|
@ -263,7 +296,7 @@ class Executor:
|
|||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -197,6 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
|
@ -264,6 +265,11 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value for key, value in msg_metadata.items() if key in NodeRunMetadataKey
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
|
@ -299,6 +305,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
|
|
@ -309,6 +316,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
|
@ -319,7 +327,11 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info, NodeRunMetadataKey.AGENT_LOG: agent_logs},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -340,6 +340,10 @@ class WorkflowEntry:
|
|||
):
|
||||
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
||||
|
||||
# environment variable already exist in variable pool, not from user inputs
|
||||
if variable_pool.get(variable_selector):
|
||||
continue
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ else
|
|||
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
||||
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
||||
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
||||
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
|
||||
--timeout ${GUNICORN_TIMEOUT:-200} \
|
||||
app:app
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def init_app(app: DifyApp):
|
|||
timezone = pytz.timezone(log_tz)
|
||||
|
||||
def time_converter(seconds):
|
||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ def _build_from_remote_url(
|
|||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
) -> File:
|
||||
url = mapping.get("url")
|
||||
url = mapping.get("url") or mapping.get("remote_url")
|
||||
if not url:
|
||||
raise ValueError("Invalid file url")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import string
|
|||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
|
@ -182,7 +182,7 @@ def generate_text_hash(text: str) -> str:
|
|||
return sha256(hash_text.encode()).hexdigest()
|
||||
|
||||
|
||||
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype="application/json")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource):
|
|||
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Error fetching block parent page ID: {response_json.message}")
|
||||
message = response_json.get("message", "unknown error")
|
||||
raise ValueError(f"Error fetching block parent page ID: {message}")
|
||||
parent = response_json["parent"]
|
||||
parent_type = parent["type"]
|
||||
if parent_type == "block_id":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
"""change workflow_runs.total_tokens to bigint
|
||||
|
||||
Revision ID: a91b476a53de
|
||||
Revises: 923752d42eb6
|
||||
Create Date: 2025-01-01 20:00:01.207369
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a91b476a53de'
|
||||
down_revision = '923752d42eb6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||
batch_op.alter_column('total_tokens',
|
||||
existing_type=sa.INTEGER(),
|
||||
type_=sa.BigInteger(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text('0'))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||
batch_op.alter_column('total_tokens',
|
||||
existing_type=sa.BigInteger(),
|
||||
type_=sa.INTEGER(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text('0'))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -415,8 +415,8 @@ class WorkflowRun(Base):
|
|||
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
|
||||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
|
||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
total_steps = db.Column(db.Integer, server_default=db.text("0"))
|
||||
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -71,7 +71,7 @@ pyjwt = "~2.8.0"
|
|||
pypdfium2 = "~4.30.0"
|
||||
python = ">=3.11,<3.13"
|
||||
python-docx = "~1.1.0"
|
||||
python-dotenv = "1.0.0"
|
||||
python-dotenv = "1.0.1"
|
||||
pyyaml = "~6.0.1"
|
||||
readabilipy = "0.2.0"
|
||||
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
||||
|
|
@ -82,7 +82,7 @@ scikit-learn = "~1.5.1"
|
|||
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
||||
sqlalchemy = "~2.0.29"
|
||||
starlette = "0.41.0"
|
||||
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
|
||||
tencentcloud-sdk-python-hunyuan = "~3.0.1294"
|
||||
tiktoken = "~0.8.0"
|
||||
tokenizers = "~0.15.0"
|
||||
transformers = "~4.35.0"
|
||||
|
|
@ -92,7 +92,7 @@ validators = "0.21.0"
|
|||
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
|
||||
websocket-client = "~1.7.0"
|
||||
xinference-client = "0.15.2"
|
||||
yarl = "~1.9.4"
|
||||
yarl = "~1.18.3"
|
||||
youtube-transcript-api = "~0.6.2"
|
||||
zhipuai = "~2.1.5"
|
||||
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
||||
|
|
@ -157,7 +157,7 @@ opensearch-py = "2.4.0"
|
|||
oracledb = "~2.2.1"
|
||||
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
||||
pgvector = "0.2.5"
|
||||
pymilvus = "~2.4.4"
|
||||
pymilvus = "~2.5.0"
|
||||
pymochow = "1.3.1"
|
||||
pyobvector = "~0.1.6"
|
||||
qdrant-client = "1.7.3"
|
||||
|
|
|
|||
|
|
@ -168,23 +168,6 @@ def clean_unused_datasets_task():
|
|||
else:
|
||||
plan = plan_cache.decode()
|
||||
if plan == "sandbox":
|
||||
# add auto disable log
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for document in documents:
|
||||
dataset_auto_disable_log = DatasetAutoDisableLog(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
)
|
||||
db.session.add(dataset_auto_disable_log)
|
||||
# remove index
|
||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None)
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class TokenPair(BaseModel):
|
|||
|
||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
class AccountService:
|
||||
|
|
@ -921,6 +921,9 @@ class RegisterService:
|
|||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
|
||||
) -> str:
|
||||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
|
||||
"""Invite new member"""
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter_by(email=email).first()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
|
|
@ -124,7 +125,7 @@ class AppDslService:
|
|||
raise ValueError(f"Invalid import_mode: {import_mode}")
|
||||
|
||||
# Get YAML content
|
||||
content: bytes | str = b""
|
||||
content: str = ""
|
||||
if mode == ImportMode.YAML_URL:
|
||||
if not yaml_url:
|
||||
return Import(
|
||||
|
|
@ -133,13 +134,17 @@ class AppDslService:
|
|||
error="yaml_url is required when import_mode is yaml-url",
|
||||
)
|
||||
try:
|
||||
# tricky way to handle url from github to github raw url
|
||||
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")):
|
||||
parsed_url = urlparse(yaml_url)
|
||||
if (
|
||||
parsed_url.scheme == "https"
|
||||
and parsed_url.netloc == "github.com"
|
||||
and parsed_url.path.endswith((".yml", ".yaml"))
|
||||
):
|
||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||
yaml_url = yaml_url.replace("/blob/", "/")
|
||||
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
content = response.content.decode()
|
||||
|
||||
if len(content) > DSL_MAX_SIZE:
|
||||
return Import(
|
||||
|
|
|
|||
|
|
@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
|
|||
|
||||
|
||||
class AppService:
|
||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id
|
||||
:param args: request args
|
||||
:return:
|
||||
|
|
@ -44,6 +45,8 @@ class AppService:
|
|||
elif args["mode"] == "channel":
|
||||
filters.append(App.mode == AppMode.CHANNEL.value)
|
||||
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
if args.get("name"):
|
||||
name = args["name"][:30]
|
||||
filters.append(App.name.ilike(f"%{name}%"))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import httpx
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
|
|
@ -17,7 +17,6 @@ class BillingService:
|
|||
params = {"tenant_id": tenant_id}
|
||||
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
|
|
@ -47,12 +46,13 @@ class BillingService:
|
|||
retry=retry_if_exception_type(httpx.RequestError),
|
||||
reraise=True,
|
||||
)
|
||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
||||
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
||||
|
||||
if method == "GET" and response.status_code != httpx.codes.OK:
|
||||
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class DatasetService:
|
|||
else:
|
||||
return [], 0
|
||||
else:
|
||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
||||
if user.current_role != TenantAccountRole.OWNER:
|
||||
# show all datasets that the user has permission to access
|
||||
if permitted_dataset_ids:
|
||||
query = query.filter(
|
||||
|
|
@ -382,7 +382,7 @@ class DatasetService:
|
|||
if dataset.tenant_id != user.current_tenant_id:
|
||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
||||
if user.current_role != TenantAccountRole.OWNER:
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
|
@ -404,7 +404,7 @@ class DatasetService:
|
|||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
||||
if user.current_role != TenantAccountRole.OWNER:
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
|
||||
if dataset.created_by != user.id:
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
|
@ -434,6 +434,12 @@ class DatasetService:
|
|||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
|
||||
return {
|
||||
"document_ids": [],
|
||||
"count": 0,
|
||||
}
|
||||
# get recent 30 days auto disable logs
|
||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
||||
|
|
@ -786,13 +792,19 @@ class DocumentService:
|
|||
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
dataset_embedding_model = knowledge_config.embedding_model
|
||||
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||
else:
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset_embedding_model = embedding_model.model
|
||||
dataset_embedding_model_provider = embedding_model.provider
|
||||
dataset.embedding_model = dataset_embedding_model
|
||||
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
dataset_embedding_model_provider, dataset_embedding_model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
|
|
@ -804,7 +816,11 @@ class DocumentService:
|
|||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
|
||||
dataset.retrieval_model = (
|
||||
knowledge_config.retrieval_model.model_dump()
|
||||
if knowledge_config.retrieval_model
|
||||
else default_retrieval_model
|
||||
) # type: ignore
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class WorkflowAppService:
|
|||
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||
|
||||
if keyword:
|
||||
keyword_like_val = f"%{args['keyword'][:30]}%"
|
||||
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
|
||||
keyword_conditions = [
|
||||
WorkflowRun.inputs.ilike(keyword_like_val),
|
||||
WorkflowRun.outputs.ilike(keyword_like_val),
|
||||
|
|
|
|||
|
|
@ -298,7 +298,7 @@ class WorkflowService:
|
|||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
):
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Handle node run result
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "remove":
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
|
|
@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GPUStackSpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackSpeech2TextModel()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
file = Path(audio_file_path).read_bytes()
|
||||
|
||||
result = model.invoke(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
file=file,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import os
|
||||
|
||||
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackText2SpeechModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="cosyvoice-300m-sft",
|
||||
tenant_id="test",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
content_text="Hello world",
|
||||
voice="Chinese Female",
|
||||
)
|
||||
|
||||
content = b""
|
||||
for chunk in result:
|
||||
content += chunk
|
||||
|
||||
assert content != b""
|
||||
|
|
@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest):
|
|||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# milvus dos not support full text searching yet in < 2.3.x
|
||||
# milvus support BM25 full text search after version 2.5.0-beta
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
assert len(hits_by_full_text) >= 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ version: '3'
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.14.2
|
||||
image: langgenius/dify-api:0.15.0
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
|
|
@ -227,7 +227,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.14.2
|
||||
image: langgenius/dify-api:0.15.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_WEB_URL: ''
|
||||
|
|
@ -397,7 +397,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.14.2
|
||||
image: langgenius/dify-web:0.15.0
|
||||
restart: always
|
||||
environment:
|
||||
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue