Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

# Conflicts:
#	api/core/memory/token_buffer_memory.py
#	api/core/rag/extractor/notion_extractor.py
#	api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
#	api/core/variables/variables.py
#	api/core/workflow/graph/graph.py
#	api/core/workflow/graph_engine/entities/event.py
#	api/services/dataset_service.py
#	web/app/components/app-sidebar/index.tsx
#	web/app/components/base/tag-management/selector.tsx
#	web/app/components/base/toast/index.tsx
#	web/app/components/datasets/create/website/index.tsx
#	web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx
#	web/app/components/workflow/header/version-history-button.tsx
#	web/app/components/workflow/hooks/use-inspect-vars-crud-common.ts
#	web/app/components/workflow/hooks/use-workflow-interactions.ts
#	web/app/components/workflow/panel/version-history-panel/index.tsx
#	web/service/base.ts
This commit is contained in:
jyong 2025-09-03 15:01:06 +08:00
commit d4aed3df5c
572 changed files with 16030 additions and 7973 deletions

View File

@ -42,11 +42,7 @@ jobs:
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run ty check
run: |
cd api
uv add --dev ty
uv run ty check || true
- name: Run pyrefly check
run: |
cd api
@ -66,15 +62,6 @@ jobs:
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
- name: MyPy Cache
uses: actions/cache@v4
with:
path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run MyPy Checks
run: dev/mypy-check
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env

View File

@ -26,6 +26,7 @@ jobs:
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
- name: mdformat
run: |
uvx mdformat .

View File

@ -12,7 +12,6 @@ permissions:
statuses: write
contents: read
jobs:
python-style:
name: Python Style
@ -44,21 +43,14 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Ruff check
- name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: |
uv run --directory api ruff --version
uv run --directory api ruff check ./
uv run --directory api ruff format --check ./
run: dev/basedpyright-check
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
web-style:
name: Web Style
runs-on: ubuntu-latest
@ -100,7 +92,9 @@ jobs:
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run lint
run: |
pnpm run lint
pnpm run eslint
docker-compose-template:
name: Docker Compose Template

5
.gitignore vendored
View File

@ -123,10 +123,12 @@ venv.bak/
# mkdocs documentation
/site
# mypy
# type checking
.mypy_cache/
.dmypy.json
dmypy.json
pyrightconfig.json
!api/pyrightconfig.json
# Pyre type checker
.pyre/
@ -195,7 +197,6 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json.template
!.vscode/README.md
pyrightconfig.json
api/.vscode
# vscode Code History Extension
.history

View File

@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --project api mypy . # Type checking
uv run --directory api basedpyright # Type checking
```
### Frontend (Web)

View File

@ -4,6 +4,48 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest
# Backend Development Environment Setup
.PHONY: dev-setup prepare-docker prepare-web prepare-api
# Default dev setup target
dev-setup: prepare-docker prepare-web prepare-api
@echo "✅ Backend development environment setup complete!"
# Step 1: Prepare Docker middleware
prepare-docker:
@echo "🐳 Setting up Docker middleware..."
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
@echo "✅ Docker middleware started"
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
prepare-api:
@echo "🔧 Setting up API environment..."
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
@cd api && uv sync --dev --extra all
@cd api && uv run flask db upgrade
@echo "✅ API environment prepared (not started)"
# Clean dev environment
dev-clean:
@echo "⚠️ Stopping Docker containers..."
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
@echo "🗑️ Removing volumes..."
@rm -rf docker/volumes/db
@rm -rf docker/volumes/redis
@rm -rf docker/volumes/plugin_daemon
@rm -rf docker/volumes/weaviate
@rm -rf api/storage
@echo "✅ Cleanup complete"
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@ -39,5 +81,21 @@ build-push-web: build-web push-web
build-push-all: build-all push-all
@echo "All Docker images have been built and pushed."
# Help target
help:
@echo "Development Setup Targets:"
@echo " make dev-setup - Run all setup steps for backend dev environment"
@echo " make prepare-docker - Set up Docker middleware"
@echo " make prepare-web - Set up web environment"
@echo " make prepare-api - Set up API environment"
@echo " make dev-clean - Stop Docker middleware containers"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@echo " make build-api - Build API Docker image"
@echo " make build-all - Build all Docker images"
@echo " make push-all - Push all Docker images"
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help

View File

@ -34,17 +34,16 @@ ignore_imports =
[importlinter:contract:rsc]
name = RSC
type = layers
layers =
layers =
graph_engine
response_coordinator
output_registry
containers =
core.workflow.graph_engine
[importlinter:contract:worker]
name = Worker
type = layers
layers =
layers =
graph_engine
worker
containers =
@ -77,26 +76,15 @@ forbidden_modules =
core.workflow.graph_engine.layers
core.workflow.graph_engine.protocols
[importlinter:contract:state-management-layers]
name = State Management Layers
type = layers
layers =
execution_tracker
node_state_manager
edge_state_manager
containers =
core.workflow.graph_engine.state_management
[importlinter:contract:worker-management-layers]
name = Worker Management Layers
type = layers
layers =
worker_pool
worker_factory
dynamic_scaler
activity_tracker
containers =
[importlinter:contract:worker-management]
name = Worker Management
type = forbidden
source_modules =
core.workflow.graph_engine.worker_management
forbidden_modules =
core.workflow.graph_engine.orchestration
core.workflow.graph_engine.command_processing
core.workflow.graph_engine.event_management
[importlinter:contract:error-handling-strategies]
name = Error Handling Strategies
@ -109,14 +97,16 @@ modules =
[importlinter:contract:graph-traversal-components]
name = Graph Traversal Components
type = independence
modules =
core.workflow.graph_engine.graph_traversal.node_readiness
core.workflow.graph_engine.graph_traversal.skip_propagator
type = layers
layers =
edge_processor
skip_propagator
containers =
core.workflow.graph_engine.graph_traversal
[importlinter:contract:command-channels]
name = Command Channels Independence
type = independence
modules =
core.workflow.graph_engine.command_channels.in_memory_channel
core.workflow.graph_engine.command_channels.redis_channel
core.workflow.graph_engine.command_channels.redis_channel

View File

@ -108,5 +108,5 @@ uv run celery -A app.celery beat
../dev/reformat # Run all formatters and linters
uv run ruff check --fix ./ # Fix linting issues
uv run ruff format ./ # Format code
uv run mypy . # Type checking
uv run basedpyright . # Type checking
```

