mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
27d6fee1ed
|
|
@ -42,11 +42,7 @@ jobs:
|
|||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
- name: Run ty check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev ty
|
||||
uv run ty check || true
|
||||
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
|
|
@ -66,15 +62,6 @@ jobs:
|
|||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: MyPy Cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: api/.mypy_cache
|
||||
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
|
||||
|
||||
- name: Run MyPy Checks
|
||||
run: dev/mypy-check
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
|
|
|
|||
|
|
@ -44,6 +44,14 @@ jobs:
|
|||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
|
||||
- name: Run Mypy Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
|
|
|||
|
|
@ -67,12 +67,22 @@ jobs:
|
|||
working-directory: ./web
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Generate i18n type definitions
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run gen:i18n-types
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
commit-message: Update i18n files and type definitions based on en-US changes
|
||||
title: 'chore: translate i18n files and update type definitions'
|
||||
body: |
|
||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
||||
|
||||
**Changes included:**
|
||||
- Updated translation files for all locales
|
||||
- Regenerated TypeScript type definitions for type safety
|
||||
branch: chore/automated-i18n-updates
|
||||
|
|
|
|||
|
|
@ -47,6 +47,11 @@ jobs:
|
|||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
|
|
|
|||
|
|
@ -123,10 +123,12 @@ venv.bak/
|
|||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
# type checking
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
pyrightconfig.json
|
||||
!api/pyrightconfig.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
|
@ -195,7 +197,6 @@ sdks/python-client/dify_client.egg-info
|
|||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
# vscode Code History Extension
|
||||
.history
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
|
|||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --project api mypy . # Type checking
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
|
|
|||
60
Makefile
60
Makefile
|
|
@ -4,6 +4,48 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
|||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||
VERSION=latest
|
||||
|
||||
# Backend Development Environment Setup
|
||||
.PHONY: dev-setup prepare-docker prepare-web prepare-api
|
||||
|
||||
# Default dev setup target
|
||||
dev-setup: prepare-docker prepare-web prepare-api
|
||||
@echo "✅ Backend development environment setup complete!"
|
||||
|
||||
# Step 1: Prepare Docker middleware
|
||||
prepare-docker:
|
||||
@echo "🐳 Setting up Docker middleware..."
|
||||
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
|
||||
@echo "✅ Docker middleware started"
|
||||
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
prepare-api:
|
||||
@echo "🔧 Setting up API environment..."
|
||||
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
|
||||
@cd api && uv sync --dev
|
||||
@cd api && uv run flask db upgrade
|
||||
@echo "✅ API environment prepared (not started)"
|
||||
|
||||
# Clean dev environment
|
||||
dev-clean:
|
||||
@echo "⚠️ Stopping Docker containers..."
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
|
||||
@echo "🗑️ Removing volumes..."
|
||||
@rm -rf docker/volumes/db
|
||||
@rm -rf docker/volumes/redis
|
||||
@rm -rf docker/volumes/plugin_daemon
|
||||
@rm -rf docker/volumes/weaviate
|
||||
@rm -rf api/storage
|
||||
@echo "✅ Cleanup complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
|
|
@ -39,5 +81,21 @@ build-push-web: build-web push-web
|
|||
build-push-all: build-all push-all
|
||||
@echo "All Docker images have been built and pushed."
|
||||
|
||||
# Help target
|
||||
help:
|
||||
@echo "Development Setup Targets:"
|
||||
@echo " make dev-setup - Run all setup steps for backend dev environment"
|
||||
@echo " make prepare-docker - Set up Docker middleware"
|
||||
@echo " make prepare-web - Set up web environment"
|
||||
@echo " make prepare-api - Set up API environment"
|
||||
@echo " make dev-clean - Stop Docker middleware containers"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@echo " make build-api - Build API Docker image"
|
||||
@echo " make build-all - Build all Docker images"
|
||||
@echo " make push-all - Push all Docker images"
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ DB_PASSWORD=difyai123456
|
|||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
|
|
|
|||
|
|
@ -108,5 +108,5 @@ uv run celery -A app.celery beat
|
|||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run mypy . # Type checking
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class ChildClass(ParentClass):
|
||||
"""Test child class for module import helper tests"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def get_name(self):
|
||||
return f"Child: {self.name}"
|
||||
|
|
@ -571,7 +571,7 @@ def old_metadata_migration():
|
|||
for document in documents:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = document.doc_metadata
|
||||
for key, value in doc_metadata.items():
|
||||
for key in doc_metadata:
|
||||
for field in BuiltInField:
|
||||
if field.value == key:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class NacosSettingsSource(RemoteSettingsSource):
|
|||
try:
|
||||
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
|
||||
self.remote_configs = self._parse_config(content)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("[get-access-token] exception occurred")
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class NacosHttpClient:
|
|||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.RequestException as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers, params, module="config"):
|
||||
|
|
@ -77,6 +77,6 @@ class NacosHttpClient:
|
|||
self.token = response_data.get("accessToken")
|
||||
self.token_ttl = response_data.get("tokenTtl", 18000)
|
||||
self.token_expire_time = current_time + self.token_ttl - 10
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("[get-access-token] exception occur")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ language_timezone_mapping = {
|
|||
"fa-IR": "Asia/Tehran",
|
||||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource):
|
|||
flask_restx.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
custom="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
|
|
|||
|
|
@ -237,9 +237,14 @@ class AppExportApi(Resource):
|
|||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument("workflow_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class AppNameApi(Resource):
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class CompletionConversationDetailApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
|
|
|||
|
|
@ -526,7 +526,7 @@ class PublishedWorkflowApi(Resource):
|
|||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
|
|||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
|
|||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
|
@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
|
|||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
|
|
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
|
|
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
|
|||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
|
|
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
|
|||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
|||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from functools import wraps
|
|||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
|
|
@ -46,23 +46,38 @@ def oauth_server_access_token_required(view):
|
|||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
raise BadRequest("Authorization header is required")
|
||||
response = jsonify({"error": "Authorization header is required"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
parts = authorization_header.strip().split(" ")
|
||||
parts = authorization_header.strip().split(None, 1)
|
||||
if len(parts) != 2:
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
response = jsonify({"error": "Invalid Authorization header format"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
raise BadRequest("token_type is invalid")
|
||||
response = jsonify({"error": "token_type is invalid"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
raise BadRequest("access_token is required")
|
||||
response = jsonify({"error": "access_token is required"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
||||
if not account:
|
||||
raise BadRequest("access_token or client_id is invalid")
|
||||
response = jsonify({"error": "access_token or client_id is invalid"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
kwargs["account"] = account
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound
|
|||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource):
|
|||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
|
|
@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
|
|
@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from core.model_manager import ModelManager
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import (
|
||||
|
|
@ -425,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
|
@ -485,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
|
|
@ -503,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
|
|
|
|||
|
|
@ -61,7 +61,6 @@ class ConversationApi(InstalledAppResource):
|
|||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
|
|||
|
|
@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
config_from=args.get("config_from", ""),
|
||||
)
|
||||
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
|
|
@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class AudioApi(Resource):
|
|||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
|
|||
args = file_preview_parser.parse_args()
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -410,7 +410,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
|
|
|
|||
|
|
@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
|||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
|
|
|||
|
|
@ -73,8 +73,6 @@ class ConversationApi(WebApiResource):
|
|||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from functools import wraps
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
|
|
@ -49,18 +50,19 @@ def decode_jwt_token():
|
|||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_id = decoded.get("app_id")
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id))
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_model = session.scalar(select(App).where(App.id == app_id))
|
||||
site = session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class MoreLikeThisConfigManager:
|
|||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError as e:
|
||||
except ValidationError:
|
||||
raise ValueError(
|
||||
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -453,7 +453,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(message)
|
||||
db.session.refresh(user)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
|
||||
# return response or stream generator
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
|
|
|
|||
|
|
@ -310,13 +310,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(
|
||||
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
# Override graph runtime state - this is a side effect but necessary
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
|
|
@ -337,15 +332,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_retry_resp:
|
||||
yield node_retry_resp
|
||||
|
|
@ -379,13 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class AppQueueManager:
|
|||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
for value in data.values():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -300,16 +300,15 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
|
@ -11,7 +11,7 @@ from core.file import File, FileUploadConfig
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
|
||||
class InvokeFrom(Enum):
|
||||
class InvokeFrom(StrEnum):
|
||||
"""
|
||||
Invoke From.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -96,7 +96,11 @@ class RateLimit:
|
|||
if isinstance(generator, Mapping):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
|
||||
return RateLimitGenerator(
|
||||
rate_limit=self,
|
||||
generator=generator, # ty: ignore [invalid-argument-type]
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
|
|||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e
|
||||
err = e # ty: ignore [invalid-assignment]
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
|
|
|||
|
|
@ -99,12 +99,13 @@ class MessageCycleManager:
|
|||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||
name = LLMGenerator.generate_conversation_name(
|
||||
app_model.tenant_id, query, conversation_id, conversation.app_id
|
||||
)
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
pass
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class DatasetIndexToolCallbackHandler:
|
|||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
|
|
@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
|
|||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
Generate a unique credential name for provider.
|
||||
:return: credential name
|
||||
"""
|
||||
return self._generate_next_api_key_name(
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
|
||||
"""
|
||||
Generate a unique credential name for custom model.
|
||||
:return: credential name
|
||||
"""
|
||||
return self._generate_next_api_key_name(
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_next_api_key_name(self, session, query_factory) -> str:
|
||||
"""
|
||||
Generate next available API KEY name by finding the highest numbered suffix.
|
||||
:param session: database session
|
||||
:param query_factory: function that returns the SQLAlchemy query
|
||||
:return: next available API KEY name
|
||||
"""
|
||||
try:
|
||||
stmt = query_factory()
|
||||
credential_records = session.execute(stmt).scalars().all()
|
||||
|
||||
if not credential_records:
|
||||
return "API KEY 1"
|
||||
|
||||
# Extract numbers from API KEY pattern using list comprehension
|
||||
pattern = re.compile(r"^API KEY\s+(\d+)$")
|
||||
numbers = [
|
||||
int(match.group(1))
|
||||
for cr in credential_records
|
||||
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
|
||||
]
|
||||
|
||||
# Return next sequential number
|
||||
next_number = max(numbers, default=0) + 1
|
||||
return f"API KEY {next_number}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
|
|
@ -351,8 +410,11 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
if credential_name:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
provider_record = self._get_provider_record(session)
|
||||
|
|
@ -395,7 +457,7 @@ class ProviderConfiguration(BaseModel):
|
|||
self,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update a saved provider credential (by credential_id).
|
||||
|
|
@ -406,7 +468,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_provider_credential_name_exists(
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
|
|
@ -428,9 +490,9 @@ class ProviderConfiguration(BaseModel):
|
|||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.credential_name = credential_name
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
|
||||
if credential_name:
|
||||
credential_record.credential_name = credential_name
|
||||
session.commit()
|
||||
|
||||
if provider_record and provider_record.credential_id == credential_id:
|
||||
|
|
@ -532,13 +594,7 @@ class ProviderConfiguration(BaseModel):
|
|||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
lb_credentials_cache.delete()
|
||||
|
||||
lb_config.credential_id = None
|
||||
lb_config.encrypted_config = None
|
||||
lb_config.enabled = False
|
||||
lb_config.name = "__delete__"
|
||||
lb_config.updated_at = naive_utc_now()
|
||||
session.add(lb_config)
|
||||
session.delete(lb_config)
|
||||
|
||||
# Check if this is the currently active credential
|
||||
provider_record = self._get_provider_record(session)
|
||||
|
|
@ -822,7 +878,7 @@ class ProviderConfiguration(BaseModel):
|
|||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
|
|
@ -833,10 +889,15 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
if credential_name:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=session
|
||||
)
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
|
|
@ -880,7 +941,7 @@ class ProviderConfiguration(BaseModel):
|
|||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
|
|
@ -893,7 +954,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
|
|
@ -925,8 +986,9 @@ class ProviderConfiguration(BaseModel):
|
|||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.credential_name = credential_name
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
credential_record.credential_name = credential_name
|
||||
session.commit()
|
||||
|
||||
if provider_model_record and provider_model_record.credential_id == credential_id:
|
||||
|
|
@ -982,12 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
|||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
lb_credentials_cache.delete()
|
||||
lb_config.credential_id = None
|
||||
lb_config.encrypted_config = None
|
||||
lb_config.enabled = False
|
||||
lb_config.name = "__delete__"
|
||||
lb_config.updated_at = naive_utc_now()
|
||||
session.add(lb_config)
|
||||
session.delete(lb_config)
|
||||
|
||||
# Check if this is the currently active credential
|
||||
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
|
||||
|
|
@ -1054,6 +1111,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
is_valid=True,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
else:
|
||||
|
|
@ -1605,11 +1663,9 @@ class ProviderConfiguration(BaseModel):
|
|||
if config.credential_source_type != "custom_model"
|
||||
]
|
||||
|
||||
if len(provider_model_lb_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
if any(config.name == "__delete__" for config in provider_model_lb_configs):
|
||||
has_invalid_load_balancing_configs = True
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
|
|
@ -1631,6 +1687,8 @@ class ProviderConfiguration(BaseModel):
|
|||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type not in model_types:
|
||||
continue
|
||||
if model_configuration.unadded_to_model_list:
|
||||
continue
|
||||
if model and model != model_configuration.model:
|
||||
continue
|
||||
try:
|
||||
|
|
@ -1663,11 +1721,9 @@ class ProviderConfiguration(BaseModel):
|
|||
if config.credential_source_type != "provider"
|
||||
]
|
||||
|
||||
if len(custom_model_lb_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
if any(config.name == "__delete__" for config in custom_model_lb_configs):
|
||||
has_invalid_load_balancing_configs = True
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
|
||||
|
||||
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
|
||||
status = ModelStatus.CREDENTIAL_REMOVED
|
||||
|
|
|
|||
|
|
@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
|
|||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
unadded_to_model_list: Optional[bool] = False
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class UnaddedModelConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider unadded model configuration.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration.
|
||||
|
|
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
|
|||
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
||||
can_added_models: list[UnaddedModelConfiguration] = []
|
||||
|
||||
|
||||
class ModelLoadBalancingConfiguration(BaseModel):
|
||||
|
|
@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
|
|||
model: str
|
||||
model_type: ModelType
|
||||
enabled: bool = True
|
||||
load_balancing_enabled: bool = False
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
|
||||
|
||||
# pydantic configs
|
||||
|
|
|
|||
|
|
@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
|
|||
timeout=self.timeout,
|
||||
proxies=proxies,
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
except requests.Timeout:
|
||||
raise ValueError("request timeout")
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.ConnectionError:
|
||||
raise ValueError("request connection error")
|
||||
|
||||
if response.status_code != 200:
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class Extensible:
|
|||
|
||||
# Find extension class
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
for obj in vars(mod).values():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
extension_class = obj
|
||||
break
|
||||
|
|
@ -123,7 +123,7 @@ class Extensible:
|
|||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error scanning extensions")
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -41,9 +41,3 @@ class Extension:
|
|||
assert module_extension.extension_class is not None
|
||||
t: type = module_extension.extension_class
|
||||
return t
|
||||
|
||||
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
form_schema = module_extension.form_schema
|
||||
|
||||
# TODO validate form_schema
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ class ExternalDataToolFactory:
|
|||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
# FIXME mypy issue here, figure out how to fix it
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
|||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration(**plugin))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
|
|||
|
||||
|
||||
def load_single_subclass_from_source(
|
||||
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
|
||||
*, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
|
||||
) -> type:
|
||||
"""
|
||||
Load a single subclass from the source
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
|
@ -72,11 +72,14 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
|
|||
return position_map
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def is_filtered(
|
||||
include_set: set[str],
|
||||
exclude_set: set[str],
|
||||
data: Any,
|
||||
name_func: Callable[[Any], str],
|
||||
data: T,
|
||||
name_func: Callable[[T], str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the object should be filtered out.
|
||||
|
|
@ -103,9 +106,9 @@ def is_filtered(
|
|||
|
||||
def sort_by_position_map(
|
||||
position_map: dict[str, int],
|
||||
data: list[Any],
|
||||
name_func: Callable[[Any], str],
|
||||
) -> list[Any]:
|
||||
data: list[T],
|
||||
name_func: Callable[[T], str],
|
||||
):
|
||||
"""
|
||||
Sort the objects by the position map.
|
||||
If the name of the object is not in the position map, it will be put at the end.
|
||||
|
|
@ -122,9 +125,9 @@ def sort_by_position_map(
|
|||
|
||||
def sort_to_dict_by_position_map(
|
||||
position_map: dict[str, int],
|
||||
data: list[Any],
|
||||
name_func: Callable[[Any], str],
|
||||
) -> OrderedDict[str, Any]:
|
||||
data: list[T],
|
||||
name_func: Callable[[T], str],
|
||||
):
|
||||
"""
|
||||
Sort the objects into a ordered dict by the position map.
|
||||
If the name of the object is not in the position map, it will be put at the end.
|
||||
|
|
@ -134,4 +137,4 @@ def sort_to_dict_by_position_map(
|
|||
:return: an OrderedDict with the sorted pairs of name and object
|
||||
"""
|
||||
sorted_items = sort_by_position_map(position_map, data, name_func)
|
||||
return OrderedDict([(name_func(item), item) for item in sorted_items])
|
||||
return OrderedDict((name_func(item), item) for item in sorted_items)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import re
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
|
|
@ -19,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
|
|
@ -269,7 +270,9 @@ class IndexingRunner:
|
|||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
preview_texts = [] # type: ignore
|
||||
# keep separate, avoid union-list ambiguity
|
||||
preview_texts: list[PreviewDetail] = []
|
||||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
index_type = doc_form
|
||||
|
|
@ -292,14 +295,14 @@ class IndexingRunner:
|
|||
for document in documents:
|
||||
if len(preview_texts) < 10:
|
||||
if doc_form and doc_form == "qa_model":
|
||||
preview_detail = QAPreviewDetail(
|
||||
qa_detail = QAPreviewDetail(
|
||||
question=document.page_content, answer=document.metadata.get("answer") or ""
|
||||
)
|
||||
preview_texts.append(preview_detail)
|
||||
qa_preview_texts.append(qa_detail)
|
||||
else:
|
||||
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
|
||||
preview_detail = PreviewDetail(content=document.page_content)
|
||||
if document.children:
|
||||
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
|
||||
preview_detail.child_chunks = [child.page_content for child in document.children]
|
||||
preview_texts.append(preview_detail)
|
||||
|
||||
# delete image files and related db records
|
||||
|
|
@ -320,8 +323,8 @@ class IndexingRunner:
|
|||
db.session.delete(image_file)
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||
|
|
@ -340,7 +343,9 @@ class IndexingRunner:
|
|||
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "notion_import":
|
||||
|
|
@ -351,7 +356,7 @@ class IndexingRunner:
|
|||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
|
|
@ -371,7 +376,7 @@ class IndexingRunner:
|
|||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
|
|
@ -394,7 +399,6 @@ class IndexingRunner:
|
|||
)
|
||||
|
||||
# replace doc id to document model id
|
||||
text_docs = cast(list[Document], text_docs)
|
||||
for text_doc in text_docs:
|
||||
if text_doc.metadata is not None:
|
||||
text_doc.metadata["document_id"] = dataset_document.id
|
||||
|
|
@ -422,6 +426,7 @@ class IndexingRunner:
|
|||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
character_splitter: TextSplitter
|
||||
if processing_rule_mode in ["custom", "hierarchical"]:
|
||||
# The user-defined segmentation rule
|
||||
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
|
|
@ -448,7 +453,7 @@ class IndexingRunner:
|
|||
embedding_model_instance=embedding_model_instance,
|
||||
)
|
||||
|
||||
return character_splitter # type: ignore
|
||||
return character_splitter
|
||||
|
||||
def _split_to_documents_for_estimate(
|
||||
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
|
||||
|
|
|
|||
|
|
@ -56,11 +56,8 @@ class LLMGenerator:
|
|||
prompts = [UserPromptMessage(content=prompt)]
|
||||
|
||||
with measure_time() as timer:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
|
|
@ -69,7 +66,7 @@ class LLMGenerator:
|
|||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
logger.exception("Failed to generate name after answer, use query instead")
|
||||
answer = query
|
||||
name = answer.strip()
|
||||
|
|
@ -113,13 +110,10 @@ class LLMGenerator:
|
|||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
text_content = response.message.get_text_content()
|
||||
|
|
@ -162,11 +156,8 @@ class LLMGenerator:
|
|||
)
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
rule_config["prompt"] = cast(str, response.message.content)
|
||||
|
|
@ -212,11 +203,8 @@ class LLMGenerator:
|
|||
try:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
|
|
@ -248,11 +236,8 @@ class LLMGenerator:
|
|||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
|
||||
except InvokeError as e:
|
||||
|
|
@ -260,11 +245,8 @@ class LLMGenerator:
|
|||
error_step = "generate variables"
|
||||
|
||||
try:
|
||||
statement_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
except InvokeError as e:
|
||||
|
|
@ -307,11 +289,8 @@ class LLMGenerator:
|
|||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = cast(str, response.message.content)
|
||||
|
|
@ -338,13 +317,10 @@ class LLMGenerator:
|
|||
|
||||
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
|
|
@ -367,11 +343,8 @@ class LLMGenerator:
|
|||
model_parameters = model_config.get("model_parameters", {})
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
raw_content = response.message.content
|
||||
|
|
@ -555,11 +528,8 @@ class LLMGenerator:
|
|||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = cast(str, response.message.content)
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
|||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
|
|
@ -117,7 +117,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
|||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except httpx.RequestError as e:
|
||||
except httpx.RequestError:
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
|
|
|
|||
|
|
@ -246,6 +246,10 @@ class StreamableHTTPTransport:
|
|||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 204:
|
||||
logger.debug("Received 204 No Content")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
self._send_session_terminated_error(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
|
|
@ -116,8 +116,7 @@ class MCPClient:
|
|||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
self._session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
|
|
|
|||
|
|
@ -258,5 +258,5 @@ def convert_input_form_to_parameters(
|
|||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
parameters[item.variable]["type"] = "number"
|
||||
return parameters, required
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ This module provides the interface for invoking and authenticating various model
|
|||
|
||||
## Features
|
||||
|
||||
- Supports capability invocation for 5 types of models
|
||||
- Supports capability invocation for 6 types of models
|
||||
|
||||
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
|
||||
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
## 功能介绍
|
||||
|
||||
- 支持 5 种模型类型的能力调用
|
||||
- 支持 6 种模型类型的能力调用
|
||||
|
||||
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
||||
- `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ class ModerationFactory:
|
|||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class OutputModeration(BaseModel):
|
|||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Moderation Output error, app_id: %s", app_id)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class TraceClient:
|
|||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.RequestException as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -849,7 +849,7 @@ class TraceQueueManager:
|
|||
if self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
|
||||
finally:
|
||||
self.start_timer()
|
||||
|
|
@ -868,7 +868,7 @@ class TraceQueueManager:
|
|||
tasks = self.collect_tasks()
|
||||
if tasks:
|
||||
self.send_to_celery(tasks)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error processing trace tasks")
|
||||
|
||||
def start_timer(self):
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
workflow_attributes["trace_id"] = trace_id
|
||||
workflow_attributes["start_time"] = trace_info.start_time
|
||||
workflow_attributes["end_time"] = trace_info.end_time
|
||||
workflow_attributes["tags"] = ["workflow"]
|
||||
workflow_attributes["tags"] = ["dify_workflow"]
|
||||
|
||||
workflow_run = WeaveTraceModel(
|
||||
file_list=trace_info.file_list,
|
||||
|
|
@ -156,6 +156,9 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
# rearrange workflow_node_executions by starting time
|
||||
workflow_node_executions = sorted(workflow_node_executions, key=lambda x: x.created_at)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
|
|||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
|
|
@ -194,11 +195,12 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
user = db.session.scalar(stmt)
|
||||
if not user:
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(stmt)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
user = session.scalar(stmt)
|
||||
if not user:
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
user = session.scalar(stmt)
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import re
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
|
|
@ -71,10 +72,21 @@ class PluginDeclaration(BaseModel):
|
|||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
minimum_dify_version: Optional[str] = Field(default=None)
|
||||
version: Optional[str] = Field(default=None)
|
||||
|
||||
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
@field_validator("minimum_dify_version")
|
||||
@classmethod
|
||||
def validate_minimum_dify_version(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
Version(v)
|
||||
return v
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
version: str = Field(...)
|
||||
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
|
||||
description: I18nObject
|
||||
|
|
@ -94,6 +106,15 @@ class PluginDeclaration(BaseModel):
|
|||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
meta: Meta
|
||||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, v: str) -> str:
|
||||
try:
|
||||
Version(v)
|
||||
return v
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_category(cls, values: dict) -> dict:
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class BasePluginClient:
|
|||
response = requests.request(
|
||||
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.ConnectionError:
|
||||
logger.exception("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
|
@ -22,6 +23,7 @@ from core.entities.provider_entities import (
|
|||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
SystemConfiguration,
|
||||
UnaddedModelConfiguration,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
|
|
@ -148,6 +150,9 @@ class ProviderManager:
|
|||
tenant_id
|
||||
)
|
||||
|
||||
# Get All provider model credentials
|
||||
provider_name_to_provider_model_credentials_dict = self._get_all_provider_model_credentials(tenant_id)
|
||||
|
||||
provider_configurations = ProviderConfigurations(tenant_id=tenant_id)
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
|
|
@ -169,10 +174,18 @@ class ProviderManager:
|
|||
provider_model_records.extend(
|
||||
provider_name_to_provider_model_records_dict.get(provider_id_entity.provider_name, [])
|
||||
)
|
||||
provider_model_credentials = provider_name_to_provider_model_credentials_dict.get(
|
||||
provider_entity.provider, []
|
||||
)
|
||||
provider_id_entity = ModelProviderID(provider_name)
|
||||
if provider_id_entity.is_langgenius():
|
||||
provider_model_credentials.extend(
|
||||
provider_name_to_provider_model_credentials_dict.get(provider_id_entity.provider_name, [])
|
||||
)
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
tenant_id, provider_entity, provider_records, provider_model_records
|
||||
tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials
|
||||
)
|
||||
|
||||
# Convert to system configuration
|
||||
|
|
@ -451,6 +464,24 @@ class ProviderManager:
|
|||
)
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_all_provider_model_credentials(tenant_id: str) -> dict[str, list[ProviderModelCredential]]:
|
||||
"""
|
||||
Get All provider model credentials of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_credentials_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
|
||||
provider_model_credentials = session.scalars(stmt)
|
||||
for provider_model_credential in provider_model_credentials:
|
||||
provider_name_to_provider_model_credentials_dict[provider_model_credential.provider_name].append(
|
||||
provider_model_credential
|
||||
)
|
||||
return provider_name_to_provider_model_credentials_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
|
||||
"""
|
||||
|
|
@ -613,6 +644,7 @@ class ProviderManager:
|
|||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider],
|
||||
provider_model_records: list[ProviderModel],
|
||||
provider_model_credentials: list[ProviderModelCredential],
|
||||
) -> CustomConfiguration:
|
||||
"""
|
||||
Convert to custom configuration.
|
||||
|
|
@ -623,6 +655,41 @@ class ProviderManager:
|
|||
:param provider_model_records: provider model records
|
||||
:return:
|
||||
"""
|
||||
# Get custom provider configuration
|
||||
custom_provider_configuration = self._get_custom_provider_configuration(
|
||||
tenant_id, provider_entity, provider_records
|
||||
)
|
||||
|
||||
# Get custom models which have not been added to the model list yet
|
||||
unadded_models = self._get_can_added_models(provider_model_records, provider_model_credentials)
|
||||
|
||||
# Get custom model configurations
|
||||
custom_model_configurations = self._get_custom_model_configurations(
|
||||
tenant_id, provider_entity, provider_model_records, unadded_models, provider_model_credentials
|
||||
)
|
||||
|
||||
can_added_models = [
|
||||
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
|
||||
]
|
||||
|
||||
return CustomConfiguration(
|
||||
provider=custom_provider_configuration,
|
||||
models=custom_model_configurations,
|
||||
can_added_models=can_added_models,
|
||||
)
|
||||
|
||||
def _get_custom_provider_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
) -> CustomProviderConfiguration | None:
|
||||
"""Get custom provider configuration."""
|
||||
# Find custom provider record (non-system)
|
||||
custom_provider_record = next(
|
||||
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
|
||||
)
|
||||
|
||||
if not custom_provider_record:
|
||||
return None
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
|
|
@ -630,113 +697,98 @@ class ProviderManager:
|
|||
else []
|
||||
)
|
||||
|
||||
# Get custom provider record
|
||||
custom_provider_record = None
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
# Get and decrypt provider credentials
|
||||
provider_credentials = self._get_and_decrypt_credentials(
|
||||
tenant_id=tenant_id,
|
||||
record_id=custom_provider_record.id,
|
||||
encrypted_config=custom_provider_record.encrypted_config,
|
||||
secret_variables=provider_credential_secret_variables,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
is_provider=True,
|
||||
)
|
||||
|
||||
custom_provider_record = provider_record
|
||||
return CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get custom provider credentials
|
||||
custom_provider_configuration = None
|
||||
if custom_provider_record:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=custom_provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
def _get_can_added_models(
|
||||
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
|
||||
) -> list[dict]:
|
||||
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
|
||||
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
|
||||
|
||||
# Get cached provider credentials
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
# Get not added custom models credentials
|
||||
not_added_custom_models_credentials = [
|
||||
credential
|
||||
for credential in all_model_credentials
|
||||
if (credential.model_name, credential.model_type) not in existing_model_set
|
||||
]
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
provider_credentials = {}
|
||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
# Group credentials by model
|
||||
model_to_credentials = defaultdict(list)
|
||||
for credential in not_added_custom_models_credentials:
|
||||
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
return [
|
||||
{
|
||||
"model": model_key[0],
|
||||
"model_type": ModelType.value_of(model_key[1]),
|
||||
"available_model_credentials": [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in creds
|
||||
],
|
||||
}
|
||||
for model_key, creds in model_to_credentials.items()
|
||||
]
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable) or "", # type: ignore
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(credentials=provider_credentials)
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
def _get_custom_model_configurations(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_model_records: list[ProviderModel],
|
||||
can_added_models: list[dict],
|
||||
all_model_credentials: Sequence[ProviderModelCredential],
|
||||
) -> list[CustomModelConfiguration]:
|
||||
"""Get custom model configurations."""
|
||||
# Get model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
# Get custom provider model credentials
|
||||
# Create credentials lookup for efficient access
|
||||
credentials_map = defaultdict(list)
|
||||
for credential in all_model_credentials:
|
||||
credentials_map[(credential.model_name, credential.model_type)].append(credential)
|
||||
|
||||
custom_model_configurations = []
|
||||
|
||||
# Process existing model records
|
||||
for provider_model_record in provider_model_records:
|
||||
available_model_credentials = self.get_provider_model_available_credentials(
|
||||
tenant_id,
|
||||
provider_model_record.provider_name,
|
||||
provider_model_record.model_name,
|
||||
provider_model_record.model_type,
|
||||
# Use pre-fetched credentials instead of individual database calls
|
||||
available_model_credentials = [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in credentials_map.get(
|
||||
(provider_model_record.model_name, provider_model_record.model_type), []
|
||||
)
|
||||
]
|
||||
|
||||
# Get and decrypt model credentials
|
||||
provider_model_credentials = self._get_and_decrypt_credentials(
|
||||
tenant_id=tenant_id,
|
||||
record_id=provider_model_record.id,
|
||||
encrypted_config=provider_model_record.encrypted_config,
|
||||
secret_variables=model_credential_secret_variables,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
is_provider=False,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.model_name,
|
||||
|
|
@ -748,7 +800,71 @@ class ProviderManager:
|
|||
)
|
||||
)
|
||||
|
||||
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
|
||||
# Add models that can be added
|
||||
for model in can_added_models:
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=model["model"],
|
||||
model_type=model["model_type"],
|
||||
credentials=None,
|
||||
current_credential_id=None,
|
||||
current_credential_name=None,
|
||||
available_model_credentials=model["available_model_credentials"],
|
||||
unadded_to_model_list=True,
|
||||
)
|
||||
)
|
||||
|
||||
return custom_model_configurations
|
||||
|
||||
def _get_and_decrypt_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
record_id: str,
|
||||
encrypted_config: str | None,
|
||||
secret_variables: list[str],
|
||||
cache_type: ProviderCredentialsCacheType,
|
||||
is_provider: bool = False,
|
||||
) -> dict:
|
||||
"""Get and decrypt credentials with caching."""
|
||||
credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=record_id,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
|
||||
# Try to get from cache first
|
||||
cached_credentials = credentials_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
# Parse encrypted config
|
||||
if not encrypted_config:
|
||||
return {}
|
||||
|
||||
if is_provider and not encrypted_config.startswith("{"):
|
||||
return {"openai_api_key": encrypted_config}
|
||||
|
||||
try:
|
||||
credentials = cast(dict, json.loads(encrypted_config))
|
||||
except JSONDecodeError:
|
||||
return {}
|
||||
|
||||
# Decrypt secret variables
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in secret_variables:
|
||||
if variable in credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
credentials.get(variable) or "",
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# Cache the decrypted credentials
|
||||
credentials_cache.set(credentials=credentials)
|
||||
return credentials
|
||||
|
||||
def _to_system_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
|
|
@ -956,18 +1072,6 @@ class ProviderManager:
|
|||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if load_balancing_model_config.name == "__delete__":
|
||||
# to calculate current model whether has invalidate lb configs
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={},
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
|
|
@ -1033,6 +1137,7 @@ class ProviderManager:
|
|||
model=provider_model_setting.model_name,
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
|
|||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
|
|
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
|
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
|
@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
|
|
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ class AnalyticdbVectorBySql:
|
|||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
_, vector, score, page_content, metadata = record
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
|
|
@ -260,7 +260,7 @@ class AnalyticdbVectorBySql:
|
|||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, page_content, metadata, score = record
|
||||
_, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import clickzetta # type: ignore
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clickzetta import Connection
|
||||
from clickzetta.connector.v0.connection import Connection # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
|
@ -701,7 +701,7 @@ class ClickzettaVector(BaseVector):
|
|||
len(data_rows),
|
||||
vector_dimension,
|
||||
)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError):
|
||||
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
|
||||
logger.exception("SQL template: %s", insert_sql)
|
||||
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
|
||||
|
|
@ -787,7 +787,7 @@ class ClickzettaVector(BaseVector):
|
|||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
|
|
@ -879,7 +879,7 @@ class ClickzettaVector(BaseVector):
|
|||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
|
|
@ -938,7 +938,7 @@ class ClickzettaVector(BaseVector):
|
|||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.exception("JSON parsing failed")
|
||||
# Fallback: extract document_id with regex
|
||||
|
||||
|
|
@ -956,7 +956,7 @@ class ClickzettaVector(BaseVector):
|
|||
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError):
|
||||
logger.exception("Full-text search failed")
|
||||
# Fallback to LIKE search if full-text search fails
|
||||
return self._search_by_like(query, **kwargs)
|
||||
|
|
@ -978,7 +978,7 @@ class ClickzettaVector(BaseVector):
|
|||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
|
|
|
|||
|
|
@ -212,10 +212,10 @@ class CouchbaseVector(BaseVector):
|
|||
|
||||
documents_to_insert = [
|
||||
{"text": text, "embedding": vector, "metadata": metadata}
|
||||
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
|
||||
for _, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
|
||||
]
|
||||
for doc, id in zip(documents_to_insert, uuids):
|
||||
result = self._scope.collection(self._collection_name).upsert(id, doc)
|
||||
_ = self._scope.collection(self._collection_name).upsert(id, doc)
|
||||
|
||||
doc_ids.extend(uuids)
|
||||
|
||||
|
|
@ -241,7 +241,7 @@ class CouchbaseVector(BaseVector):
|
|||
"""
|
||||
try:
|
||||
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to delete documents, ids: %s", ids)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
|
|
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
|||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
|
|||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
except requests.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class MatrixoneVector(BaseVector):
|
|||
return client
|
||||
try:
|
||||
client.create_full_text_index()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to create full text index")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
return client
|
||||
|
|
|
|||
|
|
@ -376,7 +376,12 @@ class MilvusVector(BaseVector):
|
|||
if config.token:
|
||||
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
|
||||
else:
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
client = MilvusClient(
|
||||
uri=config.uri,
|
||||
user=config.user or "",
|
||||
password=config.password or "",
|
||||
db_name=config.database,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -152,8 +152,8 @@ class MyScaleVector(BaseVector):
|
|||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception as e:
|
||||
logger.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401
|
||||
except Exception:
|
||||
logger.exception("Vector search operation failed")
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class OpenSearchVector(BaseVector):
|
|||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error executing vector search, query: %s", query)
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
import tablestore # type: ignore
|
||||
|
|
@ -71,7 +72,7 @@ class TableStoreVector(BaseVector):
|
|||
table_result = result.get_result_by_table(self._table_name)
|
||||
for item in table_result:
|
||||
if item.is_ok and item.row:
|
||||
kv = {k: v for k, v, t in item.row.attribute_columns}
|
||||
kv = {k: v for k, v, _ in item.row.attribute_columns}
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
|
||||
|
|
@ -102,9 +103,12 @@ class TableStoreVector(BaseVector):
|
|||
return uuids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
_, return_row, _ = self._tablestore_client.get_row(
|
||||
result = self._tablestore_client.get_row(
|
||||
table_name=self._table_name, primary_key=[("id", id)], columns_to_get=["id"]
|
||||
)
|
||||
assert isinstance(result, tuple | list)
|
||||
# Unpack the tuple result
|
||||
_, return_row, _ = result
|
||||
|
||||
return return_row is not None
|
||||
|
||||
|
|
@ -169,6 +173,7 @@ class TableStoreVector(BaseVector):
|
|||
|
||||
def _create_search_index_if_not_exist(self, dimension: int) -> None:
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
assert isinstance(search_index_list, Iterable)
|
||||
if self._index_name in [t[1] for t in search_index_list]:
|
||||
logger.info("Tablestore system index[%s] already exists", self._index_name)
|
||||
return None
|
||||
|
|
@ -212,6 +217,7 @@ class TableStoreVector(BaseVector):
|
|||
|
||||
def _delete_table_if_exist(self):
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
assert isinstance(search_index_list, Iterable)
|
||||
for resp_tuple in search_index_list:
|
||||
self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1])
|
||||
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
|
|
@ -269,7 +275,7 @@ class TableStoreVector(BaseVector):
|
|||
)
|
||||
|
||||
if search_response is not None:
|
||||
rows.extend([row[0][0][1] for row in search_response.rows])
|
||||
rows.extend([row[0][0][1] for row in list(search_response.rows)])
|
||||
|
||||
if search_response is None or search_response.next_token == b"":
|
||||
break
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ class VikingDBConfig(BaseModel):
|
|||
scheme: str
|
||||
connection_timeout: int
|
||||
socket_timeout: int
|
||||
index_type: str = IndexType.HNSW
|
||||
distance: str = DistanceType.L2
|
||||
quant: str = QuantType.Float
|
||||
index_type: str = str(IndexType.HNSW)
|
||||
distance: str = str(DistanceType.L2)
|
||||
quant: str = str(QuantType.Float)
|
||||
|
||||
|
||||
class VikingDBVector(BaseVector):
|
||||
|
|
|
|||
|
|
@ -37,22 +37,15 @@ class WeaviateVector(BaseVector):
|
|||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
auth_config = weaviate.AuthApiKey(api_key=config.api_key or "")
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
|
||||
# Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0,
|
||||
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
|
||||
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
|
||||
# which does not contain the deprecation check.
|
||||
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"):
|
||||
weaviate.connect.connection.PYPI_TIMEOUT = 0.001
|
||||
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute]
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
|
||||
client.batch.configure(
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class Blob(BaseModel):
|
|||
Blob instance
|
||||
"""
|
||||
if mime_type is None and guess_type:
|
||||
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
|
||||
_mimetype = mimetypes.guess_type(path)[0]
|
||||
else:
|
||||
_mimetype = mime_type
|
||||
# We do not load the data immediately, instead we treat the blob as a
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class ExtractProcessor:
|
|||
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
||||
) -> Union[list[Document], str]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=upload_file, document_model="text_model"
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model"
|
||||
)
|
||||
if return_text:
|
||||
delimiter = "\n"
|
||||
|
|
@ -76,7 +76,7 @@ class ExtractProcessor:
|
|||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model")
|
||||
if return_text:
|
||||
delimiter = "\n"
|
||||
return delimiter.join(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ class FirecrawlApp:
|
|||
"formats": ["markdown"],
|
||||
"onlyMainContent": True,
|
||||
"timeout": 30000,
|
||||
"integration": "dify",
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
|
|
@ -40,7 +39,7 @@ class FirecrawlApp:
|
|||
def crawl_url(self, url, params=None) -> str:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
|
||||
headers = self._prepare_headers()
|
||||
json_data = {"url": url, "integration": "dify"}
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
|
||||
|
|
@ -138,7 +137,6 @@ class FirecrawlApp:
|
|||
"timeout": 60000,
|
||||
"ignoreInvalidURLs": False,
|
||||
"scrapeOptions": {},
|
||||
"integration": "dify",
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor):
|
|||
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
|
||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||
except ImportError:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class BaseIndexProcessor(ABC):
|
|||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
all_documents = [] # type: ignore
|
||||
all_documents: list[Document] = []
|
||||
if rules.parent_mode == ParentMode.PARAGRAPH:
|
||||
# Split the text documents into nodes.
|
||||
if not rules.segmentation:
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for index, row in df.iterrows():
|
||||
for _, row in df.iterrows():
|
||||
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
|
||||
text_docs.append(data)
|
||||
if len(text_docs) == 0:
|
||||
|
|
@ -183,7 +183,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
qa_document.metadata["doc_hash"] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to format qa document")
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from typing import Any, Optional, Union, cast
|
|||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
|
|
@ -526,7 +525,7 @@ class DatasetRetrieval:
|
|||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
|
|
@ -534,7 +533,6 @@ class DatasetRetrieval:
|
|||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
|
|
@ -594,9 +592,8 @@ class DatasetRetrieval:
|
|||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
with Session(db.engine) as session:
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
|
@ -988,7 +985,7 @@ class DatasetRetrieval:
|
|||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
|
|
@ -1003,7 +1000,7 @@ class DatasetRetrieval:
|
|||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
|
|
|
|||
|
|
@ -19,5 +19,5 @@ class StructuredChatOutputParser:
|
|||
return ReactAction(response["action"], response.get("action_input", {}), text)
|
||||
else:
|
||||
return ReactFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError(f"Could not parse LLM output: {text}")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, cast
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
|
|
@ -28,18 +28,15 @@ class FunctionCallMultiDatasetRouter:
|
|||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
),
|
||||
result: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
return result.message.tool_calls[0].function.name
|
||||
return None
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Generator, Sequence
|
||||
from typing import Union, cast
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
|
|
@ -77,7 +77,7 @@ class ReactMultiDatasetRouter:
|
|||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _react_invoke(
|
||||
|
|
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
|
|||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
result_text, _ = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
|
|
@ -150,15 +150,12 @@ class ReactMultiDatasetRouter:
|
|||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result = cast(
|
||||
Generator[LLMResult, None, None],
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
|
||||
logger.debug("Queued async save for workflow execution: %s", execution.id_)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to queue save operation for execution %s", execution.id_)
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
|
||||
logger.debug("Cached and queued async save for workflow node execution: %s", execution.id)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to cache or queue save operation for node execution %s", execution.id)
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
|
|
@ -185,6 +185,6 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
|||
logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id)
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -7,9 +7,12 @@ import logging
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
import psycopg2.errors
|
||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
|
|
@ -21,6 +24,7 @@ from core.workflow.nodes.enums import NodeType
|
|||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
|
|
@ -186,18 +190,31 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
db_model.finished_at = domain_model.finished_at
|
||||
return db_model
|
||||
|
||||
def _is_duplicate_key_error(self, exception: BaseException) -> bool:
|
||||
"""Check if the exception is a duplicate key constraint violation."""
|
||||
return isinstance(exception, IntegrityError) and isinstance(exception.orig, psycopg2.errors.UniqueViolation)
|
||||
|
||||
def _regenerate_id_on_duplicate(
|
||||
self, execution: WorkflowNodeExecution, db_model: WorkflowNodeExecutionModel
|
||||
) -> None:
|
||||
"""Regenerate UUID v7 for both domain and database models when duplicate key detected."""
|
||||
new_id = str(uuidv7())
|
||||
logger.warning(
|
||||
"Duplicate key conflict for workflow node execution ID %s, generating new UUID v7: %s", db_model.id, new_id
|
||||
)
|
||||
db_model.id = new_id
|
||||
execution.id = new_id
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save or update a NodeExecution domain entity to the database.
|
||||
|
||||
This method serves as a domain-to-database adapter that:
|
||||
1. Converts the domain entity to its database representation
|
||||
2. Persists the database model using SQLAlchemy's merge operation
|
||||
2. Checks for existing records and updates or inserts accordingly
|
||||
3. Maintains proper multi-tenancy by including tenant context during conversion
|
||||
4. Updates the in-memory cache for faster subsequent lookups
|
||||
|
||||
The method handles both creating new records and updating existing ones through
|
||||
SQLAlchemy's merge operation.
|
||||
5. Handles duplicate key conflicts by retrying with a new UUID v7
|
||||
|
||||
Args:
|
||||
execution: The NodeExecution domain entity to persist
|
||||
|
|
@ -205,19 +222,62 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
# Convert domain model to database model using tenant context and other attributes
|
||||
db_model = self.to_db_model(execution)
|
||||
|
||||
# Create a new database session
|
||||
with self._session_factory() as session:
|
||||
# SQLAlchemy merge intelligently handles both insert and update operations
|
||||
# based on the presence of the primary key
|
||||
session.merge(db_model)
|
||||
session.commit()
|
||||
# Use tenacity for retry logic with duplicate key handling
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
retry=retry_if_exception(self._is_duplicate_key_error),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
def _save_with_retry():
|
||||
try:
|
||||
self._persist_to_database(db_model)
|
||||
except IntegrityError as e:
|
||||
if self._is_duplicate_key_error(e):
|
||||
# Generate new UUID and retry
|
||||
self._regenerate_id_on_duplicate(execution, db_model)
|
||||
raise # Let tenacity handle the retry
|
||||
else:
|
||||
# Different integrity error, don't retry
|
||||
logger.exception("Non-duplicate key integrity error while saving workflow node execution")
|
||||
raise
|
||||
|
||||
# Update the in-memory cache for faster subsequent lookups
|
||||
# Only cache if we have a node_execution_id to use as the cache key
|
||||
try:
|
||||
_save_with_retry()
|
||||
|
||||
# Update the in-memory cache after successful save
|
||||
if db_model.node_execution_id:
|
||||
logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id)
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to save workflow node execution after all retries")
|
||||
raise
|
||||
|
||||
def _persist_to_database(self, db_model: WorkflowNodeExecutionModel) -> None:
|
||||
"""
|
||||
Persist the database model to the database.
|
||||
|
||||
Checks if a record with the same ID exists and either updates it or creates a new one.
|
||||
|
||||
Args:
|
||||
db_model: The database model to persist
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Check if record already exists
|
||||
existing = session.get(WorkflowNodeExecutionModel, db_model.id)
|
||||
|
||||
if existing:
|
||||
# Update existing record by copying all non-private attributes
|
||||
for key, value in db_model.__dict__.items():
|
||||
if not key.startswith("_"):
|
||||
setattr(existing, key, value)
|
||||
else:
|
||||
# Add new record
|
||||
session.add(db_model)
|
||||
|
||||
session.commit()
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue