mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 19:27:23 +08:00
Merge branch 'fix/chore-fix' into dev/plugin-deploy
This commit is contained in:
commit
00ad751a57
@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
|
|||||||
# Access token expiration time in minutes
|
# Access token expiration time in minutes
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||||
|
|
||||||
|
# Refresh token expiration time in days
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
# celery configuration
|
# celery configuration
|
||||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,10 @@ if is_db_command():
|
|||||||
|
|
||||||
app = create_migrations_app()
|
app = create_migrations_app()
|
||||||
else:
|
else:
|
||||||
if os.environ.get("FLASK_DEBUG", "False") != "True":
|
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||||
|
# so we need to disable gevent in debug mode.
|
||||||
|
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||||
|
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||||
from gevent import monkey # type: ignore
|
from gevent import monkey # type: ignore
|
||||||
|
|
||||||
# gevent
|
# gevent
|
||||||
|
|||||||
@ -546,6 +546,11 @@ class AuthConfig(BaseSettings):
|
|||||||
default=60,
|
default=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
|
||||||
|
description="Expiration time for refresh tokens in days",
|
||||||
|
default=30,
|
||||||
|
)
|
||||||
|
|
||||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||||
default=86400,
|
default=86400,
|
||||||
@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings):
|
|||||||
default=4000,
|
default=4000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
|
||||||
|
description="Maximum number of child chunks to preview",
|
||||||
|
default=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalTransferConfig(BaseSettings):
|
class MultiModalTransferConfig(BaseSettings):
|
||||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||||
|
|||||||
@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
|
|||||||
description="Name of the Milvus database to connect to (default is 'default')",
|
description="Name of the Milvus database to connect to (default is 'default')",
|
||||||
default="default",
|
default="default",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
|
||||||
|
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
|
||||||
|
"older versions",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||||||
|
|
||||||
CURRENT_VERSION: str = Field(
|
CURRENT_VERSION: str = Field(
|
||||||
description="Dify version",
|
description="Dify version",
|
||||||
default="0.14.2",
|
default="0.15.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
COMMIT_SHA: str = Field(
|
COMMIT_SHA: str = Field(
|
||||||
|
|||||||
@ -57,12 +57,13 @@ class AppListApi(Resource):
|
|||||||
)
|
)
|
||||||
parser.add_argument("name", type=str, location="args", required=False)
|
parser.add_argument("name", type=str, location="args", required=False)
|
||||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||||
|
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# get app list
|
# get app list
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
||||||
if not app_pagination:
|
if not app_pagination:
|
||||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import (
|
from core.errors.error import (
|
||||||
AppInvokeQuotaExceededError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
ModelCurrentlyNotSupportError,
|
||||||
ProviderTokenNotInitError,
|
ProviderTokenNotInitError,
|
||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
|
|||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
|
|||||||
raise InvokeRateLimitHttpError(ex.description)
|
raise InvokeRateLimitHttpError(ex.description)
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
|||||||
@ -273,8 +273,7 @@ FROM
|
|||||||
messages m
|
messages m
|
||||||
ON c.id = m.conversation_id
|
ON c.id = m.conversation_id
|
||||||
WHERE
|
WHERE
|
||||||
c.override_model_configs IS NULL
|
c.app_id = :app_id"""
|
||||||
AND c.app_id = :app_id"""
|
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import base64
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse # type: ignore
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
)
|
)
|
||||||
except WorkSpaceNotAllowedCreateError:
|
except WorkSpaceNotAllowedCreateError:
|
||||||
pass
|
pass
|
||||||
except AccountRegisterError as are:
|
except AccountRegisterError:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask import current_app, redirect, request
|
from flask import current_app, redirect, request
|
||||||
from flask_restful import Resource
|
from flask_restful import Resource # type: ignore
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|||||||
@ -2,8 +2,8 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user # type: ignore
|
||||||
from flask_restful import Resource, marshal_with, reqparse
|
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|||||||
@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
| VectorType.MYSCALE
|
| VectorType.MYSCALE
|
||||||
| VectorType.ORACLE
|
| VectorType.ORACLE
|
||||||
| VectorType.ELASTICSEARCH
|
| VectorType.ELASTICSEARCH
|
||||||
|
| VectorType.ELASTICSEARCH_JA
|
||||||
| VectorType.PGVECTOR
|
| VectorType.PGVECTOR
|
||||||
| VectorType.TIDB_ON_QDRANT
|
| VectorType.TIDB_ON_QDRANT
|
||||||
| VectorType.LINDORM
|
| VectorType.LINDORM
|
||||||
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
| VectorType.MYSCALE
|
| VectorType.MYSCALE
|
||||||
| VectorType.ORACLE
|
| VectorType.ORACLE
|
||||||
| VectorType.ELASTICSEARCH
|
| VectorType.ELASTICSEARCH
|
||||||
|
| VectorType.ELASTICSEARCH_JA
|
||||||
| VectorType.COUCHBASE
|
| VectorType.COUCHBASE
|
||||||
| VectorType.PGVECTOR
|
| VectorType.PGVECTOR
|
||||||
| VectorType.LINDORM
|
| VectorType.LINDORM
|
||||||
|
|||||||
@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource):
|
|||||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||||
|
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
|
|||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
|||||||
@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
|
|||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from flask import session
|
from flask import session
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse # type: ignore
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user # type: ignore
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from flask_login import current_user
|
from flask_login import current_user # type: ignore
|
||||||
from flask_restful import Resource
|
from flask_restful import Resource # type: ignore
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from flask_login import current_user
|
from flask_login import current_user # type: ignore
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse # type: ignore
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
from flask import request, send_file
|
from flask import request, send_file
|
||||||
from flask_login import current_user
|
from flask_login import current_user # type: ignore
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse # type: ignore
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restful import Resource, marshal_with
|
from flask_restful import Resource, marshal_with # type: ignore
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restful import Resource
|
from flask_restful import Resource # type: ignore
|
||||||
|
|
||||||
from controllers.console.wraps import setup_required
|
from controllers.console.wraps import setup_required
|
||||||
from controllers.inner_api import api
|
from controllers.inner_api import api
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from functools import wraps
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restful import reqparse
|
from flask_restful import reqparse # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import (
|
from core.errors.error import (
|
||||||
AppInvokeQuotaExceededError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
ModelCurrentlyNotSupportError,
|
||||||
ProviderTokenNotInitError,
|
ProviderTokenNotInitError,
|
||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
@ -74,7 +73,7 @@ class CompletionApi(Resource):
|
|||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
@ -133,7 +132,7 @@ class ChatApi(Resource):
|
|||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import (
|
from core.errors.error import (
|
||||||
AppInvokeQuotaExceededError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
ModelCurrentlyNotSupportError,
|
||||||
ProviderTokenNotInitError,
|
ProviderTokenNotInitError,
|
||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
|
|||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
|||||||
@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
source="datasets",
|
source="datasets",
|
||||||
)
|
)
|
||||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
data_source = {
|
||||||
|
"type": "upload_file",
|
||||||
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
|
}
|
||||||
args["data_source"] = data_source
|
args["data_source"] = data_source
|
||||||
# validate args
|
# validate args
|
||||||
knowledge_config = KnowledgeConfig(**args)
|
knowledge_config = KnowledgeConfig(**args)
|
||||||
@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
data_source = {
|
||||||
|
"type": "upload_file",
|
||||||
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
|
}
|
||||||
args["data_source"] = data_source
|
args["data_source"] = data_source
|
||||||
# validate args
|
# validate args
|
||||||
args["original_document_id"] = str(document_id)
|
args["original_document_id"] = str(document_id)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime, timedelta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -8,6 +8,8 @@ from flask import current_app, request
|
|||||||
from flask_login import user_logged_in # type: ignore
|
from flask_login import user_logged_in # type: ignore
|
||||||
from flask_restful import Resource # type: ignore
|
from flask_restful import Resource # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def validate_and_get_api_token(scope=None):
|
def validate_and_get_api_token(scope: str | None = None):
|
||||||
"""
|
"""
|
||||||
Validate and get API token.
|
Validate and get API token.
|
||||||
"""
|
"""
|
||||||
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
|
|||||||
if auth_scheme != "bearer":
|
if auth_scheme != "bearer":
|
||||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||||
|
|
||||||
api_token = (
|
current_time = datetime.now(UTC).replace(tzinfo=None)
|
||||||
db.session.query(ApiToken)
|
cutoff_time = current_time - timedelta(minutes=1)
|
||||||
.filter(
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
ApiToken.token == auth_token,
|
update_stmt = (
|
||||||
ApiToken.type == scope,
|
update(ApiToken)
|
||||||
|
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
||||||
|
.values(last_used_at=current_time)
|
||||||
|
.returning(ApiToken)
|
||||||
)
|
)
|
||||||
.first()
|
result = session.execute(update_stmt)
|
||||||
)
|
api_token = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not api_token:
|
if not api_token:
|
||||||
raise Unauthorized("Access token is invalid")
|
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||||
|
api_token = session.scalar(stmt)
|
||||||
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
if not api_token:
|
||||||
db.session.commit()
|
raise Unauthorized("Access token is invalid")
|
||||||
|
else:
|
||||||
|
session.commit()
|
||||||
|
|
||||||
return api_token
|
return api_token
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
|||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
|
|||||||
@ -14,7 +14,11 @@ from controllers.web.error import (
|
|||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
callbacks=[],
|
callbacks=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
usage_dict = {}
|
usage_dict: dict[str, Optional[LLMUsage]] = {}
|
||||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||||
scratchpad = AgentScratchpadUnit(
|
scratchpad = AgentScratchpadUnit(
|
||||||
agent_response="",
|
agent_response="",
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
|||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = self.application_generate_entity.inputs
|
inputs = self.application_generate_entity.inputs
|
||||||
|
|||||||
@ -68,24 +68,17 @@ from models.account import Account
|
|||||||
from models.enums import CreatedByRole
|
from models.enums import CreatedByRole
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNodeExecution,
|
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
class AdvancedChatAppGenerateTaskPipeline:
|
||||||
"""
|
"""
|
||||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: WorkflowTaskState
|
|
||||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
||||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
|
||||||
_conversation_name_generate_thread: Optional[Thread] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
stream: bool,
|
stream: bool,
|
||||||
dialogue_count: int,
|
dialogue_count: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||||
|
|
||||||
self._workflow_id = workflow.id
|
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||||
self._workflow_features_dict = workflow.features_dict
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_system_variables={
|
||||||
self._conversation_id = conversation.id
|
SystemVariableKey.QUERY: message.query,
|
||||||
self._conversation_mode = conversation.mode
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
|
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||||
self._message_id = message.id
|
SystemVariableKey.USER_ID: user_session_id,
|
||||||
self._message_created_at = int(message.created_at.timestamp())
|
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||||
|
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||||
self._workflow_system_variables = {
|
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||||
SystemVariableKey.QUERY: message.query,
|
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
},
|
||||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
)
|
||||||
SystemVariableKey.USER_ID: user_session_id,
|
|
||||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
|
||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
|
||||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
self._wip_workflow_node_executions = {}
|
self._message_cycle_manager = MessageCycleManage(
|
||||||
self._wip_workflow_agent_logs = {}
|
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||||
|
)
|
||||||
|
|
||||||
self._conversation_name_generate_thread = None
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._workflow_id = workflow.id
|
||||||
|
self._workflow_features_dict = workflow.features_dict
|
||||||
|
self._conversation_id = conversation.id
|
||||||
|
self._conversation_mode = conversation.mode
|
||||||
|
self._message_id = message.id
|
||||||
|
self._message_created_at = int(message.created_at.timestamp())
|
||||||
|
self._conversation_name_generate_thread: Thread | None = None
|
||||||
self._recorded_files: list[Mapping[str, Any]] = []
|
self._recorded_files: list[Mapping[str, Any]] = []
|
||||||
self._workflow_run_id = ""
|
self._workflow_run_id: str = ""
|
||||||
|
|
||||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||||
"""
|
"""
|
||||||
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# start generate conversation name thread
|
# start generate conversation name thread
|
||||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
|
|
||||||
if self._stream:
|
if self._base_task_pipeline._stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
# init fake graph runtime state
|
# init fake graph runtime state
|
||||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||||
|
|
||||||
for queue_message in self._queue_manager.listen():
|
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||||
event = queue_message.event
|
event = queue_message.event
|
||||||
|
|
||||||
if isinstance(event, QueuePingEvent):
|
if isinstance(event, QueuePingEvent):
|
||||||
yield self._ping_stream_response()
|
yield self._base_task_pipeline._ping_stream_response()
|
||||||
elif isinstance(event, QueueErrorEvent):
|
elif isinstance(event, QueueErrorEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
err = self._base_task_pipeline._handle_error(
|
||||||
|
event=event, session=session, message_id=self._message_id
|
||||||
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||||
# override graph runtime state
|
# override graph runtime state
|
||||||
graph_runtime_state = event.graph_runtime_state
|
graph_runtime_state = event.graph_runtime_state
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run = self._handle_workflow_run_start(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_id=self._workflow_id,
|
workflow_id=self._workflow_id,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not message:
|
if not message:
|
||||||
raise ValueError(f"Message not found: {self._message_id}")
|
raise ValueError(f"Message not found: {self._message_id}")
|
||||||
message.workflow_run_id = workflow_run.id
|
message.workflow_run_id = workflow_run.id
|
||||||
workflow_start_resp = self._workflow_start_to_stream_response(
|
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_retry_resp = self._workflow_node_retry_to_stream_response(
|
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_node_execution_start(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
node_start_resp = self._workflow_node_start_to_stream_response(
|
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
# Record files if it's an answer node or end node
|
# Record files if it's an answer node or end node
|
||||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
self._recorded_files.extend(
|
||||||
|
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
|
||||||
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
|
session=session, event=event
|
||||||
|
)
|
||||||
|
|
||||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -368,32 +373,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if node_finish_resp:
|
if node_finish_resp:
|
||||||
yield node_finish_resp
|
yield node_finish_resp
|
||||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
|
session=session, event=event
|
||||||
|
)
|
||||||
|
|
||||||
response_finish = self._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response_finish:
|
|
||||||
yield response_finish
|
|
||||||
|
|
||||||
if node_finish_resp:
|
if node_finish_resp:
|
||||||
yield node_finish_resp
|
yield node_finish_resp
|
||||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
session=session,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
parallel_start_resp = (
|
||||||
workflow_run=workflow_run,
|
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||||
event=event,
|
session=session,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield parallel_start_resp
|
yield parallel_start_resp
|
||||||
@ -401,13 +409,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
session=session,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
parallel_finish_resp = (
|
||||||
workflow_run=workflow_run,
|
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||||
event=event,
|
session=session,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield parallel_finish_resp
|
yield parallel_finish_resp
|
||||||
@ -415,9 +427,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -429,9 +443,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -443,9 +459,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -460,8 +478,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -472,21 +490,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
|
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_partial_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -497,21 +517,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
|
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -523,20 +545,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
exceptions_count=event.exceptions_count,
|
exceptions_count=event.exceptions_count,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||||
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
|
err = self._base_task_pipeline._handle_error(
|
||||||
|
event=err_event, session=session, message_id=self._message_id
|
||||||
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueStopEvent):
|
elif isinstance(event, QueueStopEvent):
|
||||||
if self._workflow_run_id and graph_runtime_state:
|
if self._workflow_run_id and graph_runtime_state:
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -547,7 +571,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
conversation_id=self._conversation_id,
|
conversation_id=self._conversation_id,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -561,18 +585,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
yield self._message_end_to_stream_response()
|
yield self._message_end_to_stream_response()
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
self._handle_retriever_resources(event)
|
self._message_cycle_manager._handle_retriever_resources(event)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
self._handle_annotation_reply(event)
|
self._message_cycle_manager._handle_annotation_reply(event)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
@ -593,29 +617,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
tts_publisher.publish(queue_message)
|
tts_publisher.publish(queue_message)
|
||||||
|
|
||||||
self._task_state.answer += delta_text
|
self._task_state.answer += delta_text
|
||||||
yield self._message_to_stream_response(
|
yield self._message_cycle_manager._message_to_stream_response(
|
||||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueMessageReplaceEvent):
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
# published by moderation
|
# published by moderation
|
||||||
yield self._message_replace_to_stream_response(answer=event.text)
|
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
|
||||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||||
|
self._task_state.answer
|
||||||
|
)
|
||||||
if output_moderation_answer:
|
if output_moderation_answer:
|
||||||
self._task_state.answer = output_moderation_answer
|
self._task_state.answer = output_moderation_answer
|
||||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||||
|
answer=output_moderation_answer
|
||||||
|
)
|
||||||
|
|
||||||
# Save message
|
# Save message
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield self._message_end_to_stream_response()
|
yield self._message_end_to_stream_response()
|
||||||
elif isinstance(event, QueueAgentLogEvent):
|
elif isinstance(event, QueueAgentLogEvent):
|
||||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
yield self._workflow_cycle_manager._handle_agent_log(
|
||||||
|
task_id=self._application_generate_entity.task_id, event=event
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -629,7 +659,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.answer = self._task_state.answer
|
message.answer = self._task_state.answer
|
||||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
)
|
)
|
||||||
@ -693,20 +723,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
:param text: text
|
:param text: text
|
||||||
:return: True if output moderation should direct output, otherwise False
|
:return: True if output moderation should direct output, otherwise False
|
||||||
"""
|
"""
|
||||||
if self._output_moderation_handler:
|
if self._base_task_pipeline._output_moderation_handler:
|
||||||
if self._output_moderation_handler.should_direct_output():
|
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||||
# stop subscribe new token when output moderation should direct output
|
# stop subscribe new token when output moderation should direct output
|
||||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||||
self._queue_manager.publish(
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||||
)
|
)
|
||||||
|
|
||||||
self._queue_manager.publish(
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
self._output_moderation_handler.append_new_token(text)
|
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
|||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def convert(
|
def convert(
|
||||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||||
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||||
if isinstance(response, AppBlockingResponse):
|
if isinstance(response, AppBlockingResponse):
|
||||||
return cls.convert_blocking_full_response(response)
|
return cls.convert_blocking_full_response(response)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
|||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
|||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
) -> Generator[str, None, None]: ...
|
) -> Generator[str | Mapping[str, Any], None, None]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def generate(
|
def generate(
|
||||||
@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
) -> Union[Mapping, Generator[str, None, None]]:
|
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
|||||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||||
node_id=node_id, inputs=args["inputs"]
|
node_id=node_id, inputs=args["inputs"]
|
||||||
),
|
),
|
||||||
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -59,7 +59,6 @@ from models.workflow import (
|
|||||||
Workflow,
|
Workflow,
|
||||||
WorkflowAppLog,
|
WorkflowAppLog,
|
||||||
WorkflowAppLogCreatedFrom,
|
WorkflowAppLogCreatedFrom,
|
||||||
WorkflowNodeExecution,
|
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
@ -67,16 +66,11 @@ from models.workflow import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
|
class WorkflowAppGenerateTaskPipeline:
|
||||||
"""
|
"""
|
||||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: WorkflowTaskState
|
|
||||||
_application_generate_entity: WorkflowAppGenerateEntity
|
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
||||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
stream: bool,
|
stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid user type: {type(user)}")
|
raise ValueError(f"Invalid user type: {type(user)}")
|
||||||
|
|
||||||
|
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_system_variables={
|
||||||
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
|
SystemVariableKey.USER_ID: user_session_id,
|
||||||
|
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||||
|
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||||
|
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
self._workflow_id = workflow.id
|
self._workflow_id = workflow.id
|
||||||
self._workflow_features_dict = workflow.features_dict
|
self._workflow_features_dict = workflow.features_dict
|
||||||
|
|
||||||
self._workflow_system_variables = {
|
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
|
||||||
SystemVariableKey.USER_ID: user_session_id,
|
|
||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
|
||||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
self._workflow_run_id = ""
|
self._workflow_run_id = ""
|
||||||
|
|
||||||
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
if self._stream:
|
if self._base_task_pipeline._stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
"""
|
"""
|
||||||
graph_runtime_state = None
|
graph_runtime_state = None
|
||||||
|
|
||||||
for queue_message in self._queue_manager.listen():
|
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||||
event = queue_message.event
|
event = queue_message.event
|
||||||
|
|
||||||
if isinstance(event, QueuePingEvent):
|
if isinstance(event, QueuePingEvent):
|
||||||
yield self._ping_stream_response()
|
yield self._base_task_pipeline._ping_stream_response()
|
||||||
elif isinstance(event, QueueErrorEvent):
|
elif isinstance(event, QueueErrorEvent):
|
||||||
err = self._handle_error(event=event)
|
err = self._base_task_pipeline._handle_error(event=event)
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||||
# override graph runtime state
|
# override graph runtime state
|
||||||
graph_runtime_state = event.graph_runtime_state
|
graph_runtime_state = event.graph_runtime_state
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run = self._handle_workflow_run_start(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_id=self._workflow_id,
|
workflow_id=self._workflow_id,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
created_by_role=self._created_by_role,
|
created_by_role=self._created_by_role,
|
||||||
)
|
)
|
||||||
self._workflow_run_id = workflow_run.id
|
self._workflow_run_id = workflow_run.id
|
||||||
start_resp = self._workflow_start_to_stream_response(
|
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
):
|
):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
response = self._workflow_node_retry_to_stream_response(
|
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_node_execution_start(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_start_response = self._workflow_node_start_to_stream_response(
|
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if node_start_response:
|
if node_start_response:
|
||||||
yield node_start_response
|
yield node_start_response
|
||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
session=session, event=event
|
||||||
|
)
|
||||||
|
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if node_success_response:
|
if node_success_response:
|
||||||
yield node_success_response
|
yield node_success_response
|
||||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_failed(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
session=session,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
parallel_start_resp = (
|
||||||
workflow_run=workflow_run,
|
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||||
event=event,
|
session=session,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield parallel_start_resp
|
yield parallel_start_resp
|
||||||
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
session=session,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
parallel_finish_resp = (
|
||||||
workflow_run=workflow_run,
|
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||||
event=event,
|
session=session,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield parallel_finish_resp
|
yield parallel_finish_resp
|
||||||
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_partial_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -514,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
delta_text, from_variable_selector=event.from_variable_selector
|
delta_text, from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueAgentLogEvent):
|
elif isinstance(event, QueueAgentLogEvent):
|
||||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
yield self._workflow_cycle_manager._handle_agent_log(
|
||||||
|
task_id=self._application_generate_entity.task_id, event=event
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||||||
|
|
||||||
# app config
|
# app config
|
||||||
app_config: WorkflowUIBasedAppConfig
|
app_config: WorkflowUIBasedAppConfig
|
||||||
workflow_run_id: Optional[str] = None
|
workflow_run_id: str
|
||||||
|
|
||||||
class SingleIterationRunEntity(BaseModel):
|
class SingleIterationRunEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -329,6 +329,7 @@ class QueueAgentLogEvent(AppQueueEvent):
|
|||||||
error: str | None
|
error: str | None
|
||||||
status: str
|
status: str
|
||||||
data: Mapping[str, Any]
|
data: Mapping[str, Any]
|
||||||
|
metadata: Optional[Mapping[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||||
|
|||||||
@ -716,6 +716,7 @@ class AgentLogStreamResponse(StreamResponse):
|
|||||||
error: str | None
|
error: str | None
|
||||||
status: str
|
status: str
|
||||||
data: Mapping[str, Any]
|
data: Mapping[str, Any]
|
||||||
|
metadata: Optional[Mapping[str, Any]] = None
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||||
data: Data
|
data: Data
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
|||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
ErrorStreamResponse,
|
ErrorStreamResponse,
|
||||||
PingStreamResponse,
|
PingStreamResponse,
|
||||||
TaskState,
|
|
||||||
)
|
)
|
||||||
from core.errors.error import QuotaExceededError
|
from core.errors.error import QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
|
|||||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: TaskState
|
|
||||||
_application_generate_entity: AppGenerateEntity
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: AppGenerateEntity,
|
application_generate_entity: AppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Initialize GenerateTaskPipeline.
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param user: user
|
|
||||||
:param stream: stream
|
|
||||||
"""
|
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
self._start_at = time.perf_counter()
|
self._start_at = time.perf_counter()
|
||||||
|
|||||||
@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
|
|||||||
|
|
||||||
|
|
||||||
class MessageCycleManage:
|
class MessageCycleManage:
|
||||||
_application_generate_entity: Union[
|
def __init__(
|
||||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
self,
|
||||||
]
|
*,
|
||||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
application_generate_entity: Union[
|
||||||
|
ChatAppGenerateEntity,
|
||||||
|
CompletionAppGenerateEntity,
|
||||||
|
AgentChatAppGenerateEntity,
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
],
|
||||||
|
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||||
|
) -> None:
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._task_state = task_state
|
||||||
|
|
||||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
|
|||||||
ParallelBranchStartStreamResponse,
|
ParallelBranchStartStreamResponse,
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
WorkflowTaskState,
|
|
||||||
)
|
)
|
||||||
from core.file import FILE_MODEL_IDENTITY, File
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
@ -60,13 +59,20 @@ from models.workflow import (
|
|||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
|
from .exc import WorkflowRunNotFoundError
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCycleManage:
|
class WorkflowCycleManage:
|
||||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
def __init__(
|
||||||
_task_state: WorkflowTaskState
|
self,
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
*,
|
||||||
|
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||||
|
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||||
|
) -> None:
|
||||||
|
self._workflow_run: WorkflowRun | None = None
|
||||||
|
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._workflow_system_variables = workflow_system_variables
|
||||||
|
|
||||||
def _handle_workflow_run_start(
|
def _handle_workflow_run_start(
|
||||||
self,
|
self,
|
||||||
@ -104,7 +110,8 @@ class WorkflowCycleManage:
|
|||||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||||
|
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
|
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||||
|
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||||
|
|
||||||
workflow_run = WorkflowRun()
|
workflow_run = WorkflowRun()
|
||||||
workflow_run.id = workflow_run_id
|
workflow_run.id = workflow_run_id
|
||||||
@ -241,7 +248,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_run.exceptions_count = exceptions_count
|
workflow_run.exceptions_count = exceptions_count
|
||||||
|
|
||||||
stmt = select(WorkflowNodeExecution).where(
|
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
||||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||||
@ -249,15 +256,18 @@ class WorkflowCycleManage:
|
|||||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||||
)
|
)
|
||||||
|
ids = session.scalars(stmt).all()
|
||||||
running_workflow_node_executions = session.scalars(stmt).all()
|
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
||||||
|
running_workflow_node_executions = [
|
||||||
|
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
||||||
|
]
|
||||||
|
|
||||||
for workflow_node_execution in running_workflow_node_executions:
|
for workflow_node_execution in running_workflow_node_executions:
|
||||||
|
now = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
workflow_node_execution.error = error
|
workflow_node_execution.error = error
|
||||||
finish_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.finished_at = now
|
||||||
workflow_node_execution.finished_at = finish_at
|
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||||
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
|
|
||||||
|
|
||||||
if trace_manager:
|
if trace_manager:
|
||||||
trace_manager.add_trace_task(
|
trace_manager.add_trace_task(
|
||||||
@ -299,6 +309,8 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
session.add(workflow_node_execution)
|
||||||
|
|
||||||
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_success(
|
def _handle_workflow_node_execution_success(
|
||||||
@ -325,6 +337,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.finished_at = finished_at
|
workflow_node_execution.finished_at = finished_at
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
|
|
||||||
|
workflow_node_execution = session.merge(workflow_node_execution)
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_failed(
|
def _handle_workflow_node_execution_failed(
|
||||||
@ -364,6 +377,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
|
|
||||||
|
workflow_node_execution = session.merge(workflow_node_execution)
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_retried(
|
def _handle_workflow_node_execution_retried(
|
||||||
@ -415,6 +429,8 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.index = event.node_run_index
|
workflow_node_execution.index = event.node_run_index
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
session.add(workflow_node_execution)
|
||||||
|
|
||||||
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
@ -811,25 +827,23 @@ class WorkflowCycleManage:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||||
"""
|
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||||
Refetch workflow run
|
cached_workflow_run = self._workflow_run
|
||||||
:param workflow_run_id: workflow run id
|
cached_workflow_run = session.merge(cached_workflow_run)
|
||||||
:return:
|
return cached_workflow_run
|
||||||
"""
|
|
||||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||||
workflow_run = session.scalar(stmt)
|
workflow_run = session.scalar(stmt)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||||
|
self._workflow_run = workflow_run
|
||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
||||||
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
|
if node_execution_id not in self._workflow_node_executions:
|
||||||
workflow_node_execution = session.scalar(stmt)
|
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||||
if not workflow_node_execution:
|
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||||
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
|
return cached_workflow_node_execution
|
||||||
|
|
||||||
return workflow_node_execution
|
|
||||||
|
|
||||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||||
"""
|
"""
|
||||||
@ -848,5 +862,6 @@ class WorkflowCycleManage:
|
|||||||
error=event.error,
|
error=event.error,
|
||||||
status=event.status,
|
status=event.status,
|
||||||
data=event.data,
|
data=event.data,
|
||||||
|
metadata=event.metadata,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -178,7 +178,7 @@ class ModelInstance:
|
|||||||
|
|
||||||
def get_llm_num_tokens(
|
def get_llm_num_tokens(
|
||||||
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||||
) -> list[int]:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for llm
|
Get number of tokens for llm
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ class ModelInstance:
|
|||||||
|
|
||||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||||
return cast(
|
return cast(
|
||||||
list[int],
|
int,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.get_num_tokens,
|
function=self.model_type_instance.get_num_tokens,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
|||||||
@ -1,9 +1,47 @@
|
|||||||
import tiktoken
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_tokenizer: Any = None
|
||||||
|
_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
class GPT2Tokenizer:
|
class GPT2Tokenizer:
|
||||||
|
@staticmethod
|
||||||
|
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||||
|
"""
|
||||||
|
use gpt2 tokenizer to get num tokens
|
||||||
|
"""
|
||||||
|
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||||
|
tokens = _tokenizer.encode(text)
|
||||||
|
return len(tokens)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_num_tokens(text: str) -> int:
|
def get_num_tokens(text: str) -> int:
|
||||||
encoding = tiktoken.encoding_for_model("gpt2")
|
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||||
tiktoken_vec = encoding.encode(text)
|
#
|
||||||
return len(tiktoken_vec)
|
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||||
|
# result = future.result()
|
||||||
|
# return cast(int, result)
|
||||||
|
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_encoder() -> Any:
|
||||||
|
global _tokenizer, _lock
|
||||||
|
with _lock:
|
||||||
|
if _tokenizer is None:
|
||||||
|
# Try to use tiktoken to get the tokenizer because it is faster
|
||||||
|
#
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
except Exception:
|
||||||
|
from os.path import abspath, dirname, join
|
||||||
|
|
||||||
|
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||||
|
|
||||||
|
base_path = abspath(__file__)
|
||||||
|
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||||
|
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||||
|
|
||||||
|
return _tokenizer
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
stream: bool,
|
stream: bool,
|
||||||
inputs: Mapping,
|
inputs: Mapping,
|
||||||
files: list[dict],
|
files: list[dict],
|
||||||
):
|
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||||
"""
|
"""
|
||||||
invoke workflow app
|
invoke workflow app
|
||||||
"""
|
"""
|
||||||
@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
stream: bool,
|
stream: bool,
|
||||||
inputs: Mapping,
|
inputs: Mapping,
|
||||||
files: list[dict],
|
files: list[dict],
|
||||||
):
|
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||||
"""
|
"""
|
||||||
invoke completion app
|
invoke completion app
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
|
|||||||
return summary.message.content
|
return summary.message.content
|
||||||
|
|
||||||
lines = content.split("\n")
|
lines = content.split("\n")
|
||||||
new_lines = []
|
new_lines: list[str] = []
|
||||||
# split long line into multiple lines
|
# split long line into multiple lines
|
||||||
for i in range(len(lines)):
|
for i in range(len(lines)):
|
||||||
line = lines[i]
|
line = lines[i]
|
||||||
@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
|
|||||||
|
|
||||||
# merge lines into messages with max tokens
|
# merge lines into messages with max tokens
|
||||||
messages: list[str] = []
|
messages: list[str] = []
|
||||||
for i in new_lines:
|
for i in new_lines: # type: ignore
|
||||||
if len(messages) == 0:
|
if len(messages) == 0:
|
||||||
messages.append(i)
|
messages.append(i) # type: ignore
|
||||||
else:
|
else:
|
||||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
|
||||||
messages[-1] += i
|
messages[-1] += i # type: ignore
|
||||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
|
||||||
messages.append(i)
|
messages.append(i) # type: ignore
|
||||||
else:
|
else:
|
||||||
messages[-1] += i
|
messages[-1] += i # type: ignore
|
||||||
|
|
||||||
summaries = []
|
summaries = []
|
||||||
for i in range(len(messages)):
|
for i in range(len(messages)):
|
||||||
|
|||||||
@ -103,7 +103,7 @@ class BasePluginManager:
|
|||||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||||
"""
|
"""
|
||||||
for line in self._stream_request(method, path, params, headers, data, files):
|
for line in self._stream_request(method, path, params, headers, data, files):
|
||||||
yield type(**json.loads(line))
|
yield type(**json.loads(line)) # type: ignore
|
||||||
|
|
||||||
def _request_with_model(
|
def _request_with_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
quoted_ids = [f"'{id}'" for id in ids]
|
quoted_ids = [f"'{id}'" for id in ids]
|
||||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||||
|
|
||||||
|
|||||||
@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
|
|||||||
self._client.delete_collection(self._collection_name)
|
self._client.delete_collection(self._collection_name)
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
collection = self._client.get_or_create_collection(self._collection_name)
|
collection = self._client.get_or_create_collection(self._collection_name)
|
||||||
collection.delete(ids=ids)
|
collection.delete(ids=ids)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,104 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
|
||||||
|
ElasticSearchConfig,
|
||||||
|
ElasticSearchVector,
|
||||||
|
ElasticSearchVectorFactory,
|
||||||
|
)
|
||||||
|
from core.rag.datasource.vdb.field import Field
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticSearchJaVector(ElasticSearchVector):
|
||||||
|
def create_collection(
|
||||||
|
self,
|
||||||
|
embeddings: list[list[float]],
|
||||||
|
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||||
|
index_params: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
logger.info(f"Collection {self._collection_name} already exists.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
settings = {
|
||||||
|
"analysis": {
|
||||||
|
"analyzer": {
|
||||||
|
"ja_analyzer": {
|
||||||
|
"type": "custom",
|
||||||
|
"char_filter": [
|
||||||
|
"icu_normalizer",
|
||||||
|
"kuromoji_iteration_mark",
|
||||||
|
],
|
||||||
|
"tokenizer": "kuromoji_tokenizer",
|
||||||
|
"filter": [
|
||||||
|
"kuromoji_baseform",
|
||||||
|
"kuromoji_part_of_speech",
|
||||||
|
"ja_stop",
|
||||||
|
"kuromoji_number",
|
||||||
|
"kuromoji_stemmer",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mappings = {
|
||||||
|
"properties": {
|
||||||
|
Field.CONTENT_KEY.value: {
|
||||||
|
"type": "text",
|
||||||
|
"analyzer": "ja_analyzer",
|
||||||
|
"search_analyzer": "ja_analyzer",
|
||||||
|
},
|
||||||
|
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": dim,
|
||||||
|
"index": True,
|
||||||
|
"similarity": "cosine",
|
||||||
|
},
|
||||||
|
Field.METADATA_KEY.value: {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
|
||||||
|
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
return ElasticSearchJaVector(
|
||||||
|
index_name=collection_name,
|
||||||
|
config=ElasticSearchConfig(
|
||||||
|
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||||
|
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||||
|
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||||
|
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||||
|
),
|
||||||
|
attributes=[],
|
||||||
|
)
|
||||||
@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
return bool(self._client.exists(index=self._collection_name, id=id))
|
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
for id in ids:
|
for id in ids:
|
||||||
self._client.delete(index=self._collection_name, id=id)
|
self._client.delete(index=self._collection_name, id=id)
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,8 @@ class Field(Enum):
|
|||||||
METADATA_KEY = "metadata"
|
METADATA_KEY = "metadata"
|
||||||
GROUP_KEY = "group_id"
|
GROUP_KEY = "group_id"
|
||||||
VECTOR = "vector"
|
VECTOR = "vector"
|
||||||
|
# Sparse Vector aims to support full text search
|
||||||
|
SPARSE_VECTOR = "sparse_vector"
|
||||||
TEXT_KEY = "text"
|
TEXT_KEY = "text"
|
||||||
PRIMARY_KEY = "id"
|
PRIMARY_KEY = "id"
|
||||||
DOC_ID = "metadata.doc_id"
|
DOC_ID = "metadata.doc_id"
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from pymilvus import MilvusClient, MilvusException # type: ignore
|
from pymilvus import MilvusClient, MilvusException # type: ignore
|
||||||
from pymilvus.milvus_client import IndexParams # type: ignore
|
from pymilvus.milvus_client import IndexParams # type: ignore
|
||||||
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class MilvusConfig(BaseModel):
|
class MilvusConfig(BaseModel):
|
||||||
uri: str
|
"""
|
||||||
token: Optional[str] = None
|
Configuration class for Milvus connection.
|
||||||
user: str
|
"""
|
||||||
password: str
|
|
||||||
batch_size: int = 100
|
uri: str # Milvus server URI
|
||||||
database: str = "default"
|
token: Optional[str] = None # Optional token for authentication
|
||||||
|
user: str # Username for authentication
|
||||||
|
password: str # Password for authentication
|
||||||
|
batch_size: int = 100 # Batch size for operations
|
||||||
|
database: str = "default" # Database name
|
||||||
|
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Validate the configuration values.
|
||||||
|
Raises ValueError if required fields are missing.
|
||||||
|
"""
|
||||||
if not values.get("uri"):
|
if not values.get("uri"):
|
||||||
raise ValueError("config MILVUS_URI is required")
|
raise ValueError("config MILVUS_URI is required")
|
||||||
if not values.get("user"):
|
if not values.get("user"):
|
||||||
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def to_milvus_params(self):
|
def to_milvus_params(self):
|
||||||
|
"""
|
||||||
|
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
"uri": self.uri,
|
"uri": self.uri,
|
||||||
"token": self.token,
|
"token": self.token,
|
||||||
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class MilvusVector(BaseVector):
|
class MilvusVector(BaseVector):
|
||||||
|
"""
|
||||||
|
Milvus vector storage implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||||
super().__init__(collection_name)
|
super().__init__(collection_name)
|
||||||
self._client_config = config
|
self._client_config = config
|
||||||
self._client = self._init_client(config)
|
self._client = self._init_client(config)
|
||||||
self._consistency_level = "Session"
|
self._consistency_level = "Session" # Consistency level for Milvus operations
|
||||||
self._fields: list[str] = []
|
self._fields: list[str] = [] # List of fields in the collection
|
||||||
|
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||||
|
|
||||||
|
def _check_hybrid_search_support(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current Milvus version supports hybrid search.
|
||||||
|
Returns True if the version is >= 2.5.0, otherwise False.
|
||||||
|
"""
|
||||||
|
if not self._client_config.enable_hybrid_search:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
milvus_version = self._client.get_server_version()
|
||||||
|
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||||
|
return False
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the type of vector storage (Milvus).
|
||||||
|
"""
|
||||||
return VectorType.MILVUS
|
return VectorType.MILVUS
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
"""
|
||||||
|
Create a collection and add texts with embeddings.
|
||||||
|
"""
|
||||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||||
self.create_collection(embeddings, metadatas, index_params)
|
self.create_collection(embeddings, metadatas, index_params)
|
||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
"""
|
||||||
|
Add texts and their embeddings to the collection.
|
||||||
|
"""
|
||||||
insert_dict_list = []
|
insert_dict_list = []
|
||||||
for i in range(len(documents)):
|
for i in range(len(documents)):
|
||||||
insert_dict = {
|
insert_dict = {
|
||||||
|
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
|
||||||
|
# function will automatically convert the native text into a sparse vector for us.
|
||||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||||
Field.VECTOR.value: embeddings[i],
|
Field.VECTOR.value: embeddings[i],
|
||||||
Field.METADATA_KEY.value: documents[i].metadata,
|
Field.METADATA_KEY.value: documents[i].metadata,
|
||||||
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
|
|||||||
insert_dict_list.append(insert_dict)
|
insert_dict_list.append(insert_dict)
|
||||||
# Total insert count
|
# Total insert count
|
||||||
total_count = len(insert_dict_list)
|
total_count = len(insert_dict_list)
|
||||||
|
|
||||||
pks: list[str] = []
|
pks: list[str] = []
|
||||||
|
|
||||||
for i in range(0, total_count, 1000):
|
for i in range(0, total_count, 1000):
|
||||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
|
||||||
# Insert into the collection.
|
# Insert into the collection.
|
||||||
|
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||||
try:
|
try:
|
||||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||||
pks.extend(ids)
|
pks.extend(ids)
|
||||||
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
|
|||||||
return pks
|
return pks
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
|
"""
|
||||||
|
Get document IDs by metadata field key and value.
|
||||||
|
"""
|
||||||
result = self._client.query(
|
result = self._client.query(
|
||||||
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
||||||
)
|
)
|
||||||
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str):
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
|
"""
|
||||||
|
Delete documents by metadata field key and value.
|
||||||
|
"""
|
||||||
if self._client.has_collection(self._collection_name):
|
if self._client.has_collection(self._collection_name):
|
||||||
ids = self.get_ids_by_metadata_field(key, value)
|
ids = self.get_ids_by_metadata_field(key, value)
|
||||||
if ids:
|
if ids:
|
||||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
"""
|
||||||
|
Delete documents by their IDs.
|
||||||
|
"""
|
||||||
if self._client.has_collection(self._collection_name):
|
if self._client.has_collection(self._collection_name):
|
||||||
result = self._client.query(
|
result = self._client.query(
|
||||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
||||||
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
|
|||||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
|
"""
|
||||||
|
Delete the entire collection.
|
||||||
|
"""
|
||||||
if self._client.has_collection(self._collection_name):
|
if self._client.has_collection(self._collection_name):
|
||||||
self._client.drop_collection(self._collection_name, None)
|
self._client.drop_collection(self._collection_name, None)
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a text with the given ID exists in the collection.
|
||||||
|
"""
|
||||||
if not self._client.has_collection(self._collection_name):
|
if not self._client.has_collection(self._collection_name):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
|
|||||||
|
|
||||||
return len(result) > 0
|
return len(result) > 0
|
||||||
|
|
||||||
|
def field_exists(self, field: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a field exists in the collection.
|
||||||
|
"""
|
||||||
|
return field in self._fields
|
||||||
|
|
||||||
|
def _process_search_results(
|
||||||
|
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
|
||||||
|
) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Common method to process search results
|
||||||
|
|
||||||
|
:param results: Search results
|
||||||
|
:param output_fields: Fields to be output
|
||||||
|
:param score_threshold: Score threshold for filtering
|
||||||
|
:return: List of documents
|
||||||
|
"""
|
||||||
|
docs = []
|
||||||
|
for result in results[0]:
|
||||||
|
metadata = result["entity"].get(output_fields[1], {})
|
||||||
|
metadata["score"] = result["distance"]
|
||||||
|
|
||||||
|
if result["distance"] > score_threshold:
|
||||||
|
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
# Set search parameters.
|
"""
|
||||||
|
Search for documents by vector similarity.
|
||||||
|
"""
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
data=[query_vector],
|
data=[query_vector],
|
||||||
|
anns_field=Field.VECTOR.value,
|
||||||
limit=kwargs.get("top_k", 4),
|
limit=kwargs.get("top_k", 4),
|
||||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
)
|
)
|
||||||
# Organize results.
|
|
||||||
docs = []
|
return self._process_search_results(
|
||||||
for result in results[0]:
|
results,
|
||||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
metadata["score"] = result["distance"]
|
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
)
|
||||||
if result["distance"] > score_threshold:
|
|
||||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
# milvus/zilliz doesn't support bm25 search
|
"""
|
||||||
return []
|
Search for documents by full-text search (if hybrid search is enabled).
|
||||||
|
"""
|
||||||
|
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||||
|
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = self._client.search(
|
||||||
|
collection_name=self._collection_name,
|
||||||
|
data=[query],
|
||||||
|
anns_field=Field.SPARSE_VECTOR.value,
|
||||||
|
limit=kwargs.get("top_k", 4),
|
||||||
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._process_search_results(
|
||||||
|
results,
|
||||||
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
|
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
def create_collection(
|
def create_collection(
|
||||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Create a new collection in Milvus with the specified schema and index parameters.
|
||||||
|
"""
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||||
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
|
|||||||
return
|
return
|
||||||
# Grab the existing collection if it exists
|
# Grab the existing collection if it exists
|
||||||
if not self._client.has_collection(self._collection_name):
|
if not self._client.has_collection(self._collection_name):
|
||||||
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
|
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
|
||||||
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
||||||
|
|
||||||
# Determine embedding dim
|
# Determine embedding dim
|
||||||
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
|
|||||||
if metadatas:
|
if metadatas:
|
||||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||||
|
|
||||||
# Create the text field
|
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
|
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||||
|
fields.append(
|
||||||
|
FieldSchema(
|
||||||
|
Field.CONTENT_KEY.value,
|
||||||
|
DataType.VARCHAR,
|
||||||
|
max_length=65_535,
|
||||||
|
enable_analyzer=self._hybrid_search_enabled,
|
||||||
|
)
|
||||||
|
)
|
||||||
# Create the primary key field
|
# Create the primary key field
|
||||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||||
# Create the vector field, supports binary or float vectors
|
# Create the vector field, supports binary or float vectors
|
||||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||||
|
# Create Sparse Vector Index for the collection
|
||||||
|
if self._hybrid_search_enabled:
|
||||||
|
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
|
||||||
|
|
||||||
# Create the schema for the collection
|
|
||||||
schema = CollectionSchema(fields)
|
schema = CollectionSchema(fields)
|
||||||
|
|
||||||
|
# Create custom function to support text to sparse vector by BM25
|
||||||
|
if self._hybrid_search_enabled:
|
||||||
|
bm25_function = Function(
|
||||||
|
name="text_bm25_emb",
|
||||||
|
input_field_names=[Field.CONTENT_KEY.value],
|
||||||
|
output_field_names=[Field.SPARSE_VECTOR.value],
|
||||||
|
function_type=FunctionType.BM25,
|
||||||
|
)
|
||||||
|
schema.add_function(bm25_function)
|
||||||
|
|
||||||
for x in schema.fields:
|
for x in schema.fields:
|
||||||
self._fields.append(x.name)
|
self._fields.append(x.name)
|
||||||
# Since primary field is auto-id, no need to track it
|
# Since primary field is auto-id, no need to track it
|
||||||
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
|
|||||||
index_params_obj = IndexParams()
|
index_params_obj = IndexParams()
|
||||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||||
|
|
||||||
|
# Create Sparse Vector Index for the collection
|
||||||
|
if self._hybrid_search_enabled:
|
||||||
|
index_params_obj.add_index(
|
||||||
|
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
|
||||||
|
)
|
||||||
|
|
||||||
# Create the collection
|
# Create the collection
|
||||||
collection_name = self._collection_name
|
|
||||||
self._client.create_collection(
|
self._client.create_collection(
|
||||||
collection_name=collection_name,
|
collection_name=self._collection_name,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
index_params=index_params_obj,
|
index_params=index_params_obj,
|
||||||
consistency_level=self._consistency_level,
|
consistency_level=self._consistency_level,
|
||||||
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
|
|||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
def _init_client(self, config) -> MilvusClient:
|
def _init_client(self, config) -> MilvusClient:
|
||||||
|
"""
|
||||||
|
Initialize and return a Milvus client.
|
||||||
|
"""
|
||||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorFactory(AbstractVectorFactory):
|
class MilvusVectorFactory(AbstractVectorFactory):
|
||||||
|
"""
|
||||||
|
Factory class for creating MilvusVector instances.
|
||||||
|
"""
|
||||||
|
|
||||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||||
|
"""
|
||||||
|
Initialize a MilvusVector instance for the given dataset.
|
||||||
|
"""
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
collection_name = class_prefix
|
collection_name = class_prefix
|
||||||
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
|||||||
user=dify_config.MILVUS_USER or "",
|
user=dify_config.MILVUS_USER or "",
|
||||||
password=dify_config.MILVUS_PASSWORD or "",
|
password=dify_config.MILVUS_PASSWORD or "",
|
||||||
database=dify_config.MILVUS_DATABASE or "",
|
database=dify_config.MILVUS_DATABASE or "",
|
||||||
|
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
|
|||||||
return results.row_count > 0
|
return results.row_count > 0
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
self._client.command(
|
self._client.command(
|
||||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
|
|||||||
return bool(cur.rowcount != 0)
|
return bool(cur.rowcount != 0)
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||||
|
|||||||
@ -167,6 +167,8 @@ class OracleVector(BaseVector):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||||
|
|
||||||
|
|||||||
@ -129,6 +129,11 @@ class PGVector(BaseVector):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||||
|
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||||
|
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||||
|
|
||||||
|
|||||||
@ -140,6 +140,8 @@ class TencentVector(BaseVector):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
|||||||
@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||||
)
|
)
|
||||||
if not tidb_auth_binding:
|
if not tidb_auth_binding:
|
||||||
idle_tidb_auth_binding = (
|
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||||
db.session.query(TidbAuthBinding)
|
tidb_auth_binding = (
|
||||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
db.session.query(TidbAuthBinding)
|
||||||
.limit(1)
|
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
)
|
)
|
||||||
if idle_tidb_auth_binding:
|
if tidb_auth_binding:
|
||||||
idle_tidb_auth_binding.active = True
|
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
|
||||||
db.session.commit()
|
else:
|
||||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
idle_tidb_auth_binding = (
|
||||||
else:
|
|
||||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
|
||||||
tidb_auth_binding = (
|
|
||||||
db.session.query(TidbAuthBinding)
|
db.session.query(TidbAuthBinding)
|
||||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||||
|
.limit(1)
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
)
|
)
|
||||||
if tidb_auth_binding:
|
if idle_tidb_auth_binding:
|
||||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
idle_tidb_auth_binding.active = True
|
||||||
|
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||||
|
db.session.commit()
|
||||||
|
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||||
else:
|
else:
|
||||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||||
dify_config.TIDB_PROJECT_ID or "",
|
dify_config.TIDB_PROJECT_ID or "",
|
||||||
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||||||
db.session.add(new_tidb_auth_binding)
|
db.session.add(new_tidb_auth_binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||||
|
|
||||||
|
|||||||
@ -90,6 +90,12 @@ class Vector:
|
|||||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
|
|
||||||
return ElasticSearchVectorFactory
|
return ElasticSearchVectorFactory
|
||||||
|
case VectorType.ELASTICSEARCH_JA:
|
||||||
|
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||||
|
ElasticSearchJaVectorFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ElasticSearchJaVectorFactory
|
||||||
case VectorType.TIDB_VECTOR:
|
case VectorType.TIDB_VECTOR:
|
||||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ class VectorType(StrEnum):
|
|||||||
TENCENT = "tencent"
|
TENCENT = "tencent"
|
||||||
ORACLE = "oracle"
|
ORACLE = "oracle"
|
||||||
ELASTICSEARCH = "elasticsearch"
|
ELASTICSEARCH = "elasticsearch"
|
||||||
|
ELASTICSEARCH_JA = "elasticsearch-ja"
|
||||||
LINDORM = "lindorm"
|
LINDORM = "lindorm"
|
||||||
COUCHBASE = "couchbase"
|
COUCHBASE = "couchbase"
|
||||||
BAIDU = "baidu"
|
BAIDU = "baidu"
|
||||||
|
|||||||
@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
|
|||||||
self._file_cache_key = file_cache_key
|
self._file_cache_key = file_cache_key
|
||||||
|
|
||||||
def extract(self) -> list[Document]:
|
def extract(self) -> list[Document]:
|
||||||
plaintext_file_key = ""
|
|
||||||
plaintext_file_exists = False
|
plaintext_file_exists = False
|
||||||
if self._file_cache_key:
|
if self._file_cache_key:
|
||||||
try:
|
try:
|
||||||
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
|
|||||||
text = "\n\n".join(text_list)
|
text = "\n\n".join(text_list)
|
||||||
|
|
||||||
# save plaintext file for caching
|
# save plaintext file for caching
|
||||||
if not plaintext_file_exists and plaintext_file_key:
|
if not plaintext_file_exists and self._file_cache_key:
|
||||||
storage.save(plaintext_file_key, text.encode("utf-8"))
|
storage.save(self._file_cache_key, text.encode("utf-8"))
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
child_nodes = self._split_child_nodes(
|
child_nodes = self._split_child_nodes(
|
||||||
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||||
)
|
)
|
||||||
|
if kwargs.get("preview"):
|
||||||
|
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
|
||||||
|
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
|
||||||
|
|
||||||
document.children = child_nodes
|
document.children = child_nodes
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = str(uuid.uuid4())
|
||||||
hash = helper.generate_text_hash(document.page_content)
|
hash = helper.generate_text_hash(document.page_content)
|
||||||
|
|||||||
@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
|
|||||||
items.append((provider, model.model))
|
items.append((provider, model.model))
|
||||||
return items
|
return items
|
||||||
|
|
||||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
def get_runtime_parameters(
|
||||||
|
self,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
app_id: Optional[str] = None,
|
||||||
|
message_id: Optional[str] = None,
|
||||||
|
) -> list[ToolParameter]:
|
||||||
parameters = []
|
parameters = []
|
||||||
|
|
||||||
options = []
|
options = []
|
||||||
|
|||||||
@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
|
|||||||
items.append((provider, model.model, voices))
|
items.append((provider, model.model, voices))
|
||||||
return items
|
return items
|
||||||
|
|
||||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
def get_runtime_parameters(
|
||||||
|
self,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
app_id: Optional[str] = None,
|
||||||
|
message_id: Optional[str] = None,
|
||||||
|
) -> list[ToolParameter]:
|
||||||
parameters = []
|
parameters = []
|
||||||
|
|
||||||
options = []
|
options = []
|
||||||
|
|||||||
@ -212,8 +212,23 @@ class ApiTool(Tool):
|
|||||||
else:
|
else:
|
||||||
body = body
|
body = body
|
||||||
|
|
||||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
if method in {
|
||||||
response: httpx.Response = getattr(ssrf_proxy, method)(
|
"get",
|
||||||
|
"head",
|
||||||
|
"post",
|
||||||
|
"put",
|
||||||
|
"delete",
|
||||||
|
"patch",
|
||||||
|
"options",
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"HEAD",
|
||||||
|
"OPTIONS",
|
||||||
|
}:
|
||||||
|
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
|
||||||
url,
|
url,
|
||||||
params=params,
|
params=params,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|||||||
@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
|
|
||||||
@field_validator("variable_name", mode="before")
|
@field_validator("variable_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_variable_name(cls, value) -> str:
|
def transform_variable_name(cls, value: str) -> str:
|
||||||
"""
|
"""
|
||||||
The variable name must be a string.
|
The variable name must be a string.
|
||||||
"""
|
"""
|
||||||
@ -167,6 +167,7 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
error: Optional[str] = Field(default=None, description="The error message")
|
error: Optional[str] = Field(default=None, description="The error message")
|
||||||
status: LogStatus = Field(..., description="The status of the log")
|
status: LogStatus = Field(..., description="The status of the log")
|
||||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||||
|
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
|
||||||
|
|
||||||
class MessageType(Enum):
|
class MessageType(Enum):
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
|
|||||||
@ -203,6 +203,7 @@ class AgentLogEvent(BaseAgentEvent):
|
|||||||
error: str | None = Field(..., description="error")
|
error: str | None = Field(..., description="error")
|
||||||
status: str = Field(..., description="status")
|
status: str = Field(..., description="status")
|
||||||
data: Mapping[str, Any] = Field(..., description="data")
|
data: Mapping[str, Any] = Field(..., description="data")
|
||||||
|
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||||
|
|
||||||
|
|
||||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent
|
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent
|
||||||
|
|||||||
@ -89,7 +89,11 @@ class AgentNode(ToolNode):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# convert tool messages
|
# convert tool messages
|
||||||
yield from self._transform_message(message_stream, {}, parameters_for_log)
|
yield from self._transform_message(
|
||||||
|
message_stream,
|
||||||
|
{"provider": (cast(AgentNodeData, self.node_data)).agent_strategy_provider_name},
|
||||||
|
parameters_for_log,
|
||||||
|
)
|
||||||
except PluginDaemonClientSideError as e:
|
except PluginDaemonClientSideError as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
@ -170,7 +174,12 @@ class AgentNode(ToolNode):
|
|||||||
extra.get("descrption", "") or tool_runtime.entity.description.llm
|
extra.get("descrption", "") or tool_runtime.entity.description.llm
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_value.append(tool_runtime.entity.model_dump(mode="json"))
|
tool_value.append(
|
||||||
|
{
|
||||||
|
**tool_runtime.entity.model_dump(mode="json"),
|
||||||
|
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
|
||||||
|
}
|
||||||
|
)
|
||||||
value = tool_value
|
value = tool_value
|
||||||
if parameter.type == "model-selector":
|
if parameter.type == "model-selector":
|
||||||
value = cast(dict[str, Any], value)
|
value = cast(dict[str, Any], value)
|
||||||
|
|||||||
@ -2,14 +2,18 @@ import csv
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import cast
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
import docx
|
import docx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pypdfium2 # type: ignore
|
import pypdfium2 # type: ignore
|
||||||
import yaml # type: ignore
|
import yaml # type: ignore
|
||||||
|
from docx.table import Table
|
||||||
|
from docx.text.paragraph import Paragraph
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file import File, FileTransferMethod, file_manager
|
from core.file import File, FileTransferMethod, file_manager
|
||||||
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
|||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
node_id: str,
|
||||||
|
node_data: DocumentExtractorNodeData,
|
||||||
|
) -> Mapping[str, Sequence[str]]:
|
||||||
|
"""
|
||||||
|
Extract variable selector to variable mapping
|
||||||
|
:param graph_config: graph config
|
||||||
|
:param node_id: node id
|
||||||
|
:param node_data: node data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return {node_id + ".files": node_data.variable_selector}
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||||
"""Extract text from a file based on its MIME type."""
|
"""Extract text from a file based on its MIME type."""
|
||||||
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
|||||||
doc_file = io.BytesIO(file_content)
|
doc_file = io.BytesIO(file_content)
|
||||||
doc = docx.Document(doc_file)
|
doc = docx.Document(doc_file)
|
||||||
text = []
|
text = []
|
||||||
# Process paragraphs
|
|
||||||
for paragraph in doc.paragraphs:
|
|
||||||
if paragraph.text.strip():
|
|
||||||
text.append(paragraph.text)
|
|
||||||
|
|
||||||
# Process tables
|
# Keep track of paragraph and table positions
|
||||||
for table in doc.tables:
|
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||||
# Table header
|
|
||||||
try:
|
# Process paragraphs and tables
|
||||||
# table maybe cause errors so ignore it.
|
for i, paragraph in enumerate(doc.paragraphs):
|
||||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
if paragraph.text.strip():
|
||||||
|
content_items.append((i, "paragraph", paragraph))
|
||||||
|
|
||||||
|
for i, table in enumerate(doc.tables):
|
||||||
|
content_items.append((i, "table", table))
|
||||||
|
|
||||||
|
# Sort content items based on their original position
|
||||||
|
content_items.sort(key=operator.itemgetter(0))
|
||||||
|
|
||||||
|
# Process sorted content
|
||||||
|
for _, item_type, item in content_items:
|
||||||
|
if item_type == "paragraph":
|
||||||
|
if isinstance(item, Table):
|
||||||
|
continue
|
||||||
|
text.append(item.text)
|
||||||
|
elif item_type == "table":
|
||||||
|
# Process tables
|
||||||
|
if not isinstance(item, Table):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
# Check if any cell in the table has text
|
# Check if any cell in the table has text
|
||||||
has_content = False
|
has_content = False
|
||||||
for row in table.rows:
|
for row in item.rows:
|
||||||
if any(cell.text.strip() for cell in row.cells):
|
if any(cell.text.strip() for cell in row.cells):
|
||||||
has_content = True
|
has_content = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if has_content:
|
if has_content:
|
||||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
|
||||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
markdown_table = f"| {' | '.join(cell_texts)} |\n"
|
||||||
for row in table.rows[1:]:
|
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
|
||||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
|
||||||
|
for row in item.rows[1:]:
|
||||||
|
# Replace newlines with <br> in each cell
|
||||||
|
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
|
||||||
|
markdown_table += "| " + " | ".join(row_cells) + " |\n"
|
||||||
|
|
||||||
text.append(markdown_table)
|
text.append(markdown_table)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return "\n".join(text)
|
return "\n".join(text)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||||
|
|
||||||
|
|||||||
@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
|
|||||||
Code Node Data.
|
Code Node Data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["get", "post", "put", "patch", "delete", "head"]
|
method: Literal[
|
||||||
|
"get",
|
||||||
|
"post",
|
||||||
|
"put",
|
||||||
|
"patch",
|
||||||
|
"delete",
|
||||||
|
"head",
|
||||||
|
"options",
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"HEAD",
|
||||||
|
"OPTIONS",
|
||||||
|
]
|
||||||
url: str
|
url: str
|
||||||
authorization: HttpRequestNodeAuthorization
|
authorization: HttpRequestNodeAuthorization
|
||||||
headers: str
|
headers: str
|
||||||
|
|||||||
@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
|
|||||||
|
|
||||||
|
|
||||||
class Executor:
|
class Executor:
|
||||||
method: Literal["get", "head", "post", "put", "delete", "patch"]
|
method: Literal[
|
||||||
|
"get",
|
||||||
|
"head",
|
||||||
|
"post",
|
||||||
|
"put",
|
||||||
|
"delete",
|
||||||
|
"patch",
|
||||||
|
"options",
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"HEAD",
|
||||||
|
"OPTIONS",
|
||||||
|
]
|
||||||
url: str
|
url: str
|
||||||
params: list[tuple[str, str]] | None
|
params: list[tuple[str, str]] | None
|
||||||
content: str | bytes | None
|
content: str | bytes | None
|
||||||
@ -67,12 +82,6 @@ class Executor:
|
|||||||
node_data.authorization.config.api_key
|
node_data.authorization.config.api_key
|
||||||
).text
|
).text
|
||||||
|
|
||||||
# check if node_data.url is a valid URL
|
|
||||||
if not node_data.url:
|
|
||||||
raise InvalidURLError("url is required")
|
|
||||||
if not node_data.url.startswith(("http://", "https://")):
|
|
||||||
raise InvalidURLError("url should start with http:// or https://")
|
|
||||||
|
|
||||||
self.url: str = node_data.url
|
self.url: str = node_data.url
|
||||||
self.method = node_data.method
|
self.method = node_data.method
|
||||||
self.auth = node_data.authorization
|
self.auth = node_data.authorization
|
||||||
@ -99,6 +108,12 @@ class Executor:
|
|||||||
def _init_url(self):
|
def _init_url(self):
|
||||||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||||
|
|
||||||
|
# check if url is a valid URL
|
||||||
|
if not self.url:
|
||||||
|
raise InvalidURLError("url is required")
|
||||||
|
if not self.url.startswith(("http://", "https://")):
|
||||||
|
raise InvalidURLError("url should start with http:// or https://")
|
||||||
|
|
||||||
def _init_params(self):
|
def _init_params(self):
|
||||||
"""
|
"""
|
||||||
Almost same as _init_headers(), difference:
|
Almost same as _init_headers(), difference:
|
||||||
@ -158,7 +173,10 @@ class Executor:
|
|||||||
if len(data) != 1:
|
if len(data) != 1:
|
||||||
raise RequestBodyError("json body type should have exactly one item")
|
raise RequestBodyError("json body type should have exactly one item")
|
||||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||||
json_object = json.loads(json_string, strict=False)
|
try:
|
||||||
|
json_object = json.loads(json_string, strict=False)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
|
||||||
self.json = json_object
|
self.json = json_object
|
||||||
# self.json = self._parse_object_contains_variables(json_object)
|
# self.json = self._parse_object_contains_variables(json_object)
|
||||||
case "binary":
|
case "binary":
|
||||||
@ -246,7 +264,22 @@ class Executor:
|
|||||||
"""
|
"""
|
||||||
do http request depending on api bundle
|
do http request depending on api bundle
|
||||||
"""
|
"""
|
||||||
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
|
if self.method not in {
|
||||||
|
"get",
|
||||||
|
"head",
|
||||||
|
"post",
|
||||||
|
"put",
|
||||||
|
"delete",
|
||||||
|
"patch",
|
||||||
|
"options",
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"HEAD",
|
||||||
|
"OPTIONS",
|
||||||
|
}:
|
||||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||||
|
|
||||||
request_args = {
|
request_args = {
|
||||||
@ -263,7 +296,7 @@ class Executor:
|
|||||||
}
|
}
|
||||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||||
try:
|
try:
|
||||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||||
raise HttpRequestNodeError(str(e))
|
raise HttpRequestNodeError(str(e))
|
||||||
# FIXME: fix type ignore, this maybe httpx type issue
|
# FIXME: fix type ignore, this maybe httpx type issue
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -197,6 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
json: list[dict] = []
|
json: list[dict] = []
|
||||||
|
|
||||||
agent_logs: list[AgentLogEvent] = []
|
agent_logs: list[AgentLogEvent] = []
|
||||||
|
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
|
||||||
|
|
||||||
variables: dict[str, Any] = {}
|
variables: dict[str, Any] = {}
|
||||||
|
|
||||||
@ -264,6 +265,11 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||||
|
if self.node_type == NodeType.AGENT:
|
||||||
|
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||||
|
agent_execution_metadata = {
|
||||||
|
key: value for key, value in msg_metadata.items() if key in NodeRunMetadataKey
|
||||||
|
}
|
||||||
json.append(message.message.json_object)
|
json.append(message.message.json_object)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||||
@ -299,6 +305,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
status=message.message.status.value,
|
status=message.message.status.value,
|
||||||
data=message.message.data,
|
data=message.message.data,
|
||||||
label=message.message.label,
|
label=message.message.label,
|
||||||
|
metadata=message.message.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if the agent log is already in the list
|
# check if the agent log is already in the list
|
||||||
@ -309,6 +316,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
log.status = agent_log.status
|
log.status = agent_log.status
|
||||||
log.error = agent_log.error
|
log.error = agent_log.error
|
||||||
log.label = agent_log.label
|
log.label = agent_log.label
|
||||||
|
log.metadata = agent_log.metadata
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
agent_logs.append(agent_log)
|
agent_logs.append(agent_log)
|
||||||
@ -319,7 +327,11 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={"text": text, "files": files, "json": json, **variables},
|
outputs={"text": text, "files": files, "json": json, **variables},
|
||||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info, NodeRunMetadataKey.AGENT_LOG: agent_logs},
|
metadata={
|
||||||
|
**agent_execution_metadata,
|
||||||
|
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||||
|
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||||
|
},
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -340,6 +340,10 @@ class WorkflowEntry:
|
|||||||
):
|
):
|
||||||
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
||||||
|
|
||||||
|
# environment variable already exist in variable pool, not from user inputs
|
||||||
|
if variable_pool.get(variable_selector):
|
||||||
|
continue
|
||||||
|
|
||||||
# fetch variable node id from variable selector
|
# fetch variable node id from variable selector
|
||||||
variable_node_id = variable_selector[0]
|
variable_node_id = variable_selector[0]
|
||||||
variable_key_list = variable_selector[1:]
|
variable_key_list = variable_selector[1:]
|
||||||
|
|||||||
@ -33,6 +33,7 @@ else
|
|||||||
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
||||||
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
||||||
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
||||||
|
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
|
||||||
--timeout ${GUNICORN_TIMEOUT:-200} \
|
--timeout ${GUNICORN_TIMEOUT:-200} \
|
||||||
app:app
|
app:app
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -46,7 +46,7 @@ def init_app(app: DifyApp):
|
|||||||
timezone = pytz.timezone(log_tz)
|
timezone = pytz.timezone(log_tz)
|
||||||
|
|
||||||
def time_converter(seconds):
|
def time_converter(seconds):
|
||||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||||
|
|
||||||
for handler in logging.root.handlers:
|
for handler in logging.root.handlers:
|
||||||
if handler.formatter:
|
if handler.formatter:
|
||||||
|
|||||||
@ -158,7 +158,7 @@ def _build_from_remote_url(
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
) -> File:
|
) -> File:
|
||||||
url = mapping.get("url")
|
url = mapping.get("url") or mapping.get("remote_url")
|
||||||
if not url:
|
if not url:
|
||||||
raise ValueError("Invalid file url")
|
raise ValueError("Invalid file url")
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import string
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||||
@ -182,7 +182,7 @@ def generate_text_hash(text: str) -> str:
|
|||||||
return sha256(hash_text.encode()).hexdigest()
|
return sha256(hash_text.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
|
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
return Response(response=json.dumps(response), status=200, mimetype="application/json")
|
return Response(response=json.dumps(response), status=200, mimetype="application/json")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise ValueError(f"Error fetching block parent page ID: {response_json.message}")
|
message = response_json.get("message", "unknown error")
|
||||||
|
raise ValueError(f"Error fetching block parent page ID: {message}")
|
||||||
parent = response_json["parent"]
|
parent = response_json["parent"]
|
||||||
parent_type = parent["type"]
|
parent_type = parent["type"]
|
||||||
if parent_type == "block_id":
|
if parent_type == "block_id":
|
||||||
|
|||||||
@ -0,0 +1,41 @@
|
|||||||
|
"""change workflow_runs.total_tokens to bigint
|
||||||
|
|
||||||
|
Revision ID: a91b476a53de
|
||||||
|
Revises: 923752d42eb6
|
||||||
|
Create Date: 2025-01-01 20:00:01.207369
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'a91b476a53de'
|
||||||
|
down_revision = '923752d42eb6'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('total_tokens',
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
type_=sa.BigInteger(),
|
||||||
|
existing_nullable=False,
|
||||||
|
existing_server_default=sa.text('0'))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('total_tokens',
|
||||||
|
existing_type=sa.BigInteger(),
|
||||||
|
type_=sa.INTEGER(),
|
||||||
|
existing_nullable=False,
|
||||||
|
existing_server_default=sa.text('0'))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -415,8 +415,8 @@ class WorkflowRun(Base):
|
|||||||
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
|
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
|
||||||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
|
||||||
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
|
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||||
total_steps = db.Column(db.Integer, server_default=db.text("0"))
|
total_steps = db.Column(db.Integer, server_default=db.text("0"))
|
||||||
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
|
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
|
||||||
created_by = db.Column(StringUUID, nullable=False)
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
|
|||||||
2397
api/poetry.lock
generated
2397
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -71,7 +71,7 @@ pyjwt = "~2.8.0"
|
|||||||
pypdfium2 = "~4.30.0"
|
pypdfium2 = "~4.30.0"
|
||||||
python = ">=3.11,<3.13"
|
python = ">=3.11,<3.13"
|
||||||
python-docx = "~1.1.0"
|
python-docx = "~1.1.0"
|
||||||
python-dotenv = "1.0.0"
|
python-dotenv = "1.0.1"
|
||||||
pyyaml = "~6.0.1"
|
pyyaml = "~6.0.1"
|
||||||
readabilipy = "0.2.0"
|
readabilipy = "0.2.0"
|
||||||
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
||||||
@ -82,7 +82,7 @@ scikit-learn = "~1.5.1"
|
|||||||
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
||||||
sqlalchemy = "~2.0.29"
|
sqlalchemy = "~2.0.29"
|
||||||
starlette = "0.41.0"
|
starlette = "0.41.0"
|
||||||
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
|
tencentcloud-sdk-python-hunyuan = "~3.0.1294"
|
||||||
tiktoken = "~0.8.0"
|
tiktoken = "~0.8.0"
|
||||||
tokenizers = "~0.15.0"
|
tokenizers = "~0.15.0"
|
||||||
transformers = "~4.35.0"
|
transformers = "~4.35.0"
|
||||||
@ -92,7 +92,7 @@ validators = "0.21.0"
|
|||||||
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
|
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
|
||||||
websocket-client = "~1.7.0"
|
websocket-client = "~1.7.0"
|
||||||
xinference-client = "0.15.2"
|
xinference-client = "0.15.2"
|
||||||
yarl = "~1.9.4"
|
yarl = "~1.18.3"
|
||||||
youtube-transcript-api = "~0.6.2"
|
youtube-transcript-api = "~0.6.2"
|
||||||
zhipuai = "~2.1.5"
|
zhipuai = "~2.1.5"
|
||||||
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
||||||
@ -157,7 +157,7 @@ opensearch-py = "2.4.0"
|
|||||||
oracledb = "~2.2.1"
|
oracledb = "~2.2.1"
|
||||||
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
||||||
pgvector = "0.2.5"
|
pgvector = "0.2.5"
|
||||||
pymilvus = "~2.4.4"
|
pymilvus = "~2.5.0"
|
||||||
pymochow = "1.3.1"
|
pymochow = "1.3.1"
|
||||||
pyobvector = "~0.1.6"
|
pyobvector = "~0.1.6"
|
||||||
qdrant-client = "1.7.3"
|
qdrant-client = "1.7.3"
|
||||||
|
|||||||
@ -168,23 +168,6 @@ def clean_unused_datasets_task():
|
|||||||
else:
|
else:
|
||||||
plan = plan_cache.decode()
|
plan = plan_cache.decode()
|
||||||
if plan == "sandbox":
|
if plan == "sandbox":
|
||||||
# add auto disable log
|
|
||||||
documents = (
|
|
||||||
db.session.query(Document)
|
|
||||||
.filter(
|
|
||||||
Document.dataset_id == dataset.id,
|
|
||||||
Document.enabled == True,
|
|
||||||
Document.archived == False,
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
for document in documents:
|
|
||||||
dataset_auto_disable_log = DatasetAutoDisableLog(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
document_id=document.id,
|
|
||||||
)
|
|
||||||
db.session.add(dataset_auto_disable_log)
|
|
||||||
# remove index
|
# remove index
|
||||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||||
index_processor.clean(dataset, None)
|
index_processor.clean(dataset, None)
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class TokenPair(BaseModel):
|
|||||||
|
|
||||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||||
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
|
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
|
||||||
class AccountService:
|
class AccountService:
|
||||||
@ -921,6 +921,9 @@ class RegisterService:
|
|||||||
def invite_new_member(
|
def invite_new_member(
|
||||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
|
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if not inviter:
|
||||||
|
raise ValueError("Inviter is required")
|
||||||
|
|
||||||
"""Invite new member"""
|
"""Invite new member"""
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.query(Account).filter_by(email=email).first()
|
account = session.query(Account).filter_by(email=email).first()
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml # type: ignore
|
||||||
@ -124,7 +125,7 @@ class AppDslService:
|
|||||||
raise ValueError(f"Invalid import_mode: {import_mode}")
|
raise ValueError(f"Invalid import_mode: {import_mode}")
|
||||||
|
|
||||||
# Get YAML content
|
# Get YAML content
|
||||||
content: bytes | str = b""
|
content: str = ""
|
||||||
if mode == ImportMode.YAML_URL:
|
if mode == ImportMode.YAML_URL:
|
||||||
if not yaml_url:
|
if not yaml_url:
|
||||||
return Import(
|
return Import(
|
||||||
@ -133,13 +134,17 @@ class AppDslService:
|
|||||||
error="yaml_url is required when import_mode is yaml-url",
|
error="yaml_url is required when import_mode is yaml-url",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# tricky way to handle url from github to github raw url
|
parsed_url = urlparse(yaml_url)
|
||||||
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")):
|
if (
|
||||||
|
parsed_url.scheme == "https"
|
||||||
|
and parsed_url.netloc == "github.com"
|
||||||
|
and parsed_url.path.endswith((".yml", ".yaml"))
|
||||||
|
):
|
||||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||||
yaml_url = yaml_url.replace("/blob/", "/")
|
yaml_url = yaml_url.replace("/blob/", "/")
|
||||||
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
content = response.content
|
content = response.content.decode()
|
||||||
|
|
||||||
if len(content) > DSL_MAX_SIZE:
|
if len(content) > DSL_MAX_SIZE:
|
||||||
return Import(
|
return Import(
|
||||||
|
|||||||
@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
|
|||||||
|
|
||||||
|
|
||||||
class AppService:
|
class AppService:
|
||||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
|
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
|
||||||
"""
|
"""
|
||||||
Get app list with pagination
|
Get app list with pagination
|
||||||
|
:param user_id: user id
|
||||||
:param tenant_id: tenant id
|
:param tenant_id: tenant id
|
||||||
:param args: request args
|
:param args: request args
|
||||||
:return:
|
:return:
|
||||||
@ -44,6 +45,8 @@ class AppService:
|
|||||||
elif args["mode"] == "channel":
|
elif args["mode"] == "channel":
|
||||||
filters.append(App.mode == AppMode.CHANNEL.value)
|
filters.append(App.mode == AppMode.CHANNEL.value)
|
||||||
|
|
||||||
|
if args.get("is_created_by_me", False):
|
||||||
|
filters.append(App.created_by == user_id)
|
||||||
if args.get("name"):
|
if args.get("name"):
|
||||||
name = args["name"][:30]
|
name = args["name"][:30]
|
||||||
filters.append(App.name.ilike(f"%{name}%"))
|
filters.append(App.name.ilike(f"%{name}%"))
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||||
@ -17,7 +17,6 @@ class BillingService:
|
|||||||
params = {"tenant_id": tenant_id}
|
params = {"tenant_id": tenant_id}
|
||||||
|
|
||||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||||
|
|
||||||
return billing_info
|
return billing_info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -47,12 +46,13 @@ class BillingService:
|
|||||||
retry=retry_if_exception_type(httpx.RequestError),
|
retry=retry_if_exception_type(httpx.RequestError),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
|
||||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||||
|
|
||||||
url = f"{cls.base_url}{endpoint}"
|
url = f"{cls.base_url}{endpoint}"
|
||||||
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
||||||
|
if method == "GET" and response.status_code != httpx.codes.OK:
|
||||||
|
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -86,7 +86,7 @@ class DatasetService:
|
|||||||
else:
|
else:
|
||||||
return [], 0
|
return [], 0
|
||||||
else:
|
else:
|
||||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
if user.current_role != TenantAccountRole.OWNER:
|
||||||
# show all datasets that the user has permission to access
|
# show all datasets that the user has permission to access
|
||||||
if permitted_dataset_ids:
|
if permitted_dataset_ids:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
@ -382,7 +382,7 @@ class DatasetService:
|
|||||||
if dataset.tenant_id != user.current_tenant_id:
|
if dataset.tenant_id != user.current_tenant_id:
|
||||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
if user.current_role != TenantAccountRole.OWNER:
|
||||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
||||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
@ -404,7 +404,7 @@ class DatasetService:
|
|||||||
if not user:
|
if not user:
|
||||||
raise ValueError("User not found")
|
raise ValueError("User not found")
|
||||||
|
|
||||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
if user.current_role != TenantAccountRole.OWNER:
|
||||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
|
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
|
||||||
if dataset.created_by != user.id:
|
if dataset.created_by != user.id:
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
@ -434,6 +434,12 @@ class DatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
||||||
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
|
||||||
|
return {
|
||||||
|
"document_ids": [],
|
||||||
|
"count": 0,
|
||||||
|
}
|
||||||
# get recent 30 days auto disable logs
|
# get recent 30 days auto disable logs
|
||||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
||||||
@ -786,13 +792,19 @@ class DocumentService:
|
|||||||
dataset.indexing_technique = knowledge_config.indexing_technique
|
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||||
if knowledge_config.indexing_technique == "high_quality":
|
if knowledge_config.indexing_technique == "high_quality":
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
embedding_model = model_manager.get_default_model_instance(
|
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
dataset_embedding_model = knowledge_config.embedding_model
|
||||||
)
|
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||||
dataset.embedding_model = embedding_model.model
|
else:
|
||||||
dataset.embedding_model_provider = embedding_model.provider
|
embedding_model = model_manager.get_default_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||||
|
)
|
||||||
|
dataset_embedding_model = embedding_model.model
|
||||||
|
dataset_embedding_model_provider = embedding_model.provider
|
||||||
|
dataset.embedding_model = dataset_embedding_model
|
||||||
|
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
embedding_model.provider, embedding_model.model
|
dataset_embedding_model_provider, dataset_embedding_model
|
||||||
)
|
)
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
if not dataset.retrieval_model:
|
if not dataset.retrieval_model:
|
||||||
@ -804,7 +816,11 @@ class DocumentService:
|
|||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
|
dataset.retrieval_model = (
|
||||||
|
knowledge_config.retrieval_model.model_dump()
|
||||||
|
if knowledge_config.retrieval_model
|
||||||
|
else default_retrieval_model
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
if knowledge_config.original_document_id:
|
if knowledge_config.original_document_id:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class WorkflowAppService:
|
|||||||
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||||
|
|
||||||
if keyword:
|
if keyword:
|
||||||
keyword_like_val = f"%{args['keyword'][:30]}%"
|
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
|
||||||
keyword_conditions = [
|
keyword_conditions = [
|
||||||
WorkflowRun.inputs.ilike(keyword_like_val),
|
WorkflowRun.inputs.ilike(keyword_like_val),
|
||||||
WorkflowRun.outputs.ilike(keyword_like_val),
|
WorkflowRun.outputs.ilike(keyword_like_val),
|
||||||
|
|||||||
@ -298,7 +298,7 @@ class WorkflowService:
|
|||||||
start_at: float,
|
start_at: float,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
):
|
) -> WorkflowNodeExecution:
|
||||||
"""
|
"""
|
||||||
Handle node run result
|
Handle node run result
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception("Dataset not found")
|
raise Exception("Dataset not found")
|
||||||
index_type = dataset.doc_form
|
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
if action == "remove":
|
if action == "remove":
|
||||||
index_processor.clean(dataset, None, with_keywords=False)
|
index_processor.clean(dataset, None, with_keywords=False)
|
||||||
@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||||
)
|
)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
# clean collection
|
||||||
|
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
|
|||||||
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = GPUStackSpeech2TextModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="faster-whisper-medium",
|
||||||
|
credentials={
|
||||||
|
"endpoint_url": "invalid_url",
|
||||||
|
"api_key": "invalid_api_key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model="faster-whisper-medium",
|
||||||
|
credentials={
|
||||||
|
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = GPUStackSpeech2TextModel()
|
||||||
|
|
||||||
|
# Get the directory of the current file
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# Get assets directory
|
||||||
|
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||||
|
|
||||||
|
# Construct the path to the audio file
|
||||||
|
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||||
|
|
||||||
|
file = Path(audio_file_path).read_bytes()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="faster-whisper-medium",
|
||||||
|
credentials={
|
||||||
|
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||||
|
},
|
||||||
|
file=file,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
||||||
@ -0,0 +1,24 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = GPUStackText2SpeechModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="cosyvoice-300m-sft",
|
||||||
|
tenant_id="test",
|
||||||
|
credentials={
|
||||||
|
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||||
|
},
|
||||||
|
content_text="Hello world",
|
||||||
|
voice="Chinese Female",
|
||||||
|
)
|
||||||
|
|
||||||
|
content = b""
|
||||||
|
for chunk in result:
|
||||||
|
content += chunk
|
||||||
|
|
||||||
|
assert content != b""
|
||||||
@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def search_by_full_text(self):
|
def search_by_full_text(self):
|
||||||
# milvus dos not support full text searching yet in < 2.3.x
|
# milvus support BM25 full text search after version 2.5.0-beta
|
||||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||||
assert len(hits_by_full_text) == 0
|
assert len(hits_by_full_text) >= 0
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self):
|
def get_ids_by_metadata_field(self):
|
||||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ version: '3'
|
|||||||
services:
|
services:
|
||||||
# API service
|
# API service
|
||||||
api:
|
api:
|
||||||
image: langgenius/dify-api:0.14.2
|
image: langgenius/dify-api:0.15.0
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Startup mode, 'api' starts the API server.
|
# Startup mode, 'api' starts the API server.
|
||||||
@ -227,7 +227,7 @@ services:
|
|||||||
# worker service
|
# worker service
|
||||||
# The Celery worker for processing the queue.
|
# The Celery worker for processing the queue.
|
||||||
worker:
|
worker:
|
||||||
image: langgenius/dify-api:0.14.2
|
image: langgenius/dify-api:0.15.0
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
CONSOLE_WEB_URL: ''
|
CONSOLE_WEB_URL: ''
|
||||||
@ -397,7 +397,7 @@ services:
|
|||||||
|
|
||||||
# Frontend web application.
|
# Frontend web application.
|
||||||
web:
|
web:
|
||||||
image: langgenius/dify-web:0.14.2
|
image: langgenius/dify-web:0.15.0
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user