View File

@ -1,11 +0,0 @@
from tests.integration_tests.utils.parent_class import ParentClass
class ChildClass(ParentClass):
"""Test child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return f"Child: {self.name}"

View File

@ -577,7 +577,7 @@ def old_metadata_migration():
for document in documents:
if document.doc_metadata:
doc_metadata = document.doc_metadata
for key, value in doc_metadata.items():
for key in doc_metadata:
for field in BuiltInField:
if field.value == key:
break

View File

@ -29,7 +29,7 @@ class NacosSettingsSource(RemoteSettingsSource):
try:
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
self.remote_configs = self._parse_config(content)
except Exception as e:
except Exception:
logger.exception("[get-access-token] exception occurred")
raise

View File

@ -27,7 +27,7 @@ class NacosHttpClient:
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status()
return response.text
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers, params, module="config"):
@ -77,6 +77,6 @@ class NacosHttpClient:
self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10
except Exception as e:
except Exception:
logger.exception("[get-access-token] exception occur")
raise

View File

@ -19,6 +19,7 @@ language_timezone_mapping = {
"fa-IR": "Asia/Tehran",
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
}
languages = list(language_timezone_mapping.keys())

View File

@ -130,15 +130,19 @@ class InsertExploreAppApi(Resource):
app.is_public = False
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
installed_apps = (
session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
)
).all()
.scalars()
.all()
)
for installed_app in installed_apps:
db.session.delete(installed_app)
for installed_app in installed_apps:
session.delete(installed_app)
db.session.delete(recommended_app)
db.session.commit()

View File

@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource):
flask_restx.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
custom="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)

View File

@ -237,9 +237,14 @@ class AppExportApi(Resource):
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
args = parser.parse_args()
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
return {
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
class AppNameApi(Resource):

View File

@ -130,7 +130,7 @@ class MessageFeedbackApi(Resource):
message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")

View File

@ -532,7 +532,7 @@ class PublishedWorkflowApi(Resource):
)
app_model.workflow_id = workflow.id
db.session.commit()
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
workflow_created_at = TimestampField().format(workflow.created_at)

View File

@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)

View File

@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)
@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id)
except requests.exceptions.HTTPError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)

View File

@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args["token"])
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()

View File

@ -80,7 +80,7 @@ class OAuthCallback(Resource):
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
error_text = e.response.text if e.response else str(e)
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400

View File

@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
if not request.headers.get("Authorization"):
raise BadRequest("Authorization is required")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
raise BadRequest("Authorization header is required")
parts = authorization_header.split(" ")
parts = authorization_header.strip().split(" ")
if len(parts) != 2:
raise BadRequest("Invalid Authorization header format")
token_type = parts[0]
if token_type != "Bearer":
token_type = parts[0].strip()
if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid")
access_token = parts[1]
access_token = parts[1].strip()
if not access_token:
raise BadRequest("access_token is required")
@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
grant_type = OAuthGrantType(parsed_args["grant_type"])
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
"refresh_token": refresh_token,
}
)
else:
raise BadRequest("invalid grant_type")
class OAuthServerUserAccountApi(Resource):

View File

@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
@ -248,7 +249,7 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,

View File

@ -21,6 +21,7 @@ from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@ -431,7 +432,9 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import":
@ -441,7 +444,7 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
@ -456,7 +459,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],

View File

@ -41,6 +41,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.document_fields import (
@ -356,9 +357,6 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
@ -430,7 +428,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
@ -490,13 +488,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
@ -509,7 +507,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],

View File

@ -61,7 +61,6 @@ class ConversationApi(InstalledAppResource):
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204

View File

@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
)
if args.get("config_from", "") == "predefined-model":
@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
)
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@ -1,8 +1,12 @@
from base64 import b64encode
from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@ -10,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view):
def billing_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
return decorated
def enterprise_inner_api_only(view):
def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
return decorated
def plugin_inner_api_only(view):
def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@ -55,7 +55,7 @@ class AudioApi(Resource):
file = request.files["file"]
try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
args = file_preview_parser.parse_args()
# Validate file ownership and get file objects
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator
try:

View File

@ -413,7 +413,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from enum import StrEnum, auto
from functools import wraps
from typing import Optional
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
class WhereisUserArg(Enum):
class WhereisUserArg(StrEnum):
"""
Enum for whereis_user_arg.
"""
QUERY = "query"
JSON = "json"
FORM = "form"
QUERY = auto()
JSON = auto()
FORM = auto()
class FetchUserArg(BaseModel):
@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
if not user_id:
user_id = "DEFAULT-USER"
end_user = (
db.session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
)
.first()
)
.first()
)
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
db.session.add(end_user)
db.session.commit()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
session.add(end_user)
session.commit()
return end_user

View File

@ -73,8 +73,6 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}, 204

View File

@ -4,6 +4,7 @@ from functools import wraps
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -49,18 +50,19 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.scalar(select(App).where(App.id == app_id))
site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
with Session(db.engine, expire_on_commit=False) as session:
app_model = session.scalar(select(App).where(App.id == app_id))
site = session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
# for enterprise webapp auth
app_web_auth_enabled = False

View File

@ -336,7 +336,8 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
agent_thought = db.session.scalar(stmt)
if not agent_thought:
raise ValueError("agent thought not found")
@ -494,7 +495,8 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
files = db.session.scalars(stmt).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:

View File

@ -1,12 +1,14 @@
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel, ConfigDict, Field, ValidationError
class MoreLikeThisConfig(BaseModel):
enabled: bool = False
model_config = ConfigDict(extra="allow")
class AppConfigModel(BaseModel):
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
model_config = ConfigDict(extra="allow")
class MoreLikeThisConfigManager:
@ -23,8 +25,8 @@ class MoreLikeThisConfigManager:
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
try:
return AppConfigModel.model_validate(config).dict(), ["more_like_this"]
except ValidationError as e:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError:
raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
)

View File

@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,

View File

@ -73,7 +73,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
raise ValueError("App not found")
@ -147,7 +149,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
conversation_variables=conversation_variables,
)
# init graph

View File

@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())

View File

@ -68,7 +68,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
@ -306,13 +305,8 @@ class AdvancedChatAppGenerateTaskPipeline:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
) -> Generator[StreamResponse, None, None]:
def _handle_workflow_started_event(self, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# Override graph runtime state - this is a side effect but necessary
graph_runtime_state = event.graph_runtime_state
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
@ -333,15 +327,14 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_retry_resp:
yield node_retry_resp
@ -375,13 +368,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
@ -886,10 +878,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first()
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()

View File

@ -1,7 +1,7 @@
import queue
import time
from abc import abstractmethod
from enum import Enum
from enum import IntEnum, auto
from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
from extensions.ext_redis import redis_client
class PublishFrom(Enum):
APPLICATION_MANAGER = 1
TASK_PIPELINE = 2
class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto()
TASK_PIPELINE = auto()
class AppQueueManager:
@ -174,7 +174,7 @@ class AppQueueManager:
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for key, value in data.items():
for value in data.values():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")

View File

@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from sqlalchemy import select
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
message = db.session.scalar(stmt)
if not message:
raise MessageNotExistsError()

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")

View File

@ -3,6 +3,9 @@ import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
app_model_config = db.session.scalar(stmt)
if not app_model_config:
raise AppModelConfigBrokenError()
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
with Session(db.engine, expire_on_commit=False) as session:
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = db.session.query(Message).where(Message.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message = session.scalar(select(Message).where(Message.id == message_id))
if message is None:
raise MessageNotExistsError("Message not exists")

View File

@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

View File

@ -296,16 +296,15 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response

View File

@ -1,6 +1,8 @@
import logging
from typing import Optional
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@ -25,9 +27,8 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)
if not annotation_setting:
return None

View File

@ -96,7 +96,11 @@ class RateLimit:
if isinstance(generator, Mapping):
return generator
else:
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
request_id=request_id,
)
class RateLimitGenerator:

View File

@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
err = e # ty: ignore [invalid-assignment]
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))

View File

@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param event: agent thought event
:return:
"""
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: Optional[MessageAgentThought] = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:
return AgentThoughtStreamResponse(

View File

@ -3,6 +3,8 @@ from threading import Thread
from typing import Optional, Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import (
@ -84,7 +86,8 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)
if not conversation:
return
@ -98,7 +101,7 @@ class MessageCycleManager:
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
conversation.name = name
except Exception as e:
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
pass
@ -143,7 +146,8 @@ class MessageCycleManager:
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
# get tool file id
@ -183,7 +187,8 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(

View File

@ -1,6 +1,8 @@
import logging
from collections.abc import Sequence
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@ -57,17 +60,14 @@ class DatasetIndexToolCallbackHandler:
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(

View File

@ -1,5 +1,6 @@
import json
import logging
import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
with Session(db.engine) as new_session:
return _validate(new_session)
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
def _generate_provider_credential_name(self, session) -> str:
"""
Generate a unique credential name for provider.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
),
)
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
"""
Generate a unique credential name for custom model.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
def _generate_next_api_key_name(self, session, query_factory) -> str:
"""
Generate next available API KEY name by finding the highest numbered suffix.
:param session: database session
:param query_factory: function that returns the SQLAlchemy query
:return: next available API KEY name
"""
try:
stmt = query_factory()
credential_records = session.execute(stmt).scalars().all()
if not credential_records:
return "API KEY 1"
# Extract numbers from API KEY pattern using list comprehension
pattern = re.compile(r"^API KEY\s+(\d+)$")
numbers = [
int(match.group(1))
for cr in credential_records
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
]
# Return next sequential number
next_number = max(numbers, default=0) + 1
return f"API KEY {next_number}"
except Exception as e:
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
"""
Add custom provider credentials.
:param credentials: provider credentials
@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session)
@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel):
self,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update a saved provider credential (by credential_id).
@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_record and provider_record.credential_id == credential_id:
@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_record = self._get_provider_record(session)
@ -823,7 +880,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create a custom model credential.
@ -834,10 +891,14 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
if credential_name and self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
@ -881,7 +942,7 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None:
"""
Update a custom model credential.
@ -894,7 +955,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
@ -926,8 +987,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_model_record and provider_model_record.credential_id == credential_id:
@ -983,12 +1045,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
@ -1055,6 +1112,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
else:
@ -1608,11 +1666,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "custom_model"
]
if len(provider_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in provider_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
provider_models.append(
ModelWithProviderEntity(
@ -1634,6 +1690,8 @@ class ProviderConfiguration(BaseModel):
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
if model_configuration.unadded_to_model_list:
continue
if model and model != model_configuration.model:
continue
try:
@ -1666,11 +1724,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "provider"
]
if len(custom_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in custom_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
status = ModelStatus.CREDENTIAL_REMOVED

View File

@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = []
unadded_to_model_list: Optional[bool] = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class UnaddedModelConfiguration(BaseModel):
"""
Model class for provider unadded model configuration.
"""
model: str
model_type: ModelType
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
can_added_models: list[UnaddedModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel):
@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
model: str
model_type: ModelType
enabled: bool = True
load_balancing_enabled: bool = False
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
# pydantic configs

View File

@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
timeout=self.timeout,
proxies=proxies,
)
except requests.exceptions.Timeout:
except requests.Timeout:
raise ValueError("request timeout")
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
raise ValueError("request connection error")
if response.status_code != 200:

View File

@ -91,7 +91,7 @@ class Extensible:
# Find extension class
extension_class = None
for name, obj in vars(mod).items():
for obj in vars(mod).values():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break
@ -123,7 +123,7 @@ class Extensible:
)
)
except Exception as e:
except Exception:
logger.exception("Error scanning extensions")
raise

View File

@ -41,9 +41,3 @@ class Extension:
assert module_extension.extension_class is not None
t: type = module_extension.extension_class
return t
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
form_schema = module_extension.form_schema
# TODO validate form_schema

View File

@ -1,5 +1,7 @@
from typing import Optional
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError(

View File

@ -22,7 +22,6 @@ class ExternalDataToolFactory:
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
# FIXME mypy issue here, figure out how to fix it
extension_class.validate_config(tenant_id, config) # type: ignore

View File

@ -3,7 +3,7 @@ import base64
from libs import rsa
def obfuscated_token(token: str):
def obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:
@ -11,6 +11,10 @@ def obfuscated_token(token: str):
return token[:6] + "*" * 12 + token[-2:]
def full_mask_token(token_length=20):
return "*" * token_length
def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
import requests
import httpx
from yarl import URL
from configs import dify_config
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
@ -36,13 +36,13 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e:
except Exception:
pass
return result
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()

View File

@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
def load_single_subclass_from_source(
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
*, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
) -> type:
"""
Load a single subclass from the source

View File

@ -5,9 +5,10 @@ import re
import threading
import time
import uuid
from typing import Any, Optional, cast
from typing import Any, Optional
from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@ -18,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
@ -56,13 +58,11 @@ class IndexingRunner:
if not dataset:
raise ValueError("no dataset found")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
@ -123,11 +123,8 @@ class IndexingRunner:
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
@ -208,7 +205,6 @@ class IndexingRunner:
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# build index
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@ -310,7 +306,8 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
image_file = db.session.scalar(stmt)
if image_file is None:
continue
try:
@ -339,14 +336,14 @@ class IndexingRunner:
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
file_detail = (
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if file_detail:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "notion_import":
@ -357,7 +354,7 @@ class IndexingRunner:
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
@ -378,7 +375,7 @@ class IndexingRunner:
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
@ -401,7 +398,6 @@ class IndexingRunner:
)
# replace doc id to document model id
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs:
if text_doc.metadata is not None:
text_doc.metadata["document_id"] = dataset_document.id

View File

@ -58,11 +58,8 @@ class LLMGenerator:
prompts = [UserPromptMessage(content=prompt)]
with measure_time() as timer:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
@ -71,7 +68,7 @@ class LLMGenerator:
try:
result_dict = json.loads(cleaned_answer)
answer = result_dict["Your Output"]
except json.JSONDecodeError as e:
except json.JSONDecodeError:
logger.exception("Failed to generate name after answer, use query instead")
answer = query
name = answer.strip()
@ -115,13 +112,10 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
)
text_content = response.message.get_text_content()
@ -164,11 +158,8 @@ class LLMGenerator:
)
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = cast(str, response.message.content)
@ -214,11 +205,8 @@ class LLMGenerator:
try:
try:
# the first step to generate the task prompt
prompt_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
except InvokeError as e:
error = str(e)
@ -250,11 +238,8 @@ class LLMGenerator:
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
),
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
except InvokeError as e:
@ -262,11 +247,8 @@ class LLMGenerator:
error_step = "generate variables"
try:
statement_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
),
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e:
@ -309,11 +291,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = model_config.get("completion_params", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = cast(str, response.message.content)
@ -340,13 +319,10 @@ class LLMGenerator:
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters={"temperature": 0.01, "max_tokens": 2000},
stream=False,
)
answer = cast(str, response.message.content)
@ -369,11 +345,8 @@ class LLMGenerator:
model_parameters = model_config.get("model_parameters", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
raw_content = response.message.content
@ -560,11 +533,8 @@ class LLMGenerator:
model_parameters = {"temperature": 0.4}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_raw = cast(str, response.message.content)

View File

@ -101,7 +101,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True)
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
if b_query:
url_for_resource_discovery += f"?{b_query}"
@ -117,7 +117,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else:
return False, ""
return False, ""
except httpx.RequestError as e:
except httpx.RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""

View File

@ -2,7 +2,7 @@ import logging
from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any, Optional, cast
from typing import Any, Optional
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
@ -116,8 +116,7 @@ class MCPClient:
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
self._session.initialize()
return
except MCPAuthError:

View File

@ -2,7 +2,6 @@ from collections.abc import Sequence
from typing import Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
@ -33,12 +32,7 @@ class TokenBufferMemory:
self.model_instance = model_instance
def _build_prompt_message_with_files(
self,
message_files: Sequence[MessageFile],
text_content: str,
message: Message,
app_record,
is_user_message: bool,
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
) -> PromptMessage:
"""
Build prompt message with files.
@ -104,74 +98,74 @@ class TokenBufferMemory:
:param max_token_limit: max token limit
:param message_limit: message limit
"""
with Session(db.engine) as session:
app_record = self.conversation.app
app_record = self.conversation.app
# fetch limited messages, and return reversed
stmt = (
select(Message)
.where(Message.conversation_id == self.conversation.id)
.order_by(Message.created_at.desc())
)
# fetch limited messages, and return reversed
stmt = (
select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
)
if message_limit and message_limit > 0:
message_limit = min(message_limit, 500)
else:
message_limit = 500
if message_limit and message_limit > 0:
message_limit = min(message_limit, 500)
else:
message_limit = 500
stmt = stmt.limit(message_limit)
msg_limit_stmt = stmt.limit(message_limit)
messages = session.scalars(stmt).all()
messages = db.session.scalars(msg_limit_stmt).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
thread_messages = extract_thread_messages(messages)
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
thread_messages = extract_thread_messages(messages)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
messages = list(reversed(thread_messages))
messages = list(reversed(thread_messages))
prompt_messages: list[PromptMessage] = []
for message in messages:
# Process user message with files
user_file_query = select(MessageFile).where(
prompt_messages: list[PromptMessage] = []
for message in messages:
# Process user message with files
user_files = (
db.session.query(MessageFile)
.where(
MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
)
user_files = session.scalars(user_file_query).all()
.all()
)
if user_files:
user_prompt_message = self._build_prompt_message_with_files(
message_files=user_files,
text_content=message.query,
message=message,
app_record=app_record,
is_user_message=True,
)
prompt_messages.append(user_prompt_message)
else:
prompt_messages.append(UserPromptMessage(content=message.query))
# Process assistant message with files
assistant_file_query = select(MessageFile).where(
MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant"
if user_files:
user_prompt_message = self._build_prompt_message_with_files(
message_files=user_files,
text_content=message.query,
message=message,
app_record=app_record,
is_user_message=True,
)
assistant_files = session.scalars(assistant_file_query).all()
prompt_messages.append(user_prompt_message)
else:
prompt_messages.append(UserPromptMessage(content=message.query))
if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files(
message_files=assistant_files,
text_content=message.answer,
message=message,
app_record=app_record,
is_user_message=False,
)
prompt_messages.append(assistant_prompt_message)
else:
prompt_messages.append(AssistantPromptMessage(content=message.answer))
# Process assistant message with files
assistant_files = (
db.session.query(MessageFile)
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
.all()
)
if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files(
message_files=assistant_files,
text_content=message.answer,
message=message,
app_record=app_record,
is_user_message=False,
)
prompt_messages.append(assistant_prompt_message)
else:
prompt_messages.append(AssistantPromptMessage(content=message.answer))
if not prompt_messages:
return []

View File

@ -158,8 +158,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@ -188,8 +186,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@ -214,8 +210,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@ -237,8 +231,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@ -269,8 +261,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@ -295,8 +285,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@ -318,8 +306,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@ -343,8 +329,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@ -404,8 +388,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)

View File

@ -1,6 +1,7 @@
from typing import Optional
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token
@ -87,10 +88,9 @@ class ApiModeration(Moderation):
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
extension = db.session.scalar(stmt)
return extension

View File

@ -20,7 +20,6 @@ class ModerationFactory:
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
# FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore

View File

@ -135,7 +135,7 @@ class OutputModeration(BaseModel):
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
except Exception:
logger.exception("Moderation Output error, app_id: %s", app_id)
return None

View File

@ -5,6 +5,7 @@ from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@ -260,15 +261,15 @@ class AliyunDataTrace(BaseTraceInstance):
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (

View File

@ -72,7 +72,7 @@ class TraceClient:
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import BaseTracingConfig
@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

View File

@ -228,9 +228,9 @@ class OpsTraceManager:
if not trace_config_data:
return None
# decrypt_token
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app:
raise ValueError("App not found")
@ -297,20 +297,19 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).where(Message.id == message_id).first()
message_stmt = select(Message).where(Message.id == message_id)
message_data = db.session.scalar(message_stmt)
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data:
return None
if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
app_model_config = db.session.scalar(config_stmt)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs
@ -852,7 +851,7 @@ class TraceQueueManager:
if self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception as e:
except Exception:
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
finally:
self.start_timer()
@ -871,7 +870,7 @@ class TraceQueueManager:
tasks = self.collect_tasks()
if tasks:
self.send_to_celery(tasks)
except Exception as e:
except Exception:
logger.exception("Error processing trace tasks")
def start_timer(self):

View File

@ -1,6 +1,8 @@
from collections.abc import Generator, Mapping
from typing import Optional, Union
from sqlalchemy import select
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -191,10 +193,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
get the user by user id
"""
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
stmt = select(EndUser).where(EndUser.id == user_id)
user = db.session.scalar(stmt)
if not user:
user = db.session.query(Account).where(Account.id == user_id).first()
stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(stmt)
if not user:
raise ValueError("user not found")

View File

@ -64,7 +64,7 @@ class BasePluginClient:
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TypeVar, Union, cast
from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -85,7 +85,7 @@ def merge_blob_chunks(
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
meta=resp.meta,
)
yield cast(MessageType, merged_message)
yield merged_message
# Clean up the buffer
del files[chunk_id]
else:

View File

@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)

View File

@ -1,6 +1,7 @@
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import Any, Optional, cast
@ -22,6 +23,7 @@ from core.entities.provider_entities import (
QuotaConfiguration,
QuotaUnit,
SystemConfiguration,
UnaddedModelConfiguration,
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
@ -154,8 +156,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
@ -276,15 +278,11 @@ class ProviderManager:
:param model_type: model type
:return:
"""
# Get the corresponding TenantDefaultModel record
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)
# If it does not exist, get the first available provider model from get_configurations
# and update the TenantDefaultModel record
@ -367,16 +365,11 @@ class ProviderManager:
model_names = [model.model for model in available_models]
if model not in model_names:
raise ValueError(f"Model {model} does not exist.")
# Get the list of available models from get_configurations and check if it is LLM
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)
# create or update TenantDefaultModel record
if default_model:
@ -546,6 +539,23 @@ class ProviderManager:
for credential in available_credentials
]
@staticmethod
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
"""
Get all the credentials records from ProviderModelCredential by provider_name
:param tenant_id: workspace id
:param provider_name: provider name
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
)
all_credentials = session.scalars(stmt).all()
return all_credentials
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@ -598,16 +608,13 @@ class ProviderManager:
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError:
db.session.rollback()
existed_provider_record = (
db.session.query(Provider)
.where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
.first()
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
continue
@ -635,6 +642,44 @@ class ProviderManager:
:param provider_model_records: provider model records
:return:
"""
# Get custom provider configuration
custom_provider_configuration = self._get_custom_provider_configuration(
tenant_id, provider_entity, provider_records
)
# Get all model credentials once
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
# Get custom models which have not been added to the model list yet
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
# Get custom model configurations
custom_model_configurations = self._get_custom_model_configurations(
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
)
can_added_models = [
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
]
return CustomConfiguration(
provider=custom_provider_configuration,
models=custom_model_configurations,
can_added_models=can_added_models,
)
def _get_custom_provider_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
) -> CustomProviderConfiguration | None:
"""Get custom provider configuration."""
# Find custom provider record (non-system)
custom_provider_record = next(
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
)
if not custom_provider_record:
return None
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
@ -642,113 +687,98 @@ class ProviderManager:
else []
)
# Get custom provider record
custom_provider_record = None
for provider_record in provider_records:
if provider_record.provider_type == ProviderType.SYSTEM.value:
continue
# Get and decrypt provider credentials
provider_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=custom_provider_record.id,
encrypted_config=custom_provider_record.encrypted_config,
secret_variables=provider_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.PROVIDER,
is_provider=True,
)
custom_provider_record = provider_record
return CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get custom provider credentials
custom_provider_configuration = None
if custom_provider_record:
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=custom_provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
def _get_can_added_models(
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
) -> list[dict]:
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
# Get cached provider credentials
cached_provider_credentials = provider_credentials_cache.get()
# Get not added custom models credentials
not_added_custom_models_credentials = [
credential
for credential in all_model_credentials
if (credential.model_name, credential.model_type) not in existing_model_set
]
if not cached_provider_credentials:
try:
# fix origin data
if custom_provider_record.encrypted_config is None:
provider_credentials = {}
elif not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# Group credentials by model
model_to_credentials = defaultdict(list)
for credential in not_added_custom_models_credentials:
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
return [
{
"model": model_key[0],
"model_type": ModelType.value_of(model_key[1]),
"available_model_credentials": [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in creds
],
}
for model_key, creds in model_to_credentials.items()
]
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
with contextlib.suppress(ValueError):
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable) or "", # type: ignore
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider credentials
provider_credentials_cache.set(credentials=provider_credentials)
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get provider model credential secret variables
def _get_custom_model_configurations(
self,
tenant_id: str,
provider_entity: ProviderEntity,
provider_model_records: list[ProviderModel],
can_added_models: list[dict],
all_model_credentials: Sequence[ProviderModelCredential],
) -> list[CustomModelConfiguration]:
"""Get custom model configurations."""
# Get model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
)
# Get custom provider model credentials
# Create credentials lookup for efficient access
credentials_map = defaultdict(list)
for credential in all_model_credentials:
credentials_map[(credential.model_name, credential.model_type)].append(credential)
custom_model_configurations = []
# Process existing model records
for provider_model_record in provider_model_records:
available_model_credentials = self.get_provider_model_available_credentials(
tenant_id,
provider_model_record.provider_name,
provider_model_record.model_name,
provider_model_record.model_type,
# Use pre-fetched credentials instead of individual database calls
available_model_credentials = [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in credentials_map.get(
(provider_model_record.model_name, provider_model_record.model_type), []
)
]
# Get and decrypt model credentials
provider_model_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=provider_model_record.id,
encrypted_config=provider_model_record.encrypted_config,
secret_variables=model_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.MODEL,
is_provider=False,
)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
with contextlib.suppress(ValueError):
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider model credentials
provider_model_credentials_cache.set(credentials=provider_model_credentials)
else:
provider_model_credentials = cached_provider_model_credentials
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.model_name,
@ -760,7 +790,71 @@ class ProviderManager:
)
)
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
# Add models that can be added
for model in can_added_models:
custom_model_configurations.append(
CustomModelConfiguration(
model=model["model"],
model_type=model["model_type"],
credentials=None,
current_credential_id=None,
current_credential_name=None,
available_model_credentials=model["available_model_credentials"],
unadded_to_model_list=True,
)
)
return custom_model_configurations
def _get_and_decrypt_credentials(
self,
tenant_id: str,
record_id: str,
encrypted_config: str | None,
secret_variables: list[str],
cache_type: ProviderCredentialsCacheType,
is_provider: bool = False,
) -> dict:
"""Get and decrypt credentials with caching."""
credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=record_id,
cache_type=cache_type,
)
# Try to get from cache first
cached_credentials = credentials_cache.get()
if cached_credentials:
return cached_credentials
# Parse encrypted config
if not encrypted_config:
return {}
if is_provider and not encrypted_config.startswith("{"):
return {"openai_api_key": encrypted_config}
try:
credentials = cast(dict, json.loads(encrypted_config))
except JSONDecodeError:
return {}
# Decrypt secret variables
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in secret_variables:
if variable in credentials:
with contextlib.suppress(ValueError):
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable) or "",
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# Cache the decrypted credentials
credentials_cache.set(credentials=credentials)
return credentials
def _to_system_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
@ -968,18 +1062,6 @@ class ProviderManager:
load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type
):
if load_balancing_model_config.name == "__delete__":
# to calculate current model whether has invalidate lb configs
load_balancing_configs.append(
ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={},
credential_source_type=load_balancing_model_config.credential_source_type,
)
)
continue
if not load_balancing_model_config.enabled:
continue
@ -1045,6 +1127,7 @@ class ProviderManager:
model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type),
enabled=provider_model_setting.enabled,
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
)
)

View File

@ -3,6 +3,7 @@ from typing import Any, Optional
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@ -212,11 +213,10 @@ class Jieba(BaseKeyword):
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
)
document_segment = db.session.scalar(stmt)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)

View File

@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
@ -24,7 +25,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -127,7 +128,8 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
metadata_condition = (
@ -316,10 +318,8 @@ class RetrievalService:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk = (
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
continue
@ -378,17 +378,13 @@ class RetrievalService:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
if not segment:
continue

View File

@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
vector=None, # ty: ignore [invalid-argument-type]
content=None, # ty: ignore [invalid-argument-type]
top_k=1,
filter=f"ref_doc_id='{id}'",
)
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
@ -249,14 +249,14 @@ class AnalyticdbVectorOpenAPI:
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
content=None, # ty: ignore [invalid-argument-type]
top_k=kwargs.get("top_k", 4),
filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
vector=None, # ty: ignore [invalid-argument-type]
content=query,
top_k=kwargs.get("top_k", 4),
filter=where_clause,
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(

View File

@ -228,8 +228,8 @@ class AnalyticdbVectorBySql:
)
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
_, vector, score, page_content, metadata = record
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,
@ -260,7 +260,7 @@ class AnalyticdbVectorBySql:
)
documents = []
for record in cur:
id, vector, page_content, metadata, score = record
_, vector, page_content, metadata, score = record
metadata["score"] = score
doc = Document(
page_content=page_content,

View File

@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
if score > score_threshold:
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)

View File

@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
distance = distances[index]
metadata = dict(metadatas[index])
score = 1 - distance
if score > score_threshold:
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=documents[index],

View File

@ -12,7 +12,7 @@ import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from clickzetta import Connection
from clickzetta.connector.v0.connection import Connection # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -701,7 +701,7 @@ class ClickzettaVector(BaseVector):
len(data_rows),
vector_dimension,
)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
except (RuntimeError, ValueError, TypeError, ConnectionError):
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
logger.exception("SQL template: %s", insert_sql)
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
@ -787,7 +787,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
_ = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []
@ -879,7 +879,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
_ = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []
@ -938,7 +938,7 @@ class ClickzettaVector(BaseVector):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError) as e:
except (json.JSONDecodeError, TypeError):
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex
@ -956,7 +956,7 @@ class ClickzettaVector(BaseVector):
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
except (RuntimeError, ValueError, TypeError, ConnectionError):
logger.exception("Full-text search failed")
# Fallback to LIKE search if full-text search fails
return self._search_by_like(query, **kwargs)
@ -978,7 +978,7 @@ class ClickzettaVector(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
_ = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []

View File

@ -212,10 +212,10 @@ class CouchbaseVector(BaseVector):
documents_to_insert = [
{"text": text, "embedding": vector, "metadata": metadata}
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
for _, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
]
for doc, id in zip(documents_to_insert, uuids):
result = self._scope.collection(self._collection_name).upsert(id, doc)
_ = self._scope.collection(self._collection_name).upsert(id, doc)
doc_ids.extend(uuids)
@ -241,7 +241,7 @@ class CouchbaseVector(BaseVector):
"""
try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e:
except Exception:
logger.exception("Failed to delete documents, ids: %s", ids)
def delete_by_document_id(self, document_id: str):
@ -304,9 +304,9 @@ class CouchbaseVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2)
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch")
except requests.exceptions.ConnectionError as e:
except requests.ConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -99,7 +99,7 @@ class MatrixoneVector(BaseVector):
return client
try:
client.create_full_text_index()
except Exception as e:
except Exception:
logger.exception("Failed to create full text index")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return client

View File

@ -376,7 +376,12 @@ class MilvusVector(BaseVector):
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
else:
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
client = MilvusClient(
uri=config.uri,
user=config.user or "",
password=config.password or "",
db_name=config.database,
)
return client

View File

@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -197,7 +197,7 @@ class OpenSearchVector(BaseVector):
try:
response = self._client.search(index=self._collection_name.lower(), body=query)
except Exception as e:
except Exception:
logger.exception("Error executing vector search, query: %s", query)
raise
@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold:
if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -261,7 +261,7 @@ class OracleVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
conn.close()
return docs

Some files were not shown because too many files have changed in this diff Show More