Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

# Conflicts:
#	api/core/memory/token_buffer_memory.py
#	api/core/rag/extractor/notion_extractor.py
#	api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
#	api/core/variables/variables.py
#	api/core/workflow/graph/graph.py
#	api/core/workflow/graph_engine/entities/event.py
#	api/services/dataset_service.py
#	web/app/components/app-sidebar/index.tsx
#	web/app/components/base/tag-management/selector.tsx
#	web/app/components/base/toast/index.tsx
#	web/app/components/datasets/create/website/index.tsx
#	web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx
#	web/app/components/workflow/header/version-history-button.tsx
#	web/app/components/workflow/hooks/use-inspect-vars-crud-common.ts
#	web/app/components/workflow/hooks/use-workflow-interactions.ts
#	web/app/components/workflow/panel/version-history-panel/index.tsx
#	web/service/base.ts
This commit is contained in:
jyong 2025-09-03 15:01:06 +08:00
commit d4aed3df5c
572 changed files with 16030 additions and 7973 deletions

View File

@ -42,11 +42,7 @@ jobs:
- name: Run Unit tests - name: Run Unit tests
run: | run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run ty check
run: |
cd api
uv add --dev ty
uv run ty check || true
- name: Run pyrefly check - name: Run pyrefly check
run: | run: |
cd api cd api
@ -66,15 +62,6 @@ jobs:
- name: Run dify config tests - name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py run: uv run --project api dev/pytest/pytest_config_tests.py
- name: MyPy Cache
uses: actions/cache@v4
with:
path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run MyPy Checks
run: dev/mypy-check
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
cp docker/.env.example docker/.env cp docker/.env.example docker/.env

View File

@ -26,6 +26,7 @@ jobs:
- name: ast-grep - name: ast-grep
run: | run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
- name: mdformat - name: mdformat
run: | run: |
uvx mdformat . uvx mdformat .

View File

@ -12,7 +12,6 @@ permissions:
statuses: write statuses: write
contents: read contents: read
jobs: jobs:
python-style: python-style:
name: Python Style name: Python Style
@ -44,21 +43,14 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev run: uv sync --project api --dev
- name: Ruff check - name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: | run: dev/basedpyright-check
uv run --directory api ruff --version
uv run --directory api ruff check ./
uv run --directory api ruff format --check ./
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
web-style: web-style:
name: Web Style name: Web Style
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -100,7 +92,9 @@ jobs:
- name: Web style check - name: Web style check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web working-directory: ./web
run: pnpm run lint run: |
pnpm run lint
pnpm run eslint
docker-compose-template: docker-compose-template:
name: Docker Compose Template name: Docker Compose Template

5
.gitignore vendored
View File

@ -123,10 +123,12 @@ venv.bak/
# mkdocs documentation # mkdocs documentation
/site /site
# mypy # type checking
.mypy_cache/ .mypy_cache/
.dmypy.json .dmypy.json
dmypy.json dmypy.json
pyrightconfig.json
!api/pyrightconfig.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
@ -195,7 +197,6 @@ sdks/python-client/dify_client.egg-info
.vscode/* .vscode/*
!.vscode/launch.json.template !.vscode/launch.json.template
!.vscode/README.md !.vscode/README.md
pyrightconfig.json
api/.vscode api/.vscode
# vscode Code History Extension # vscode Code History Extension
.history .history

View File

@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
./dev/reformat # Run all formatters and linters ./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code uv run --project api ruff format ./ # Format code
uv run --project api mypy . # Type checking uv run --directory api basedpyright # Type checking
``` ```
### Frontend (Web) ### Frontend (Web)

View File

@ -4,6 +4,48 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest VERSION=latest
# Backend Development Environment Setup
.PHONY: dev-setup prepare-docker prepare-web prepare-api
# Default dev setup target
dev-setup: prepare-docker prepare-web prepare-api
@echo "✅ Backend development environment setup complete!"
# Step 1: Prepare Docker middleware
prepare-docker:
@echo "🐳 Setting up Docker middleware..."
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
@echo "✅ Docker middleware started"
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
prepare-api:
@echo "🔧 Setting up API environment..."
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
@cd api && uv sync --dev --extra all
@cd api && uv run flask db upgrade
@echo "✅ API environment prepared (not started)"
# Clean dev environment
dev-clean:
@echo "⚠️ Stopping Docker containers..."
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
@echo "🗑️ Removing volumes..."
@rm -rf docker/volumes/db
@rm -rf docker/volumes/redis
@rm -rf docker/volumes/plugin_daemon
@rm -rf docker/volumes/weaviate
@rm -rf api/storage
@echo "✅ Cleanup complete"
# Build Docker images # Build Docker images
build-web: build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..." @echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@ -39,5 +81,21 @@ build-push-web: build-web push-web
build-push-all: build-all push-all build-push-all: build-all push-all
@echo "All Docker images have been built and pushed." @echo "All Docker images have been built and pushed."
# Help target
help:
@echo "Development Setup Targets:"
@echo " make dev-setup - Run all setup steps for backend dev environment"
@echo " make prepare-docker - Set up Docker middleware"
@echo " make prepare-web - Set up web environment"
@echo " make prepare-api - Set up API environment"
@echo " make dev-clean - Stop Docker middleware containers"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@echo " make build-api - Build API Docker image"
@echo " make build-all - Build all Docker images"
@echo " make push-all - Push all Docker images"
@echo " make build-push-all - Build and push all Docker images"
# Phony targets # Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all .PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help

View File

@ -34,17 +34,16 @@ ignore_imports =
[importlinter:contract:rsc] [importlinter:contract:rsc]
name = RSC name = RSC
type = layers type = layers
layers = layers =
graph_engine graph_engine
response_coordinator response_coordinator
output_registry
containers = containers =
core.workflow.graph_engine core.workflow.graph_engine
[importlinter:contract:worker] [importlinter:contract:worker]
name = Worker name = Worker
type = layers type = layers
layers = layers =
graph_engine graph_engine
worker worker
containers = containers =
@ -77,26 +76,15 @@ forbidden_modules =
core.workflow.graph_engine.layers core.workflow.graph_engine.layers
core.workflow.graph_engine.protocols core.workflow.graph_engine.protocols
[importlinter:contract:state-management-layers] [importlinter:contract:worker-management]
name = State Management Layers name = Worker Management
type = layers type = forbidden
layers = source_modules =
execution_tracker
node_state_manager
edge_state_manager
containers =
core.workflow.graph_engine.state_management
[importlinter:contract:worker-management-layers]
name = Worker Management Layers
type = layers
layers =
worker_pool
worker_factory
dynamic_scaler
activity_tracker
containers =
core.workflow.graph_engine.worker_management core.workflow.graph_engine.worker_management
forbidden_modules =
core.workflow.graph_engine.orchestration
core.workflow.graph_engine.command_processing
core.workflow.graph_engine.event_management
[importlinter:contract:error-handling-strategies] [importlinter:contract:error-handling-strategies]
name = Error Handling Strategies name = Error Handling Strategies
@ -109,14 +97,16 @@ modules =
[importlinter:contract:graph-traversal-components] [importlinter:contract:graph-traversal-components]
name = Graph Traversal Components name = Graph Traversal Components
type = independence type = layers
modules = layers =
core.workflow.graph_engine.graph_traversal.node_readiness edge_processor
core.workflow.graph_engine.graph_traversal.skip_propagator skip_propagator
containers =
core.workflow.graph_engine.graph_traversal
[importlinter:contract:command-channels] [importlinter:contract:command-channels]
name = Command Channels Independence name = Command Channels Independence
type = independence type = independence
modules = modules =
core.workflow.graph_engine.command_channels.in_memory_channel core.workflow.graph_engine.command_channels.in_memory_channel
core.workflow.graph_engine.command_channels.redis_channel core.workflow.graph_engine.command_channels.redis_channel

View File

@ -108,5 +108,5 @@ uv run celery -A app.celery beat
../dev/reformat # Run all formatters and linters ../dev/reformat # Run all formatters and linters
uv run ruff check --fix ./ # Fix linting issues uv run ruff check --fix ./ # Fix linting issues
uv run ruff format ./ # Format code uv run ruff format ./ # Format code
uv run mypy . # Type checking uv run basedpyright . # Type checking
``` ```

View File

@ -1,11 +0,0 @@
from tests.integration_tests.utils.parent_class import ParentClass
class ChildClass(ParentClass):
"""Test child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return f"Child: {self.name}"

View File

@ -577,7 +577,7 @@ def old_metadata_migration():
for document in documents: for document in documents:
if document.doc_metadata: if document.doc_metadata:
doc_metadata = document.doc_metadata doc_metadata = document.doc_metadata
for key, value in doc_metadata.items(): for key in doc_metadata:
for field in BuiltInField: for field in BuiltInField:
if field.value == key: if field.value == key:
break break

View File

@ -29,7 +29,7 @@ class NacosSettingsSource(RemoteSettingsSource):
try: try:
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params) content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
self.remote_configs = self._parse_config(content) self.remote_configs = self._parse_config(content)
except Exception as e: except Exception:
logger.exception("[get-access-token] exception occurred") logger.exception("[get-access-token] exception occurred")
raise raise

View File

@ -27,7 +27,7 @@ class NacosHttpClient:
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except requests.exceptions.RequestException as e: except requests.RequestException as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers, params, module="config"): def _inject_auth_info(self, headers, params, module="config"):
@ -77,6 +77,6 @@ class NacosHttpClient:
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000) self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10 self.token_expire_time = current_time + self.token_ttl - 10
except Exception as e: except Exception:
logger.exception("[get-access-token] exception occur") logger.exception("[get-access-token] exception occur")
raise raise

View File

@ -19,6 +19,7 @@ language_timezone_mapping = {
"fa-IR": "Asia/Tehran", "fa-IR": "Asia/Tehran",
"sl-SI": "Europe/Ljubljana", "sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok", "th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
} }
languages = list(language_timezone_mapping.keys()) languages = list(language_timezone_mapping.keys())

View File

@ -130,15 +130,19 @@ class InsertExploreAppApi(Resource):
app.is_public = False app.is_public = False
with Session(db.engine) as session: with Session(db.engine) as session:
installed_apps = session.execute( installed_apps = (
select(InstalledApp).where( session.execute(
InstalledApp.app_id == recommended_app.app_id, select(InstalledApp).where(
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
) )
).all() .scalars()
.all()
)
for installed_app in installed_apps: for installed_app in installed_apps:
db.session.delete(installed_app) session.delete(installed_app)
db.session.delete(recommended_app) db.session.delete(recommended_app)
db.session.commit() db.session.commit()

View File

@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource):
flask_restx.abort( flask_restx.abort(
400, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded", custom="max_keys_exceeded",
) )
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)

View File

@ -237,9 +237,14 @@ class AppExportApi(Resource):
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} return {
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
class AppNameApi(Resource): class AppNameApi(Resource):

View File

@ -130,7 +130,7 @@ class MessageFeedbackApi(Resource):
message_id = str(args["message_id"]) message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message: if not message:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

View File

@ -532,7 +532,7 @@ class PublishedWorkflowApi(Resource):
) )
app_model.workflow_id = workflow.id app_model.workflow_id = workflow.id
db.session.commit() db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
workflow_created_at = TimestampField().format(workflow.created_at) workflow_created_at = TimestampField().format(workflow.created_at)

View File

@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") parser.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
parser.add_argument( parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp" "created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
) )

View File

@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400 return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e: except requests.HTTPError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )
@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
try: try:
oauth_provider.sync_data_source(binding_id) oauth_provider.sync_data_source(binding_id)
except requests.exceptions.HTTPError as e: except requests.HTTPError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )

View File

@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
if account is None: if account is None:
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
if account is None: if account is None:
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args["token"]) AccountService.revoke_email_code_login_token(args["token"])
try: try:
account = AccountService.get_user_through_email(user_email) account = AccountService.get_user_through_email(user_email)
except AccountRegisterError as are: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
if account: if account:
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace() raise NotAllowedCreateWorkspace()
except AccountRegisterError as are: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
except WorkspacesLimitExceededError: except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()

View File

@ -80,7 +80,7 @@ class OAuthCallback(Resource):
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e: except requests.RequestException as e:
error_text = e.response.text if e.response else str(e) error_text = e.response.text if e.response else str(e)
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400

View File

@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app") raise BadRequest("Invalid oauth_provider_app")
if not request.headers.get("Authorization"):
raise BadRequest("Authorization is required")
authorization_header = request.headers.get("Authorization") authorization_header = request.headers.get("Authorization")
if not authorization_header: if not authorization_header:
raise BadRequest("Authorization header is required") raise BadRequest("Authorization header is required")
parts = authorization_header.split(" ") parts = authorization_header.strip().split(" ")
if len(parts) != 2: if len(parts) != 2:
raise BadRequest("Invalid Authorization header format") raise BadRequest("Invalid Authorization header format")
token_type = parts[0] token_type = parts[0].strip()
if token_type != "Bearer": if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid") raise BadRequest("token_type is invalid")
access_token = parts[1] access_token = parts[1].strip()
if not access_token: if not access_token:
raise BadRequest("access_token is required") raise BadRequest("access_token is required")
@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
parser.add_argument("refresh_token", type=str, required=False, location="json") parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
grant_type = OAuthGrantType(parsed_args["grant_type"]) try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE: if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]: if not parsed_args["code"]:
@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
"refresh_token": refresh_token, "refresh_token": refresh_token,
} }
) )
else:
raise BadRequest("invalid grant_type")
class OAuthServerUserAccountApi(Resource): class OAuthServerUserAccountApi(Resource):

View File

@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
@ -248,7 +249,7 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id") credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="notion_import", datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
"credential_id": credential_id, "credential_id": credential_id,
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,

View File

@ -21,6 +21,7 @@ from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
@ -431,7 +432,9 @@ class DatasetIndexingEstimateApi(Resource):
if file_details: if file_details:
for file_detail in file_details: for file_detail in file_details:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import": elif args["info_list"]["data_source_type"] == "notion_import":
@ -441,7 +444,7 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id") credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="notion_import", datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
"credential_id": credential_id, "credential_id": credential_id,
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
@ -456,7 +459,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"] website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]: for url in website_info_list["urls"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="website_crawl", datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info={
"provider": website_info_list["provider"], "provider": website_info_list["provider"],
"job_id": website_info_list["job_id"], "job_id": website_info_list["job_id"],

View File

@ -41,6 +41,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import ( from fields.document_fields import (
@ -356,9 +357,6 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
@ -430,7 +428,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file, document_model=document.doc_form datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
) )
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
@ -490,13 +488,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import": elif document.data_source_type == "notion_import":
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="notion_import", datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
"credential_id": data_source_info["credential_id"], "credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
@ -509,7 +507,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl": elif document.data_source_type == "website_crawl":
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="website_crawl", datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info={
"provider": data_source_info["provider"], "provider": data_source_info["provider"],
"job_id": data_source_info["job_id"], "job_id": data_source_info["job_id"],

View File

@ -61,7 +61,6 @@ class ConversationApi(InstalledAppResource):
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()

View File

@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
model_load_balancing_service = ModelLoadBalancingService() model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
) )
if args.get("config_from", "") == "predefined-model": if args.get("config_from", "") == "predefined-model":
@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
choices=[mt.value for mt in ModelType], choices=[mt.value for mt in ModelType],
location="json", location="json",
) )
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
) )
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()

View File

@ -1,8 +1,12 @@
from base64 import b64encode from base64 import b64encode
from collections.abc import Callable
from functools import wraps from functools import wraps
from hashlib import sha1 from hashlib import sha1
from hmac import new as hmac_new from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request from flask import abort, request
from configs import dify_config from configs import dify_config
@ -10,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
def billing_inner_api_only(view): def billing_inner_api_only(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API: if not dify_config.INNER_API:
abort(404) abort(404)
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
return decorated return decorated
def enterprise_inner_api_only(view): def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API: if not dify_config.INNER_API:
abort(404) abort(404)
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
return decorated return decorated
def plugin_inner_api_only(view): def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY: if not dify_config.PLUGIN_DAEMON_KEY:
abort(404) abort(404)

View File

@ -55,7 +55,7 @@ class AudioApi(Resource):
file = request.files["file"] file = request.files["file"]
try: try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
args = file_preview_parser.parse_args() args = file_preview_parser.parse_args()
# Validate file ownership and get file objects # Validate file ownership and get file objects
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id) _, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator # Get file content generator
try: try:

View File

@ -413,7 +413,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
knowledge_config=knowledge_config, knowledge_config=knowledge_config,
account=dataset.created_by_account, account=dataset.created_by_account,

View File

@ -1,7 +1,7 @@
import time import time
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import StrEnum, auto
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService from services.feature_service import FeatureService
class WhereisUserArg(Enum): class WhereisUserArg(StrEnum):
""" """
Enum for whereis_user_arg. Enum for whereis_user_arg.
""" """
QUERY = "query" QUERY = auto()
JSON = "json" JSON = auto()
FORM = "form" FORM = auto()
class FetchUserArg(BaseModel): class FetchUserArg(BaseModel):
@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
if not user_id: if not user_id:
user_id = "DEFAULT-USER" user_id = "DEFAULT-USER"
end_user = ( with Session(db.engine, expire_on_commit=False) as session:
db.session.query(EndUser) end_user = (
.where( session.query(EndUser)
EndUser.tenant_id == app_model.tenant_id, .where(
EndUser.app_id == app_model.id, EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id, EndUser.app_id == app_model.id,
EndUser.type == "service_api", EndUser.session_id == user_id,
EndUser.type == "service_api",
)
.first()
) )
.first()
)
if end_user is None: if end_user is None:
end_user = EndUser( end_user = EndUser(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_id=app_model.id, app_id=app_model.id,
type="service_api", type="service_api",
is_anonymous=user_id == "DEFAULT-USER", is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id, session_id=user_id,
) )
db.session.add(end_user) session.add(end_user)
db.session.commit() session.commit()
return end_user return end_user

View File

@ -73,8 +73,6 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user) ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -4,6 +4,7 @@ from functools import wraps
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -49,18 +50,19 @@ def decode_jwt_token():
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
app_code = decoded.get("app_code") app_code = decoded.get("app_code")
app_id = decoded.get("app_id") app_id = decoded.get("app_id")
app_model = db.session.scalar(select(App).where(App.id == app_id)) with Session(db.engine, expire_on_commit=False) as session:
site = db.session.scalar(select(Site).where(Site.code == app_code)) app_model = session.scalar(select(App).where(App.id == app_id))
if not app_model: site = session.scalar(select(Site).where(Site.code == app_code))
raise NotFound() if not app_model:
if not app_code or not site: raise NotFound()
raise BadRequest("Site URL is no longer valid.") if not app_code or not site:
if app_model.enable_site is False: raise BadRequest("Site URL is no longer valid.")
raise BadRequest("Site is disabled.") if app_model.enable_site is False:
end_user_id = decoded.get("end_user_id") raise BadRequest("Site is disabled.")
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) end_user_id = decoded.get("end_user_id")
if not end_user: end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
raise NotFound() if not end_user:
raise NotFound()
# for enterprise webapp auth # for enterprise webapp auth
app_web_auth_enabled = False app_web_auth_enabled = False

View File

@ -336,7 +336,8 @@ class BaseAgentRunner(AppRunner):
""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first() stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
agent_thought = db.session.scalar(stmt)
if not agent_thought: if not agent_thought:
raise ValueError("agent thought not found") raise ValueError("agent thought not found")
@ -494,7 +495,8 @@ class BaseAgentRunner(AppRunner):
return result return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() stmt = select(MessageFile).where(MessageFile.message_id == message.id)
files = db.session.scalars(stmt).all()
if not files: if not files:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
if message.app_model_config: if message.app_model_config:

View File

@ -1,12 +1,14 @@
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, ConfigDict, Field, ValidationError
class MoreLikeThisConfig(BaseModel): class MoreLikeThisConfig(BaseModel):
enabled: bool = False enabled: bool = False
model_config = ConfigDict(extra="allow")
class AppConfigModel(BaseModel): class AppConfigModel(BaseModel):
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig) more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
model_config = ConfigDict(extra="allow")
class MoreLikeThisConfigManager: class MoreLikeThisConfigManager:
@ -23,8 +25,8 @@ class MoreLikeThisConfigManager:
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
try: try:
return AppConfigModel.model_validate(config).dict(), ["more_like_this"] return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError as e: except ValidationError:
raise ValueError( raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type" "more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
) )

View File

@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start() worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
# return response or stream generator # return response or stream generator
response = self._handle_advanced_chat_response( response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,

View File

@ -73,7 +73,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first() with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
@ -147,7 +149,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables, environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`, # Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), conversation_variables=conversation_variables,
) )
# init graph # init graph

View File

@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())

View File

@ -68,7 +68,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile from models import Conversation, EndUser, Message, MessageFile
@ -306,13 +305,8 @@ class AdvancedChatAppGenerateTaskPipeline:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
def _handle_workflow_started_event( def _handle_workflow_started_event(self, **kwargs) -> Generator[StreamResponse, None, None]:
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events.""" """Handle workflow started events."""
# Override graph runtime state - this is a side effect but necessary
graph_runtime_state = event.graph_runtime_state
with self._database_session() as session: with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_ self._workflow_run_id = workflow_execution.id_
@ -333,15 +327,14 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events.""" """Handle node retry events."""
self._ensure_workflow_initialized() self._ensure_workflow_initialized()
with self._database_session() as session: workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( workflow_execution_id=self._workflow_run_id, event=event
workflow_execution_id=self._workflow_run_id, event=event )
) node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response( event=event,
event=event, task_id=self._application_generate_entity.task_id,
task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution,
workflow_node_execution=workflow_node_execution, )
)
if node_retry_resp: if node_retry_resp:
yield node_retry_resp yield node_retry_resp
@ -375,13 +368,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
) )
with self._database_session() as session: workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event) node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event,
event=event, task_id=self._application_generate_entity.task_id,
task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution,
workflow_node_execution=workflow_node_execution, )
)
self._save_output_for_event(event, workflow_node_execution.id) self._save_output_for_event(event, workflow_node_execution.id)
@ -886,10 +878,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.metadata.usage = usage self._task_state.metadata.usage = usage
else: else:
self._task_state.metadata.usage = LLMUsage.empty_usage() self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse: def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
""" """

View File

@ -1,6 +1,8 @@
import logging import logging
from typing import cast from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity from core.agent.entities import AgentEntity
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
""" """
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config) app_config = cast(AgentChatAppConfig, app_config)
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.query(App).where(App.id == app_config.app_id).first() app_record = db.session.scalar(app_stmt)
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None: if conversation_result is None:
raise ValueError("Conversation not found") raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first() msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None: if message_result is None:
raise ValueError("Message not found") raise ValueError("Message not found")
db.session.close() db.session.close()

View File

@ -1,7 +1,7 @@
import queue import queue
import time import time
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import IntEnum, auto
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
class PublishFrom(Enum): class PublishFrom(IntEnum):
APPLICATION_MANAGER = 1 APPLICATION_MANAGER = auto()
TASK_PIPELINE = 2 TASK_PIPELINE = auto()
class AppQueueManager: class AppQueueManager:
@ -174,7 +174,7 @@ class AppQueueManager:
def _check_for_sqlalchemy_models(self, data: Any): def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list # from entity to dict or list
if isinstance(data, dict): if isinstance(data, dict):
for key, value in data.items(): for value in data.values():
self._check_for_sqlalchemy_models(value) self._check_for_sqlalchemy_models(value)
elif isinstance(data, list): elif isinstance(data, list):
for item in data: for item in data:

View File

@ -1,6 +1,8 @@
import logging import logging
from typing import cast from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.apps.chat.app_config_manager import ChatAppConfig
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
""" """
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config) app_config = cast(ChatAppConfig, app_config)
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.query(App).where(App.id == app_config.app_id).first() app_record = db.session.scalar(stmt)
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")

View File

@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
message = ( stmt = select(Message).where(
db.session.query(Message) Message.id == message_id,
.where( Message.app_id == app_model.id,
Message.id == message_id, Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.app_id == app_model.id, Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_source == ("api" if isinstance(user, EndUser) else "console"), Message.from_account_id == (user.id if isinstance(user, Account) else None),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
) )
message = db.session.scalar(stmt)
if not message: if not message:
raise MessageNotExistsError() raise MessageNotExistsError()

View File

@ -1,6 +1,8 @@
import logging import logging
from typing import cast from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.apps.completion.app_config_manager import CompletionAppConfig
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
""" """
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config) app_config = cast(CompletionAppConfig, app_config)
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.query(App).where(App.id == app_config.app_id).first() app_record = db.session.scalar(stmt)
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")

View File

@ -3,6 +3,9 @@ import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation: if conversation:
app_model_config = ( stmt = select(AppModelConfig).where(
db.session.query(AppModelConfig) AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
) )
app_model_config = db.session.scalar(stmt)
if not app_model_config: if not app_model_config:
raise AppModelConfigBrokenError() raise AppModelConfigBrokenError()
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id :param conversation_id: conversation id
:return: conversation :return: conversation
""" """
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() with Session(db.engine, expire_on_commit=False) as session:
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation: if not conversation:
raise ConversationNotExistsError("Conversation not exists") raise ConversationNotExistsError("Conversation not exists")
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id :param message_id: message id
:return: message :return: message
""" """
message = db.session.query(Message).where(Message.id == message_id).first() with Session(db.engine, expire_on_commit=False) as session:
message = session.scalar(select(Message).where(Message.id == message_id))
if message is None: if message is None:
raise MessageNotExistsError("Message not exists") raise MessageNotExistsError("Message not exists")

View File

@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else: else:
response_chunk.update(sub_stream_response.to_dict()) response_chunk.update(sub_stream_response.to_dict())
yield response_chunk yield response_chunk

View File

@ -296,16 +296,15 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events.""" """Handle node retry events."""
self._ensure_workflow_initialized() self._ensure_workflow_initialized()
with self._database_session() as session: workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_run_id, event=event,
event=event, )
) response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
response = self._workflow_response_converter.workflow_node_retry_to_stream_response( event=event,
event=event, task_id=self._application_generate_entity.task_id,
task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution,
workflow_node_execution=workflow_node_execution, )
)
if response: if response:
yield response yield response

View File

@ -1,6 +1,8 @@
import logging import logging
from typing import Optional from typing import Optional
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db from extensions.ext_database import db
@ -25,9 +27,8 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from :param invoke_from: invoke from
:return: :return:
""" """
annotation_setting = ( stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() annotation_setting = db.session.scalar(stmt)
)
if not annotation_setting: if not annotation_setting:
return None return None

View File

@ -96,7 +96,11 @@ class RateLimit:
if isinstance(generator, Mapping): if isinstance(generator, Mapping):
return generator return generator
else: else:
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id) return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
request_id=request_id,
)
class RateLimitGenerator: class RateLimitGenerator:

View File

@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError): if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided") err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError): elif isinstance(e, InvokeError | ValueError):
err = e err = e # ty: ignore [invalid-assignment]
else: else:
description = getattr(e, "description", None) description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e)) err = Exception(description if description is not None else str(e))

View File

@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param event: agent thought event :param event: agent thought event
:return: :return:
""" """
agent_thought: Optional[MessageAgentThought] = ( with Session(db.engine, expire_on_commit=False) as session:
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() agent_thought: Optional[MessageAgentThought] = (
) session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought: if agent_thought:
return AgentThoughtStreamResponse( return AgentThoughtStreamResponse(

View File

@ -3,6 +3,8 @@ from threading import Thread
from typing import Optional, Union from typing import Optional, Union
from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
@ -84,7 +86,8 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context(): with flask_app.app_context():
# get conversation and message # get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)
if not conversation: if not conversation:
return return
@ -98,7 +101,7 @@ class MessageCycleManager:
try: try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query) name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
conversation.name = name conversation.name = name
except Exception as e: except Exception:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id) logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
pass pass
@ -143,7 +146,8 @@ class MessageCycleManager:
:param event: event :param event: event
:return: :return:
""" """
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first() with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None: if message_file and message_file.url is not None:
# get tool file id # get tool file id
@ -183,7 +187,8 @@ class MessageCycleManager:
:param message_id: message id :param message_id: message id
:return: :return:
""" """
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first() with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse( return MessageStreamResponse(

View File

@ -1,6 +1,8 @@
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
for document in documents: for document in documents:
if document.metadata is not None: if document.metadata is not None:
document_id = document.metadata["document_id"] document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document: if not dataset_document:
_logger.warning( _logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s", "Expected DatasetDocument record to exist, but none was found, document_id=%s",
@ -57,17 +60,14 @@ class DatasetIndexToolCallbackHandler:
) )
continue continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ( child_chunk_stmt = select(ChildChunk).where(
db.session.query(ChildChunk) ChildChunk.index_node_id == document.metadata["doc_id"],
.where( ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.document_id == dataset_document.id,
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
) )
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk: if child_chunk:
segment = ( _ = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id) .where(DocumentSegment.id == child_chunk.segment_id)
.update( .update(

View File

@ -1,5 +1,6 @@
import json import json
import logging import logging
import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator, Sequence from collections.abc import Iterator, Sequence
from json import JSONDecodeError from json import JSONDecodeError
@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
with Session(db.engine) as new_session: with Session(db.engine) as new_session:
return _validate(new_session) return _validate(new_session)
def create_provider_credential(self, credentials: dict, credential_name: str) -> None: def _generate_provider_credential_name(self, session) -> str:
"""
Generate a unique credential name for provider.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
),
)
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
"""
Generate a unique credential name for custom model.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
def _generate_next_api_key_name(self, session, query_factory) -> str:
"""
Generate next available API KEY name by finding the highest numbered suffix.
:param session: database session
:param query_factory: function that returns the SQLAlchemy query
:return: next available API KEY name
"""
try:
stmt = query_factory()
credential_records = session.execute(stmt).scalars().all()
if not credential_records:
return "API KEY 1"
# Extract numbers from API KEY pattern using list comprehension
pattern = re.compile(r"^API KEY\s+(\d+)$")
numbers = [
int(match.group(1))
for cr in credential_records
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
]
# Return next sequential number
next_number = max(numbers, default=0) + 1
return f"API KEY {next_number}"
except Exception as e:
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
""" """
Add custom provider credentials. Add custom provider credentials.
:param credentials: provider credentials :param credentials: provider credentials
@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
with Session(db.engine) as session: with Session(db.engine) as session:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session
):
raise ValueError(f"Credential with name '{credential_name}' already exists.") raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session) credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session) provider_record = self._get_provider_record(session)
@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel):
self, self,
credentials: dict, credentials: dict,
credential_id: str, credential_id: str,
credential_name: str, credential_name: str | None,
) -> None: ) -> None:
""" """
update a saved provider credential (by credential_id). update a saved provider credential (by credential_id).
@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
with Session(db.engine) as session: with Session(db.engine) as session:
if self._check_provider_credential_name_exists( if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id credential_name=credential_name, session=session, exclude_id=credential_id
): ):
raise ValueError(f"Credential with name '{credential_name}' already exists.") raise ValueError(f"Credential with name '{credential_name}' already exists.")
@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel):
try: try:
# Update credential # Update credential
credential_record.encrypted_config = json.dumps(credentials) credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now() credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit() session.commit()
if provider_record and provider_record.credential_id == credential_id: if provider_record and provider_record.credential_id == credential_id:
@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
) )
lb_credentials_cache.delete() lb_credentials_cache.delete()
session.delete(lb_config)
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
# Check if this is the currently active credential # Check if this is the currently active credential
provider_record = self._get_provider_record(session) provider_record = self._get_provider_record(session)
@ -823,7 +880,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session) return _validate(new_session)
def create_custom_model_credential( def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None: ) -> None:
""" """
Create a custom model credential. Create a custom model credential.
@ -834,10 +891,14 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
with Session(db.engine) as session: with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists( if credential_name and self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session model=model, model_type=model_type, credential_name=credential_name, session=session
): ):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
)
# validate custom model config # validate custom model config
credentials = self.validate_custom_model_credentials( credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session model_type=model_type, model=model, credentials=credentials, session=session
@ -881,7 +942,7 @@ class ProviderConfiguration(BaseModel):
raise raise
def update_custom_model_credential( def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None: ) -> None:
""" """
Update a custom model credential. Update a custom model credential.
@ -894,7 +955,7 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
with Session(db.engine) as session: with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists( if credential_name and self._check_custom_model_credential_name_exists(
model=model, model=model,
model_type=model_type, model_type=model_type,
credential_name=credential_name, credential_name=credential_name,
@ -926,8 +987,9 @@ class ProviderConfiguration(BaseModel):
try: try:
# Update credential # Update credential
credential_record.encrypted_config = json.dumps(credentials) credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now() credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit() session.commit()
if provider_model_record and provider_model_record.credential_id == credential_id: if provider_model_record and provider_model_record.credential_id == credential_id:
@ -983,12 +1045,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
) )
lb_credentials_cache.delete() lb_credentials_cache.delete()
lb_config.credential_id = None session.delete(lb_config)
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
# Check if this is the currently active credential # Check if this is the currently active credential
provider_model_record = self._get_custom_model_record(model_type, model, session=session) provider_model_record = self._get_custom_model_record(model_type, model, session=session)
@ -1055,6 +1112,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider, provider_name=self.provider.provider,
model_name=model, model_name=model,
model_type=model_type.to_origin_model_type(), model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id, credential_id=credential_id,
) )
else: else:
@ -1608,11 +1666,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "custom_model" if config.credential_source_type != "custom_model"
] ]
if len(provider_model_lb_configs) > 1: load_balancing_enabled = model_setting.load_balancing_enabled
load_balancing_enabled = True # when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
if any(config.name == "__delete__" for config in provider_model_lb_configs):
has_invalid_load_balancing_configs = True
provider_models.append( provider_models.append(
ModelWithProviderEntity( ModelWithProviderEntity(
@ -1634,6 +1690,8 @@ class ProviderConfiguration(BaseModel):
for model_configuration in self.custom_configuration.models: for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types: if model_configuration.model_type not in model_types:
continue continue
if model_configuration.unadded_to_model_list:
continue
if model and model != model_configuration.model: if model and model != model_configuration.model:
continue continue
try: try:
@ -1666,11 +1724,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "provider" if config.credential_source_type != "provider"
] ]
if len(custom_model_lb_configs) > 1: load_balancing_enabled = model_setting.load_balancing_enabled
load_balancing_enabled = True # when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
if any(config.name == "__delete__" for config in custom_model_lb_configs):
has_invalid_load_balancing_configs = True
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
status = ModelStatus.CREDENTIAL_REMOVED status = ModelStatus.CREDENTIAL_REMOVED

View File

@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
current_credential_id: Optional[str] = None current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = [] available_model_credentials: list[CredentialConfiguration] = []
unadded_to_model_list: Optional[bool] = False
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class UnaddedModelConfiguration(BaseModel):
"""
Model class for provider unadded model configuration.
"""
model: str
model_type: ModelType
class CustomConfiguration(BaseModel): class CustomConfiguration(BaseModel):
""" """
Model class for provider custom configuration. Model class for provider custom configuration.
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
provider: Optional[CustomProviderConfiguration] = None provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = [] models: list[CustomModelConfiguration] = []
can_added_models: list[UnaddedModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel): class ModelLoadBalancingConfiguration(BaseModel):
@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
model: str model: str
model_type: ModelType model_type: ModelType
enabled: bool = True enabled: bool = True
load_balancing_enabled: bool = False
load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
# pydantic configs # pydantic configs

View File

@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
timeout=self.timeout, timeout=self.timeout,
proxies=proxies, proxies=proxies,
) )
except requests.exceptions.Timeout: except requests.Timeout:
raise ValueError("request timeout") raise ValueError("request timeout")
except requests.exceptions.ConnectionError: except requests.ConnectionError:
raise ValueError("request connection error") raise ValueError("request connection error")
if response.status_code != 200: if response.status_code != 200:

View File

@ -91,7 +91,7 @@ class Extensible:
# Find extension class # Find extension class
extension_class = None extension_class = None
for name, obj in vars(mod).items(): for obj in vars(mod).values():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj extension_class = obj
break break
@ -123,7 +123,7 @@ class Extensible:
) )
) )
except Exception as e: except Exception:
logger.exception("Error scanning extensions") logger.exception("Error scanning extensions")
raise raise

View File

@ -41,9 +41,3 @@ class Extension:
assert module_extension.extension_class is not None assert module_extension.extension_class is not None
t: type = module_extension.extension_class t: type = module_extension.extension_class
return t return t
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
form_schema = module_extension.form_schema
# TODO validate form_schema

View File

@ -1,5 +1,7 @@
from typing import Optional from typing import Optional
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter from core.helper import encrypter
@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = config.get("api_based_extension_id") api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id: if not api_based_extension_id:
raise ValueError("api_based_extension_id is required") raise ValueError("api_based_extension_id is required")
# get api_based_extension # get api_based_extension
api_based_extension = ( stmt = select(APIBasedExtension).where(
db.session.query(APIBasedExtension) APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
) )
api_based_extension = db.session.scalar(stmt)
if not api_based_extension: if not api_based_extension:
raise ValueError("api_based_extension_id is invalid") raise ValueError("api_based_extension_id is invalid")
@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError(f"config is required, config: {self.config}") raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id") api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required" assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension # get api_based_extension
api_based_extension = ( stmt = select(APIBasedExtension).where(
db.session.query(APIBasedExtension) APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
) )
api_based_extension = db.session.scalar(stmt)
if not api_based_extension: if not api_based_extension:
raise ValueError( raise ValueError(

View File

@ -22,7 +22,6 @@ class ExternalDataToolFactory:
:param config: the form config data :param config: the form config data
:return: :return:
""" """
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
# FIXME mypy issue here, figure out how to fix it # FIXME mypy issue here, figure out how to fix it
extension_class.validate_config(tenant_id, config) # type: ignore extension_class.validate_config(tenant_id, config) # type: ignore

View File

@ -3,7 +3,7 @@ import base64
from libs import rsa from libs import rsa
def obfuscated_token(token: str): def obfuscated_token(token: str) -> str:
if not token: if not token:
return token return token
if len(token) <= 8: if len(token) <= 8:
@ -11,6 +11,10 @@ def obfuscated_token(token: str):
return token[:6] + "*" * 12 + token[-2:] return token[:6] + "*" * 12 + token[-2:]
def full_mask_token(token_length=20):
return "*" * token_length
def encrypt_token(tenant_id: str, token: str): def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Tenant from models.account import Tenant

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence from collections.abc import Sequence
import requests import httpx
from yarl import URL from yarl import URL
from configs import dify_config from configs import dify_config
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return [] return []
url = str(marketplace_api_url / "api/v1/plugins/batch") url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids}) response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status() response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
@ -36,13 +36,13 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
return [] return []
url = str(marketplace_api_url / "api/v1/plugins/batch") url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids}) response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status() response.raise_for_status()
result: list[MarketplacePluginDeclaration] = [] result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]: for plugin in response.json()["data"]["plugins"]:
try: try:
result.append(MarketplacePluginDeclaration(**plugin)) result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e: except Exception:
pass pass
return result return result
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
def record_install_plugin_event(plugin_unique_identifier: str): def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count") url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status() response.raise_for_status()

View File

@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
def load_single_subclass_from_source( def load_single_subclass_from_source(
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False *, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
) -> type: ) -> type:
""" """
Load a single subclass from the source Load a single subclass from the source

View File

@ -5,9 +5,10 @@ import re
import threading import threading
import time import time
import uuid import uuid
from typing import Any, Optional, cast from typing import Any, Optional
from flask import current_app from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config from configs import dify_config
@ -18,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
@ -56,13 +58,11 @@ class IndexingRunner:
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get the process rule # get the process rule
processing_rule = ( stmt = select(DatasetProcessRule).where(
db.session.query(DatasetProcessRule) DatasetProcessRule.id == dataset_document.dataset_process_rule_id
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
) )
processing_rule = db.session.scalar(stmt)
if not processing_rule: if not processing_rule:
raise ValueError("no process rule found") raise ValueError("no process rule found")
index_type = dataset_document.doc_form index_type = dataset_document.doc_form
@ -123,11 +123,8 @@ class IndexingRunner:
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit() db.session.commit()
# get the process rule # get the process rule
processing_rule = ( stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
db.session.query(DatasetProcessRule) processing_rule = db.session.scalar(stmt)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule: if not processing_rule:
raise ValueError("no process rule found") raise ValueError("no process rule found")
@ -208,7 +205,6 @@ class IndexingRunner:
child_documents.append(child_document) child_documents.append(child_document)
document.children = child_documents document.children = child_documents
documents.append(document) documents.append(document)
# build index # build index
index_type = dataset_document.doc_form index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
@ -310,7 +306,8 @@ class IndexingRunner:
# delete image files and related db records # delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content) image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids: for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
image_file = db.session.scalar(stmt)
if image_file is None: if image_file is None:
continue continue
try: try:
@ -339,14 +336,14 @@ class IndexingRunner:
if dataset_document.data_source_type == "upload_file": if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info: if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found") raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = ( file_detail = db.session.scalars(stmt).one_or_none()
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
if file_detail: if file_detail:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=dataset_document.doc_form,
) )
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "notion_import": elif dataset_document.data_source_type == "notion_import":
@ -357,7 +354,7 @@ class IndexingRunner:
): ):
raise ValueError("no notion import info found") raise ValueError("no notion import info found")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="notion_import", datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
"credential_id": data_source_info["credential_id"], "credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
@ -378,7 +375,7 @@ class IndexingRunner:
): ):
raise ValueError("no website import info found") raise ValueError("no website import info found")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="website_crawl", datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info={
"provider": data_source_info["provider"], "provider": data_source_info["provider"],
"job_id": data_source_info["job_id"], "job_id": data_source_info["job_id"],
@ -401,7 +398,6 @@ class IndexingRunner:
) )
# replace doc id to document model id # replace doc id to document model id
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs: for text_doc in text_docs:
if text_doc.metadata is not None: if text_doc.metadata is not None:
text_doc.metadata["document_id"] = dataset_document.id text_doc.metadata["document_id"] = dataset_document.id

View File

@ -58,11 +58,8 @@ class LLMGenerator:
prompts = [UserPromptMessage(content=prompt)] prompts = [UserPromptMessage(content=prompt)]
with measure_time() as timer: with measure_time() as timer:
response = cast( response: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
),
) )
answer = cast(str, response.message.content) answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
@ -71,7 +68,7 @@ class LLMGenerator:
try: try:
result_dict = json.loads(cleaned_answer) result_dict = json.loads(cleaned_answer)
answer = result_dict["Your Output"] answer = result_dict["Your Output"]
except json.JSONDecodeError as e: except json.JSONDecodeError:
logger.exception("Failed to generate name after answer, use query instead") logger.exception("Failed to generate name after answer, use query instead")
answer = query answer = query
name = answer.strip() name = answer.strip()
@ -115,13 +112,10 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)] prompt_messages = [UserPromptMessage(content=prompt)]
try: try:
response = cast( response: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(prompt_messages),
model_instance.invoke_llm( model_parameters={"max_tokens": 256, "temperature": 0},
prompt_messages=list(prompt_messages), stream=False,
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
) )
text_content = response.message.get_text_content() text_content = response.message.get_text_content()
@ -164,11 +158,8 @@ class LLMGenerator:
) )
try: try:
response = cast( response: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
) )
rule_config["prompt"] = cast(str, response.message.content) rule_config["prompt"] = cast(str, response.message.content)
@ -214,11 +205,8 @@ class LLMGenerator:
try: try:
try: try:
# the first step to generate the task prompt # the first step to generate the task prompt
prompt_content = cast( prompt_content: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
) )
except InvokeError as e: except InvokeError as e:
error = str(e) error = str(e)
@ -250,11 +238,8 @@ class LLMGenerator:
statement_messages = [UserPromptMessage(content=statement_generate_prompt)] statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try: try:
parameter_content = cast( parameter_content: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
),
) )
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
except InvokeError as e: except InvokeError as e:
@ -262,11 +247,8 @@ class LLMGenerator:
error_step = "generate variables" error_step = "generate variables"
try: try:
statement_content = cast( statement_content: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
),
) )
rule_config["opening_statement"] = cast(str, statement_content.message.content) rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e: except InvokeError as e:
@ -309,11 +291,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)] prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = model_config.get("completion_params", {}) model_parameters = model_config.get("completion_params", {})
try: try:
response = cast( response: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
) )
generated_code = cast(str, response.message.content) generated_code = cast(str, response.message.content)
@ -340,13 +319,10 @@ class LLMGenerator:
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
response = cast( response: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=prompt_messages,
model_instance.invoke_llm( model_parameters={"temperature": 0.01, "max_tokens": 2000},
prompt_messages=prompt_messages, stream=False,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
),
) )
answer = cast(str, response.message.content) answer = cast(str, response.message.content)
@ -369,11 +345,8 @@ class LLMGenerator:
model_parameters = model_config.get("model_parameters", {}) model_parameters = model_config.get("model_parameters", {})
try: try:
response = cast( response: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
) )
raw_content = response.message.content raw_content = response.message.content
@ -560,11 +533,8 @@ class LLMGenerator:
model_parameters = {"temperature": 0.4} model_parameters = {"temperature": 0.4}
try: try:
response = cast( response: LLMResult = model_instance.invoke_llm(
LLMResult, prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
) )
generated_raw = cast(str, response.message.content) generated_raw = cast(str, response.message.content)

View File

@ -101,7 +101,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery.""" """Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True) b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}" url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
if b_query: if b_query:
url_for_resource_discovery += f"?{b_query}" url_for_resource_discovery += f"?{b_query}"
@ -117,7 +117,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else: else:
return False, "" return False, ""
return False, "" return False, ""
except httpx.RequestError as e: except httpx.RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata # Not support resource discovery, fall back to well-known OAuth metadata
return False, "" return False, ""

View File

@ -2,7 +2,7 @@ import logging
from collections.abc import Callable from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack from contextlib import AbstractContextManager, ExitStack
from types import TracebackType from types import TracebackType
from typing import Any, Optional, cast from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client from core.mcp.client.sse_client import sse_client
@ -116,8 +116,7 @@ class MCPClient:
self._session_context = ClientSession(*streams) self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context) self._session = self._exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session) self._session.initialize()
session.initialize()
return return
except MCPAuthError: except MCPAuthError:

View File

@ -2,7 +2,6 @@ from collections.abc import Sequence
from typing import Optional from typing import Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager from core.file import file_manager
@ -33,12 +32,7 @@ class TokenBufferMemory:
self.model_instance = model_instance self.model_instance = model_instance
def _build_prompt_message_with_files( def _build_prompt_message_with_files(
self, self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
message_files: Sequence[MessageFile],
text_content: str,
message: Message,
app_record,
is_user_message: bool,
) -> PromptMessage: ) -> PromptMessage:
""" """
Build prompt message with files. Build prompt message with files.
@ -104,74 +98,74 @@ class TokenBufferMemory:
:param max_token_limit: max token limit :param max_token_limit: max token limit
:param message_limit: message limit :param message_limit: message limit
""" """
with Session(db.engine) as session: app_record = self.conversation.app
app_record = self.conversation.app
# fetch limited messages, and return reversed # fetch limited messages, and return reversed
stmt = ( stmt = (
select(Message) select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
.where(Message.conversation_id == self.conversation.id) )
.order_by(Message.created_at.desc())
)
if message_limit and message_limit > 0: if message_limit and message_limit > 0:
message_limit = min(message_limit, 500) message_limit = min(message_limit, 500)
else: else:
message_limit = 500 message_limit = 500
stmt = stmt.limit(message_limit) msg_limit_stmt = stmt.limit(message_limit)
messages = session.scalars(stmt).all() messages = db.session.scalars(msg_limit_stmt).all()
# instead of all messages from the conversation, we only need to extract messages # instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message # that belong to the thread of last message
thread_messages = extract_thread_messages(messages) thread_messages = extract_thread_messages(messages)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory # for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0) thread_messages.pop(0)
messages = list(reversed(thread_messages)) messages = list(reversed(thread_messages))
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
for message in messages: for message in messages:
# Process user message with files # Process user message with files
user_file_query = select(MessageFile).where( user_files = (
db.session.query(MessageFile)
.where(
MessageFile.message_id == message.id, MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
) )
user_files = session.scalars(user_file_query).all() .all()
)
if user_files: if user_files:
user_prompt_message = self._build_prompt_message_with_files( user_prompt_message = self._build_prompt_message_with_files(
message_files=user_files, message_files=user_files,
text_content=message.query, text_content=message.query,
message=message, message=message,
app_record=app_record, app_record=app_record,
is_user_message=True, is_user_message=True,
)
prompt_messages.append(user_prompt_message)
else:
prompt_messages.append(UserPromptMessage(content=message.query))
# Process assistant message with files
assistant_file_query = select(MessageFile).where(
MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant"
) )
assistant_files = session.scalars(assistant_file_query).all() prompt_messages.append(user_prompt_message)
else:
prompt_messages.append(UserPromptMessage(content=message.query))
if assistant_files: # Process assistant message with files
assistant_prompt_message = self._build_prompt_message_with_files( assistant_files = (
message_files=assistant_files, db.session.query(MessageFile)
text_content=message.answer, .where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
message=message, .all()
app_record=app_record, )
is_user_message=False,
) if assistant_files:
prompt_messages.append(assistant_prompt_message) assistant_prompt_message = self._build_prompt_message_with_files(
else: message_files=assistant_files,
prompt_messages.append(AssistantPromptMessage(content=message.answer)) text_content=message.answer,
message=message,
app_record=app_record,
is_user_message=False,
)
prompt_messages.append(assistant_prompt_message)
else:
prompt_messages.append(AssistantPromptMessage(content=message.answer))
if not prompt_messages: if not prompt_messages:
return [] return []

View File

@ -158,8 +158,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, LargeLanguageModel): if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel") raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast( return cast(
Union[LLMResult, Generator], Union[LLMResult, Generator],
self._round_robin_invoke( self._round_robin_invoke(
@ -188,8 +186,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, LargeLanguageModel): if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel") raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast( return cast(
int, int,
self._round_robin_invoke( self._round_robin_invoke(
@ -214,8 +210,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, TextEmbeddingModel): if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel") raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast( return cast(
TextEmbeddingResult, TextEmbeddingResult,
self._round_robin_invoke( self._round_robin_invoke(
@ -237,8 +231,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, TextEmbeddingModel): if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel") raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast( return cast(
list[int], list[int],
self._round_robin_invoke( self._round_robin_invoke(
@ -269,8 +261,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, RerankModel): if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel") raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast( return cast(
RerankResult, RerankResult,
self._round_robin_invoke( self._round_robin_invoke(
@ -295,8 +285,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, ModerationModel): if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel") raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast( return cast(
bool, bool,
self._round_robin_invoke( self._round_robin_invoke(
@ -318,8 +306,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, Speech2TextModel): if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel") raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast( return cast(
str, str,
self._round_robin_invoke( self._round_robin_invoke(
@ -343,8 +329,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, TTSModel): if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel") raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast( return cast(
Iterable[bytes], Iterable[bytes],
self._round_robin_invoke( self._round_robin_invoke(
@ -404,8 +388,6 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, TTSModel): if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel") raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices( return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language model=self.model, credentials=self.credentials, language=language
) )

View File

@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token from core.helper.encrypter import decrypt_token
@ -87,10 +88,9 @@ class ApiModeration(Moderation):
@staticmethod @staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = ( stmt = select(APIBasedExtension).where(
db.session.query(APIBasedExtension) APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
) )
extension = db.session.scalar(stmt)
return extension return extension

View File

@ -20,7 +20,6 @@ class ModerationFactory:
:param config: the form config data :param config: the form config data
:return: :return:
""" """
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
# FIXME: mypy error, try to fix it instead of using type: ignore # FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore extension_class.validate_config(tenant_id, config) # type: ignore

View File

@ -135,7 +135,7 @@ class OutputModeration(BaseModel):
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result return result
except Exception as e: except Exception:
logger.exception("Moderation Output error, app_id: %s", app_id) logger.exception("Moderation Output error, app_id: %s", app_id)
return None return None

View File

@ -5,6 +5,7 @@ from typing import Optional
from urllib.parse import urljoin from urllib.parse import urljoin
from opentelemetry.trace import Link, Status, StatusCode from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import ( from core.ops.aliyun_trace.data_exporter.traceclient import (
@ -260,15 +261,15 @@ class AliyunDataTrace(BaseTraceInstance):
app_id = trace_info.metadata.get("app_id") app_id = trace_info.metadata.get("app_id")
if not app_id: if not app_id:
raise ValueError("No app_id found in trace_info metadata") raise ValueError("No app_id found in trace_info metadata")
app_stmt = select(App).where(App.id == app_id)
app = session.query(App).where(App.id == app_id).first() app = session.scalar(app_stmt)
if not app: if not app:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")
if not app.created_by: if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)") raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.query(Account).where(Account.id == app.created_by).first() service_account = session.scalar(account_stmt)
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = ( current_tenant = (

View File

@ -72,7 +72,7 @@ class TraceClient:
else: else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False return False
except requests.exceptions.RequestException as e: except requests.RequestException as e:
logger.debug("AliyunTrace API check failed: %s", str(e)) logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}") raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.ops.entities.config_entity import BaseTracingConfig from core.ops.entities.config_entity import BaseTracingConfig
@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
""" """
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator # Get the app to find its creator
app = session.query(App).where(App.id == app_id).first() app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app: if not app:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")
if not app.created_by: if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)") raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.query(Account).where(Account.id == app.created_by).first() service_account = session.scalar(account_stmt)
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

View File

@ -228,9 +228,9 @@ class OpsTraceManager:
if not trace_config_data: if not trace_config_data:
return None return None
# decrypt_token # decrypt_token
app = db.session.query(App).where(App.id == app_id).first() stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app: if not app:
raise ValueError("App not found") raise ValueError("App not found")
@ -297,20 +297,19 @@ class OpsTraceManager:
@classmethod @classmethod
def get_app_config_through_message_id(cls, message_id: str): def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None app_model_config = None
message_data = db.session.query(Message).where(Message.id == message_id).first() message_stmt = select(Message).where(Message.id == message_id)
message_data = db.session.scalar(message_stmt)
if not message_data: if not message_data:
return None return None
conversation_id = message_data.conversation_id conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data: if not conversation_data:
return None return None
if conversation_data.app_model_config_id: if conversation_data.app_model_config_id:
app_model_config = ( config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
db.session.query(AppModelConfig) app_model_config = db.session.scalar(config_stmt)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs app_model_config = conversation_data.override_model_configs
@ -852,7 +851,7 @@ class TraceQueueManager:
if self.trace_instance: if self.trace_instance:
trace_task.app_id = self.app_id trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task) trace_manager_queue.put(trace_task)
except Exception as e: except Exception:
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type) logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
finally: finally:
self.start_timer() self.start_timer()
@ -871,7 +870,7 @@ class TraceQueueManager:
tasks = self.collect_tasks() tasks = self.collect_tasks()
if tasks: if tasks:
self.send_to_celery(tasks) self.send_to_celery(tasks)
except Exception as e: except Exception:
logger.exception("Error processing trace tasks") logger.exception("Error processing trace tasks")
def start_timer(self): def start_timer(self):

View File

@ -1,6 +1,8 @@
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy import select
from controllers.service_api.wraps import create_or_update_end_user_for_user_id from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -191,10 +193,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
""" """
get the user by user id get the user by user id
""" """
stmt = select(EndUser).where(EndUser.id == user_id)
user = db.session.query(EndUser).where(EndUser.id == user_id).first() user = db.session.scalar(stmt)
if not user: if not user:
user = db.session.query(Account).where(Account.id == user_id).first() stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(stmt)
if not user: if not user:
raise ValueError("user not found") raise ValueError("user not found")

View File

@ -64,7 +64,7 @@ class BasePluginClient:
response = requests.request( response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
) )
except requests.exceptions.ConnectionError: except requests.ConnectionError:
logger.exception("Request to Plugin Daemon Service failed") logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")

View File

@ -1,6 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TypeVar, Union, cast from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
@ -85,7 +85,7 @@ def merge_blob_chunks(
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]), message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
meta=resp.meta, meta=resp.meta,
) )
yield cast(MessageType, merged_message) yield merged_message
# Clean up the buffer # Clean up the buffer
del files[chunk_id] del files[chunk_id]
else: else:

View File

@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list): if isinstance(prompt_message.content, list):
for content in prompt_message.content: for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT: if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data text += content.data
else: else:
content = cast(ImagePromptMessageContent, content) content = cast(ImagePromptMessageContent, content)

View File

@ -1,6 +1,7 @@
import contextlib import contextlib
import json import json
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Optional, cast from typing import Any, Optional, cast
@ -22,6 +23,7 @@ from core.entities.provider_entities import (
QuotaConfiguration, QuotaConfiguration,
QuotaUnit, QuotaUnit,
SystemConfiguration, SystemConfiguration,
UnaddedModelConfiguration,
) )
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
@ -154,8 +156,8 @@ class ProviderManager:
for provider_entity in provider_entities: for provider_entity in provider_entities:
# handle include, exclude # handle include, exclude
if is_filtered( if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity, data=provider_entity,
name_func=lambda x: x.provider, name_func=lambda x: x.provider,
): ):
@ -276,15 +278,11 @@ class ProviderManager:
:param model_type: model type :param model_type: model type
:return: :return:
""" """
# Get the corresponding TenantDefaultModel record stmt = select(TenantDefaultModel).where(
default_model = ( TenantDefaultModel.tenant_id == tenant_id,
db.session.query(TenantDefaultModel) TenantDefaultModel.model_type == model_type.to_origin_model_type(),
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
) )
default_model = db.session.scalar(stmt)
# If it does not exist, get the first available provider model from get_configurations # If it does not exist, get the first available provider model from get_configurations
# and update the TenantDefaultModel record # and update the TenantDefaultModel record
@ -367,16 +365,11 @@ class ProviderManager:
model_names = [model.model for model in available_models] model_names = [model.model for model in available_models]
if model not in model_names: if model not in model_names:
raise ValueError(f"Model {model} does not exist.") raise ValueError(f"Model {model} does not exist.")
stmt = select(TenantDefaultModel).where(
# Get the list of available models from get_configurations and check if it is LLM TenantDefaultModel.tenant_id == tenant_id,
default_model = ( TenantDefaultModel.model_type == model_type.to_origin_model_type(),
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
) )
default_model = db.session.scalar(stmt)
# create or update TenantDefaultModel record # create or update TenantDefaultModel record
if default_model: if default_model:
@ -546,6 +539,23 @@ class ProviderManager:
for credential in available_credentials for credential in available_credentials
] ]
@staticmethod
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
"""
Get all the credentials records from ProviderModelCredential by provider_name
:param tenant_id: workspace id
:param provider_name: provider name
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
)
all_credentials = session.scalars(stmt).all()
return all_credentials
@staticmethod @staticmethod
def _init_trial_provider_records( def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@ -598,16 +608,13 @@ class ProviderManager:
provider_name_to_provider_records_dict[provider_name].append(new_provider_record) provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
existed_provider_record = ( stmt = select(Provider).where(
db.session.query(Provider) Provider.tenant_id == tenant_id,
.where( Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.tenant_id == tenant_id, Provider.provider_type == ProviderType.SYSTEM.value,
Provider.provider_name == ModelProviderID(provider_name).provider_name, Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
.first()
) )
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record: if not existed_provider_record:
continue continue
@ -635,6 +642,44 @@ class ProviderManager:
:param provider_model_records: provider model records :param provider_model_records: provider model records
:return: :return:
""" """
# Get custom provider configuration
custom_provider_configuration = self._get_custom_provider_configuration(
tenant_id, provider_entity, provider_records
)
# Get all model credentials once
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
# Get custom models which have not been added to the model list yet
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
# Get custom model configurations
custom_model_configurations = self._get_custom_model_configurations(
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
)
can_added_models = [
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
]
return CustomConfiguration(
provider=custom_provider_configuration,
models=custom_model_configurations,
can_added_models=can_added_models,
)
def _get_custom_provider_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
) -> CustomProviderConfiguration | None:
"""Get custom provider configuration."""
# Find custom provider record (non-system)
custom_provider_record = next(
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
)
if not custom_provider_record:
return None
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas provider_entity.provider_credential_schema.credential_form_schemas
@ -642,113 +687,98 @@ class ProviderManager:
else [] else []
) )
# Get custom provider record # Get and decrypt provider credentials
custom_provider_record = None provider_credentials = self._get_and_decrypt_credentials(
for provider_record in provider_records: tenant_id=tenant_id,
if provider_record.provider_type == ProviderType.SYSTEM.value: record_id=custom_provider_record.id,
continue encrypted_config=custom_provider_record.encrypted_config,
secret_variables=provider_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.PROVIDER,
is_provider=True,
)
custom_provider_record = provider_record return CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get custom provider credentials def _get_can_added_models(
custom_provider_configuration = None self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
if custom_provider_record: ) -> list[dict]:
provider_credentials_cache = ProviderCredentialsCache( """Get the custom models and credentials from enterprise version which haven't add to the model list"""
tenant_id=tenant_id, existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
identity_id=custom_provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
# Get cached provider credentials # Get not added custom models credentials
cached_provider_credentials = provider_credentials_cache.get() not_added_custom_models_credentials = [
credential
for credential in all_model_credentials
if (credential.model_name, credential.model_type) not in existing_model_set
]
if not cached_provider_credentials: # Group credentials by model
try: model_to_credentials = defaultdict(list)
# fix origin data for credential in not_added_custom_models_credentials:
if custom_provider_record.encrypted_config is None: model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
provider_credentials = {}
elif not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# Get decoding rsa key and cipher for decrypting credentials return [
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: {
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) "model": model_key[0],
"model_type": ModelType.value_of(model_key[1]),
"available_model_credentials": [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in creds
],
}
for model_key, creds in model_to_credentials.items()
]
for variable in provider_credential_secret_variables: def _get_custom_model_configurations(
if variable in provider_credentials: self,
with contextlib.suppress(ValueError): tenant_id: str,
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_entity: ProviderEntity,
provider_credentials.get(variable) or "", # type: ignore provider_model_records: list[ProviderModel],
self.decoding_rsa_key, can_added_models: list[dict],
self.decoding_cipher_rsa, all_model_credentials: Sequence[ProviderModelCredential],
) ) -> list[CustomModelConfiguration]:
"""Get custom model configurations."""
# cache provider credentials # Get model credential secret variables
provider_credentials_cache.set(credentials=provider_credentials)
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables( model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema if provider_entity.model_credential_schema
else [] else []
) )
# Get custom provider model credentials # Create credentials lookup for efficient access
credentials_map = defaultdict(list)
for credential in all_model_credentials:
credentials_map[(credential.model_name, credential.model_type)].append(credential)
custom_model_configurations = [] custom_model_configurations = []
# Process existing model records
for provider_model_record in provider_model_records: for provider_model_record in provider_model_records:
available_model_credentials = self.get_provider_model_available_credentials( # Use pre-fetched credentials instead of individual database calls
tenant_id, available_model_credentials = [
provider_model_record.provider_name, CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
provider_model_record.model_name, for cred in credentials_map.get(
provider_model_record.model_type, (provider_model_record.model_name, provider_model_record.model_type), []
)
]
# Get and decrypt model credentials
provider_model_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=provider_model_record.id,
encrypted_config=provider_model_record.encrypted_config,
secret_variables=model_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.MODEL,
is_provider=False,
) )
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
with contextlib.suppress(ValueError):
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider model credentials
provider_model_credentials_cache.set(credentials=provider_model_credentials)
else:
provider_model_credentials = cached_provider_model_credentials
custom_model_configurations.append( custom_model_configurations.append(
CustomModelConfiguration( CustomModelConfiguration(
model=provider_model_record.model_name, model=provider_model_record.model_name,
@ -760,7 +790,71 @@ class ProviderManager:
) )
) )
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) # Add models that can be added
for model in can_added_models:
custom_model_configurations.append(
CustomModelConfiguration(
model=model["model"],
model_type=model["model_type"],
credentials=None,
current_credential_id=None,
current_credential_name=None,
available_model_credentials=model["available_model_credentials"],
unadded_to_model_list=True,
)
)
return custom_model_configurations
def _get_and_decrypt_credentials(
self,
tenant_id: str,
record_id: str,
encrypted_config: str | None,
secret_variables: list[str],
cache_type: ProviderCredentialsCacheType,
is_provider: bool = False,
) -> dict:
"""Get and decrypt credentials with caching."""
credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=record_id,
cache_type=cache_type,
)
# Try to get from cache first
cached_credentials = credentials_cache.get()
if cached_credentials:
return cached_credentials
# Parse encrypted config
if not encrypted_config:
return {}
if is_provider and not encrypted_config.startswith("{"):
return {"openai_api_key": encrypted_config}
try:
credentials = cast(dict, json.loads(encrypted_config))
except JSONDecodeError:
return {}
# Decrypt secret variables
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in secret_variables:
if variable in credentials:
with contextlib.suppress(ValueError):
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable) or "",
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# Cache the decrypted credentials
credentials_cache.set(credentials=credentials)
return credentials
def _to_system_configuration( def _to_system_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
@ -968,18 +1062,6 @@ class ProviderManager:
load_balancing_model_config.model_name == provider_model_setting.model_name load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type and load_balancing_model_config.model_type == provider_model_setting.model_type
): ):
if load_balancing_model_config.name == "__delete__":
# to calculate current model whether has invalidate lb configs
load_balancing_configs.append(
ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={},
credential_source_type=load_balancing_model_config.credential_source_type,
)
)
continue
if not load_balancing_model_config.enabled: if not load_balancing_model_config.enabled:
continue continue
@ -1045,6 +1127,7 @@ class ProviderManager:
model=provider_model_setting.model_name, model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type), model_type=ModelType.value_of(provider_model_setting.model_type),
enabled=provider_model_setting.enabled, enabled=provider_model_setting.enabled,
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
) )
) )

View File

@ -3,6 +3,7 @@ from typing import Any, Optional
import orjson import orjson
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@ -212,11 +213,10 @@ class Jieba(BaseKeyword):
return sorted_chunk_indices[:k] return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = ( stmt = select(DocumentSegment).where(
db.session.query(DocumentSegment) DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
) )
document_segment = db.session.scalar(stmt)
if document_segment: if document_segment:
document_segment.keywords = keywords document_segment.keywords = keywords
db.session.add(document_segment) db.session.add(document_segment)

View File

@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only from sqlalchemy.orm import Session, load_only
from configs import dify_config from configs import dify_config
@ -24,7 +25,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False, "reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2, "top_k": 4,
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }
@ -127,7 +128,8 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None, external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None, metadata_filtering_conditions: Optional[dict] = None,
): ):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset: if not dataset:
return [] return []
metadata_condition = ( metadata_condition = (
@ -316,10 +318,8 @@ class RetrievalService:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents # Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id") child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = ( child_chunk = db.session.scalar(child_chunk_stmt)
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
if not child_chunk: if not child_chunk:
continue continue
@ -378,17 +378,13 @@ class RetrievalService:
index_node_id = document.metadata.get("doc_id") index_node_id = document.metadata.get("doc_id")
if not index_node_id: if not index_node_id:
continue continue
document_segment_stmt = select(DocumentSegment).where(
segment = ( DocumentSegment.dataset_id == dataset_document.dataset_id,
db.session.query(DocumentSegment) DocumentSegment.enabled == True,
.where( DocumentSegment.status == "completed",
DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.index_node_id == index_node_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
) )
segment = db.session.scalar(document_segment_stmt)
if not segment: if not segment:
continue continue

View File

@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name, collection=self._collection_name,
metrics=self.config.metrics, metrics=self.config.metrics,
include_values=True, include_values=True,
vector=None, vector=None, # ty: ignore [invalid-argument-type]
content=None, content=None, # ty: ignore [invalid-argument-type]
top_k=1, top_k=1,
filter=f"ref_doc_id='{id}'", filter=f"ref_doc_id='{id}'",
) )
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace, namespace=self.config.namespace,
namespace_password=self.config.namespace_password, namespace_password=self.config.namespace_password,
collection=self._collection_name, collection=self._collection_name,
collection_data=None, collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"ref_doc_id IN {ids_str}", collection_data_filter=f"ref_doc_id IN {ids_str}",
) )
self._client.delete_collection_data(request) self._client.delete_collection_data(request)
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace, namespace=self.config.namespace,
namespace_password=self.config.namespace_password, namespace_password=self.config.namespace_password,
collection=self._collection_name, collection=self._collection_name,
collection_data=None, collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'", collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
) )
self._client.delete_collection_data(request) self._client.delete_collection_data(request)
@ -249,14 +249,14 @@ class AnalyticdbVectorOpenAPI:
include_values=kwargs.pop("include_values", True), include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics, metrics=self.config.metrics,
vector=query_vector, vector=query_vector,
content=None, content=None, # ty: ignore [invalid-argument-type]
top_k=kwargs.get("top_k", 4), top_k=kwargs.get("top_k", 4),
filter=where_clause, filter=where_clause,
) )
response = self._client.query_collection_data(request) response = self._client.query_collection_data(request)
documents = [] documents = []
for match in response.body.matches.match: for match in response.body.matches.match:
if match.score > score_threshold: if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_")) metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score metadata["score"] = match.score
doc = Document( doc = Document(
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name, collection=self._collection_name,
include_values=kwargs.pop("include_values", True), include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics, metrics=self.config.metrics,
vector=None, vector=None, # ty: ignore [invalid-argument-type]
content=query, content=query,
top_k=kwargs.get("top_k", 4), top_k=kwargs.get("top_k", 4),
filter=where_clause, filter=where_clause,
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request) response = self._client.query_collection_data(request)
documents = [] documents = []
for match in response.body.matches.match: for match in response.body.matches.match:
if match.score > score_threshold: if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_")) metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score metadata["score"] = match.score
doc = Document( doc = Document(

View File

@ -228,8 +228,8 @@ class AnalyticdbVectorBySql:
) )
documents = [] documents = []
for record in cur: for record in cur:
id, vector, score, page_content, metadata = record _, vector, score, page_content, metadata = record
if score > score_threshold: if score >= score_threshold:
metadata["score"] = score metadata["score"] = score
doc = Document( doc = Document(
page_content=page_content, page_content=page_content,
@ -260,7 +260,7 @@ class AnalyticdbVectorBySql:
) )
documents = [] documents = []
for record in cur: for record in cur:
id, vector, page_content, metadata, score = record _, vector, page_content, metadata, score = record
metadata["score"] = score metadata["score"] = score
doc = Document( doc = Document(
page_content=page_content, page_content=page_content,

View File

@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
if meta is not None: if meta is not None:
meta = json.loads(meta) meta = json.loads(meta)
score = row.get("score", 0.0) score = row.get("score", 0.0)
if score > score_threshold: if score >= score_threshold:
meta["score"] = score meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta) doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc) docs.append(doc)

View File

@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
distance = distances[index] distance = distances[index]
metadata = dict(metadatas[index]) metadata = dict(metadatas[index])
score = 1 - distance score = 1 - distance
if score > score_threshold: if score >= score_threshold:
metadata["score"] = score metadata["score"] = score
doc = Document( doc = Document(
page_content=documents[index], page_content=documents[index],

View File

@ -12,7 +12,7 @@ import clickzetta # type: ignore
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
if TYPE_CHECKING: if TYPE_CHECKING:
from clickzetta import Connection from clickzetta.connector.v0.connection import Connection # type: ignore
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@ -701,7 +701,7 @@ class ClickzettaVector(BaseVector):
len(data_rows), len(data_rows),
vector_dimension, vector_dimension,
) )
except (RuntimeError, ValueError, TypeError, ConnectionError) as e: except (RuntimeError, ValueError, TypeError, ConnectionError):
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
logger.exception("SQL template: %s", insert_sql) logger.exception("SQL template: %s", insert_sql)
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
@ -787,7 +787,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow) # Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {}) _ = kwargs.get("filter", {})
# Build filter clause # Build filter clause
filter_clauses = [] filter_clauses = []
@ -879,7 +879,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow) # Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {}) _ = kwargs.get("filter", {})
# Build filter clause # Build filter clause
filter_clauses = [] filter_clauses = []
@ -938,7 +938,7 @@ class ClickzettaVector(BaseVector):
metadata = {} metadata = {}
else: else:
metadata = {} metadata = {}
except (json.JSONDecodeError, TypeError) as e: except (json.JSONDecodeError, TypeError):
logger.exception("JSON parsing failed") logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex # Fallback: extract document_id with regex
@ -956,7 +956,7 @@ class ClickzettaVector(BaseVector):
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
doc = Document(page_content=row[1], metadata=metadata) doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc) documents.append(doc)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e: except (RuntimeError, ValueError, TypeError, ConnectionError):
logger.exception("Full-text search failed") logger.exception("Full-text search failed")
# Fallback to LIKE search if full-text search fails # Fallback to LIKE search if full-text search fails
return self._search_by_like(query, **kwargs) return self._search_by_like(query, **kwargs)
@ -978,7 +978,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow) # Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {}) _ = kwargs.get("filter", {})
# Build filter clause # Build filter clause
filter_clauses = [] filter_clauses = []

View File

@ -212,10 +212,10 @@ class CouchbaseVector(BaseVector):
documents_to_insert = [ documents_to_insert = [
{"text": text, "embedding": vector, "metadata": metadata} {"text": text, "embedding": vector, "metadata": metadata}
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas) for _, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
] ]
for doc, id in zip(documents_to_insert, uuids): for doc, id in zip(documents_to_insert, uuids):
result = self._scope.collection(self._collection_name).upsert(id, doc) _ = self._scope.collection(self._collection_name).upsert(id, doc)
doc_ids.extend(uuids) doc_ids.extend(uuids)
@ -241,7 +241,7 @@ class CouchbaseVector(BaseVector):
""" """
try: try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e: except Exception:
logger.exception("Failed to delete documents, ids: %s", ids) logger.exception("Failed to delete documents, ids: %s", ids)
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
@ -304,9 +304,9 @@ class CouchbaseVector(BaseVector):
return docs return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2) top_k = kwargs.get("top_k", 4)
try: try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
search_iter = self._scope.search( search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
) )

View File

@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
if not client.ping(): if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch") raise ConnectionError("Failed to connect to Elasticsearch")
except requests.exceptions.ConnectionError as e: except requests.ConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}") raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e: except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold: if score >= score_threshold:
if doc.metadata is not None: if doc.metadata is not None:
doc.metadata["score"] = score doc.metadata["score"] = score
docs.append(doc) docs.append(doc)

View File

@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold: if score >= score_threshold:
if doc.metadata is not None: if doc.metadata is not None:
doc.metadata["score"] = score doc.metadata["score"] = score
docs.append(doc) docs.append(doc)

View File

@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold: if score >= score_threshold:
if doc.metadata is not None: if doc.metadata is not None:
doc.metadata["score"] = score doc.metadata["score"] = score
docs.append(doc) docs.append(doc)

View File

@ -99,7 +99,7 @@ class MatrixoneVector(BaseVector):
return client return client
try: try:
client.create_full_text_index() client.create_full_text_index()
except Exception as e: except Exception:
logger.exception("Failed to create full text index") logger.exception("Failed to create full text index")
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
return client return client

View File

@ -376,7 +376,12 @@ class MilvusVector(BaseVector):
if config.token: if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database) client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
else: else:
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) client = MilvusClient(
uri=config.uri,
user=config.user or "",
password=config.password or "",
db_name=config.database,
)
return client return client

View File

@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
metadata, text, distance = record metadata, text, distance = record
score = 1 - distance score = 1 - distance
metadata["score"] = score metadata["score"] = score
if score > score_threshold: if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata)) docs.append(Document(page_content=text, metadata=metadata))
return docs return docs

View File

@ -197,7 +197,7 @@ class OpenSearchVector(BaseVector):
try: try:
response = self._client.search(index=self._collection_name.lower(), body=query) response = self._client.search(index=self._collection_name.lower(), body=query)
except Exception as e: except Exception:
logger.exception("Error executing vector search, query: %s", query) logger.exception("Error executing vector search, query: %s", query)
raise raise
@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"] metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold: if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc) docs.append(doc)

View File

@ -261,7 +261,7 @@ class OracleVector(BaseVector):
metadata, text, distance = record metadata, text, distance = record
score = 1 - distance score = 1 - distance
metadata["score"] = score metadata["score"] = score
if score > score_threshold: if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata)) docs.append(Document(page_content=text, metadata=metadata))
conn.close() conn.close()
return docs return docs

Some files were not shown because too many files have changed in this diff Show More