mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
# Conflicts: # api/core/memory/token_buffer_memory.py # api/core/rag/extractor/notion_extractor.py # api/core/repositories/sqlalchemy_workflow_node_execution_repository.py # api/core/variables/variables.py # api/core/workflow/graph/graph.py # api/core/workflow/graph_engine/entities/event.py # api/services/dataset_service.py # web/app/components/app-sidebar/index.tsx # web/app/components/base/tag-management/selector.tsx # web/app/components/base/toast/index.tsx # web/app/components/datasets/create/website/index.tsx # web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx # web/app/components/workflow/header/version-history-button.tsx # web/app/components/workflow/hooks/use-inspect-vars-crud-common.ts # web/app/components/workflow/hooks/use-workflow-interactions.ts # web/app/components/workflow/panel/version-history-panel/index.tsx # web/service/base.ts
This commit is contained in:
commit
d4aed3df5c
|
|
@ -42,11 +42,7 @@ jobs:
|
|||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
- name: Run ty check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev ty
|
||||
uv run ty check || true
|
||||
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
|
|
@ -66,15 +62,6 @@ jobs:
|
|||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: MyPy Cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: api/.mypy_cache
|
||||
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
|
||||
|
||||
- name: Run MyPy Checks
|
||||
run: dev/mypy-check
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ jobs:
|
|||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx mdformat .
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ permissions:
|
|||
statuses: write
|
||||
contents: read
|
||||
|
||||
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
|
|
@ -44,21 +43,14 @@ jobs:
|
|||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Ruff check
|
||||
- name: Run Basedpyright Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
uv run --directory api ruff --version
|
||||
uv run --directory api ruff check ./
|
||||
uv run --directory api ruff format --check ./
|
||||
run: dev/basedpyright-check
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
|
|
@ -100,7 +92,9 @@ jobs:
|
|||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run lint
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
|
|
|||
|
|
@ -123,10 +123,12 @@ venv.bak/
|
|||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
# type checking
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
pyrightconfig.json
|
||||
!api/pyrightconfig.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
|
@ -195,7 +197,6 @@ sdks/python-client/dify_client.egg-info
|
|||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
# vscode Code History Extension
|
||||
.history
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
|
|||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --project api mypy . # Type checking
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
|
|
|||
60
Makefile
60
Makefile
|
|
@ -4,6 +4,48 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
|||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||
VERSION=latest
|
||||
|
||||
# Backend Development Environment Setup
|
||||
.PHONY: dev-setup prepare-docker prepare-web prepare-api
|
||||
|
||||
# Default dev setup target
|
||||
dev-setup: prepare-docker prepare-web prepare-api
|
||||
@echo "✅ Backend development environment setup complete!"
|
||||
|
||||
# Step 1: Prepare Docker middleware
|
||||
prepare-docker:
|
||||
@echo "🐳 Setting up Docker middleware..."
|
||||
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
|
||||
@echo "✅ Docker middleware started"
|
||||
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
prepare-api:
|
||||
@echo "🔧 Setting up API environment..."
|
||||
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
|
||||
@cd api && uv sync --dev --extra all
|
||||
@cd api && uv run flask db upgrade
|
||||
@echo "✅ API environment prepared (not started)"
|
||||
|
||||
# Clean dev environment
|
||||
dev-clean:
|
||||
@echo "⚠️ Stopping Docker containers..."
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
|
||||
@echo "🗑️ Removing volumes..."
|
||||
@rm -rf docker/volumes/db
|
||||
@rm -rf docker/volumes/redis
|
||||
@rm -rf docker/volumes/plugin_daemon
|
||||
@rm -rf docker/volumes/weaviate
|
||||
@rm -rf api/storage
|
||||
@echo "✅ Cleanup complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
|
|
@ -39,5 +81,21 @@ build-push-web: build-web push-web
|
|||
build-push-all: build-all push-all
|
||||
@echo "All Docker images have been built and pushed."
|
||||
|
||||
# Help target
|
||||
help:
|
||||
@echo "Development Setup Targets:"
|
||||
@echo " make dev-setup - Run all setup steps for backend dev environment"
|
||||
@echo " make prepare-docker - Set up Docker middleware"
|
||||
@echo " make prepare-web - Set up web environment"
|
||||
@echo " make prepare-api - Set up API environment"
|
||||
@echo " make dev-clean - Stop Docker middleware containers"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@echo " make build-api - Build API Docker image"
|
||||
@echo " make build-all - Build all Docker images"
|
||||
@echo " make push-all - Push all Docker images"
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help
|
||||
|
|
|
|||
|
|
@ -34,17 +34,16 @@ ignore_imports =
|
|||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
output_registry
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
|
|
@ -77,26 +76,15 @@ forbidden_modules =
|
|||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:state-management-layers]
|
||||
name = State Management Layers
|
||||
type = layers
|
||||
layers =
|
||||
execution_tracker
|
||||
node_state_manager
|
||||
edge_state_manager
|
||||
containers =
|
||||
core.workflow.graph_engine.state_management
|
||||
|
||||
[importlinter:contract:worker-management-layers]
|
||||
name = Worker Management Layers
|
||||
type = layers
|
||||
layers =
|
||||
worker_pool
|
||||
worker_factory
|
||||
dynamic_scaler
|
||||
activity_tracker
|
||||
containers =
|
||||
[importlinter:contract:worker-management]
|
||||
name = Worker Management
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.orchestration
|
||||
core.workflow.graph_engine.command_processing
|
||||
core.workflow.graph_engine.event_management
|
||||
|
||||
[importlinter:contract:error-handling-strategies]
|
||||
name = Error Handling Strategies
|
||||
|
|
@ -109,14 +97,16 @@ modules =
|
|||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.graph_traversal.node_readiness
|
||||
core.workflow.graph_engine.graph_traversal.skip_propagator
|
||||
type = layers
|
||||
layers =
|
||||
edge_processor
|
||||
skip_propagator
|
||||
containers =
|
||||
core.workflow.graph_engine.graph_traversal
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
|
|
|
|||
|
|
@ -108,5 +108,5 @@ uv run celery -A app.celery beat
|
|||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run mypy . # Type checking
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class ChildClass(ParentClass):
|
||||
"""Test child class for module import helper tests"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def get_name(self):
|
||||
return f"Child: {self.name}"
|
||||
|
|
@ -577,7 +577,7 @@ def old_metadata_migration():
|
|||
for document in documents:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = document.doc_metadata
|
||||
for key, value in doc_metadata.items():
|
||||
for key in doc_metadata:
|
||||
for field in BuiltInField:
|
||||
if field.value == key:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class NacosSettingsSource(RemoteSettingsSource):
|
|||
try:
|
||||
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
|
||||
self.remote_configs = self._parse_config(content)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("[get-access-token] exception occurred")
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class NacosHttpClient:
|
|||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.RequestException as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers, params, module="config"):
|
||||
|
|
@ -77,6 +77,6 @@ class NacosHttpClient:
|
|||
self.token = response_data.get("accessToken")
|
||||
self.token_ttl = response_data.get("tokenTtl", 18000)
|
||||
self.token_expire_time = current_time + self.token_ttl - 10
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("[get-access-token] exception occur")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ language_timezone_mapping = {
|
|||
"fa-IR": "Asia/Tehran",
|
||||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
|
|
|||
|
|
@ -130,15 +130,19 @@ class InsertExploreAppApi(Resource):
|
|||
app.is_public = False
|
||||
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
for installed_app in installed_apps:
|
||||
db.session.delete(installed_app)
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource):
|
|||
flask_restx.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
custom="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
|
|
|||
|
|
@ -237,9 +237,14 @@ class AppExportApi(Resource):
|
|||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument("workflow_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class AppNameApi(Resource):
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
|
|||
|
|
@ -532,7 +532,7 @@ class PublishedWorkflowApi(Resource):
|
|||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
|
|||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
|
|||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
|
@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
|
|||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
|
|
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
|
|
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
|
|||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
|
|
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
|
|||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
|||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
|
|
|||
|
|
@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
|
|||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
raise BadRequest("Authorization is required")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
raise BadRequest("Authorization header is required")
|
||||
|
||||
parts = authorization_header.split(" ")
|
||||
parts = authorization_header.strip().split(" ")
|
||||
if len(parts) != 2:
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
|
||||
token_type = parts[0]
|
||||
if token_type != "Bearer":
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
raise BadRequest("token_type is invalid")
|
||||
|
||||
access_token = parts[1]
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
raise BadRequest("access_token is required")
|
||||
|
||||
|
|
@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
|
|||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
|
|
@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
|
|||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -248,7 +249,7 @@ class DataSourceNotionApi(Resource):
|
|||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.indexing_runner import IndexingRunner
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -431,7 +432,9 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
|
|
@ -441,7 +444,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
|
|
@ -456,7 +459,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from core.model_manager import ModelManager
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import (
|
||||
|
|
@ -356,9 +357,6 @@ class DatasetInitApi(Resource):
|
|||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
|
|
@ -430,7 +428,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
|
@ -490,13 +488,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
|
|
@ -509,7 +507,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
|
|
|
|||
|
|
@ -61,7 +61,6 @@ class ConversationApi(InstalledAppResource):
|
|||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
|
|||
|
|
@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
config_from=args.get("config_from", ""),
|
||||
)
|
||||
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
|
|
@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from base64 import b64encode
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -10,9 +14,9 @@ from extensions.ext_database import db
|
|||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view):
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
|
|
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
|
|||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view):
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
|
|
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
|
|||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class AudioApi(Resource):
|
|||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
|
|||
args = file_preview_parser.parse_args()
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -413,7 +413,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class WhereisUserArg(Enum):
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
|
||||
QUERY = "query"
|
||||
JSON = "json"
|
||||
FORM = "form"
|
||||
QUERY = auto()
|
||||
JSON = auto()
|
||||
FORM = auto()
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
|
|
@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
|||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
|
|
|||
|
|
@ -73,8 +73,6 @@ class ConversationApi(WebApiResource):
|
|||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from functools import wraps
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
|
|
@ -49,18 +50,19 @@ def decode_jwt_token():
|
|||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_id = decoded.get("app_id")
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id))
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_model = session.scalar(select(App).where(App.id == app_id))
|
||||
site = session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
|
|
|
|||
|
|
@ -336,7 +336,8 @@ class BaseAgentRunner(AppRunner):
|
|||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
|
||||
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
|
||||
agent_thought = db.session.scalar(stmt)
|
||||
if not agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
|
|
@ -494,7 +495,8 @@ class BaseAgentRunner(AppRunner):
|
|||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
|
||||
files = db.session.scalars(stmt).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
if message.app_model_config:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
class MoreLikeThisConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AppConfigModel(BaseModel):
|
||||
more_like_this: MoreLikeThisConfig = Field(default_factory=MoreLikeThisConfig)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
|
|
@ -23,8 +25,8 @@ class MoreLikeThisConfigManager:
|
|||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).dict(), ["more_like_this"]
|
||||
except ValidationError as e:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
raise ValueError(
|
||||
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
worker_thread.start()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(message)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
|
|||
|
|
@ -73,7 +73,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
@ -147,7 +149,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables),
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
|||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
|
|
@ -306,13 +305,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(
|
||||
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
def _handle_workflow_started_event(self, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
# Override graph runtime state - this is a side effect but necessary
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
|
|
@ -333,15 +327,14 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_retry_resp:
|
||||
yield node_retry_resp
|
||||
|
|
@ -375,13 +368,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
|
|
@ -886,10 +878,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
|
|
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
|
|||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AgentChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
app_stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(app_stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
|
|||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
|
||||
conversation_result = db.session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
message_result = db.session.query(Message).where(Message.id == message.id).first()
|
||||
msg_stmt = select(Message).where(Message.id == message.id)
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
|
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
|
|||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
TASK_PIPELINE = auto()
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
|
|
@ -174,7 +174,7 @@ class AppQueueManager:
|
|||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
for value in data.values():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfig
|
||||
|
|
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
|
|||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
|
|||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
|
|
@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
stmt = select(Message).where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
message = db.session.scalar(stmt)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfig
|
||||
|
|
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
|
|||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(CompletionAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ import logging
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
|
|
@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
|
||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
stmt = select(AppModelConfig).where(
|
||||
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
|
||||
)
|
||||
app_model_config = db.session.scalar(stmt)
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
|
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
|
|
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = session.scalar(select(Message).where(Message.id == message_id))
|
||||
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -296,16 +296,15 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -25,9 +27,8 @@ class AnnotationReplyFeature:
|
|||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
|
||||
)
|
||||
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
|
||||
annotation_setting = db.session.scalar(stmt)
|
||||
|
||||
if not annotation_setting:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -96,7 +96,11 @@ class RateLimit:
|
|||
if isinstance(generator, Mapping):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
|
||||
return RateLimitGenerator(
|
||||
rate_limit=self,
|
||||
generator=generator, # ty: ignore [invalid-argument-type]
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline:
|
|||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e
|
||||
err = e # ty: ignore [invalid-assignment]
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
|
|
|||
|
|
@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from threading import Thread
|
|||
from typing import Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
|
|
@ -84,7 +86,8 @@ class MessageCycleManager:
|
|||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
conversation = db.session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
|
@ -98,7 +101,7 @@ class MessageCycleManager:
|
|||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
pass
|
||||
|
|
@ -143,7 +146,8 @@ class MessageCycleManager:
|
|||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
|
|
@ -183,7 +187,8 @@ class MessageCycleManager:
|
|||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
|
@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
|
|||
for document in documents:
|
||||
if document.metadata is not None:
|
||||
document_id = document.metadata["document_id"]
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if not dataset_document:
|
||||
_logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
|
||||
|
|
@ -57,17 +60,14 @@ class DatasetIndexToolCallbackHandler:
|
|||
)
|
||||
continue
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
|
|
@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
|
|||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
Generate a unique credential name for provider.
|
||||
:return: credential name
|
||||
"""
|
||||
return self._generate_next_api_key_name(
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
|
||||
"""
|
||||
Generate a unique credential name for custom model.
|
||||
:return: credential name
|
||||
"""
|
||||
return self._generate_next_api_key_name(
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_next_api_key_name(self, session, query_factory) -> str:
|
||||
"""
|
||||
Generate next available API KEY name by finding the highest numbered suffix.
|
||||
:param session: database session
|
||||
:param query_factory: function that returns the SQLAlchemy query
|
||||
:return: next available API KEY name
|
||||
"""
|
||||
try:
|
||||
stmt = query_factory()
|
||||
credential_records = session.execute(stmt).scalars().all()
|
||||
|
||||
if not credential_records:
|
||||
return "API KEY 1"
|
||||
|
||||
# Extract numbers from API KEY pattern using list comprehension
|
||||
pattern = re.compile(r"^API KEY\s+(\d+)$")
|
||||
numbers = [
|
||||
int(match.group(1))
|
||||
for cr in credential_records
|
||||
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
|
||||
]
|
||||
|
||||
# Return next sequential number
|
||||
next_number = max(numbers, default=0) + 1
|
||||
return f"API KEY {next_number}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
|
|
@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
provider_record = self._get_provider_record(session)
|
||||
|
|
@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel):
|
|||
self,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update a saved provider credential (by credential_id).
|
||||
|
|
@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_provider_credential_name_exists(
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
|
|
@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel):
|
|||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.credential_name = credential_name
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
|
||||
if credential_name:
|
||||
credential_record.credential_name = credential_name
|
||||
session.commit()
|
||||
|
||||
if provider_record and provider_record.credential_id == credential_id:
|
||||
|
|
@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel):
|
|||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
lb_credentials_cache.delete()
|
||||
|
||||
lb_config.credential_id = None
|
||||
lb_config.encrypted_config = None
|
||||
lb_config.enabled = False
|
||||
lb_config.name = "__delete__"
|
||||
lb_config.updated_at = naive_utc_now()
|
||||
session.add(lb_config)
|
||||
session.delete(lb_config)
|
||||
|
||||
# Check if this is the currently active credential
|
||||
provider_record = self._get_provider_record(session)
|
||||
|
|
@ -823,7 +880,7 @@ class ProviderConfiguration(BaseModel):
|
|||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
|
|
@ -834,10 +891,14 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=session
|
||||
)
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
|
|
@ -881,7 +942,7 @@ class ProviderConfiguration(BaseModel):
|
|||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
|
|
@ -894,7 +955,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
|
|
@ -926,8 +987,9 @@ class ProviderConfiguration(BaseModel):
|
|||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.credential_name = credential_name
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
credential_record.credential_name = credential_name
|
||||
session.commit()
|
||||
|
||||
if provider_model_record and provider_model_record.credential_id == credential_id:
|
||||
|
|
@ -983,12 +1045,7 @@ class ProviderConfiguration(BaseModel):
|
|||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
lb_credentials_cache.delete()
|
||||
lb_config.credential_id = None
|
||||
lb_config.encrypted_config = None
|
||||
lb_config.enabled = False
|
||||
lb_config.name = "__delete__"
|
||||
lb_config.updated_at = naive_utc_now()
|
||||
session.add(lb_config)
|
||||
session.delete(lb_config)
|
||||
|
||||
# Check if this is the currently active credential
|
||||
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
|
||||
|
|
@ -1055,6 +1112,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
is_valid=True,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
else:
|
||||
|
|
@ -1608,11 +1666,9 @@ class ProviderConfiguration(BaseModel):
|
|||
if config.credential_source_type != "custom_model"
|
||||
]
|
||||
|
||||
if len(provider_model_lb_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
if any(config.name == "__delete__" for config in provider_model_lb_configs):
|
||||
has_invalid_load_balancing_configs = True
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
|
|
@ -1634,6 +1690,8 @@ class ProviderConfiguration(BaseModel):
|
|||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type not in model_types:
|
||||
continue
|
||||
if model_configuration.unadded_to_model_list:
|
||||
continue
|
||||
if model and model != model_configuration.model:
|
||||
continue
|
||||
try:
|
||||
|
|
@ -1666,11 +1724,9 @@ class ProviderConfiguration(BaseModel):
|
|||
if config.credential_source_type != "provider"
|
||||
]
|
||||
|
||||
if len(custom_model_lb_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
if any(config.name == "__delete__" for config in custom_model_lb_configs):
|
||||
has_invalid_load_balancing_configs = True
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
|
||||
|
||||
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
|
||||
status = ModelStatus.CREDENTIAL_REMOVED
|
||||
|
|
|
|||
|
|
@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
|
|||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
unadded_to_model_list: Optional[bool] = False
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class UnaddedModelConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider unadded model configuration.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration.
|
||||
|
|
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
|
|||
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
||||
can_added_models: list[UnaddedModelConfiguration] = []
|
||||
|
||||
|
||||
class ModelLoadBalancingConfiguration(BaseModel):
|
||||
|
|
@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
|
|||
model: str
|
||||
model_type: ModelType
|
||||
enabled: bool = True
|
||||
load_balancing_enabled: bool = False
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
|
||||
|
||||
# pydantic configs
|
||||
|
|
|
|||
|
|
@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
|
|||
timeout=self.timeout,
|
||||
proxies=proxies,
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
except requests.Timeout:
|
||||
raise ValueError("request timeout")
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.ConnectionError:
|
||||
raise ValueError("request connection error")
|
||||
|
||||
if response.status_code != 200:
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class Extensible:
|
|||
|
||||
# Find extension class
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
for obj in vars(mod).values():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
extension_class = obj
|
||||
break
|
||||
|
|
@ -123,7 +123,7 @@ class Extensible:
|
|||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error scanning extensions")
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -41,9 +41,3 @@ class Extension:
|
|||
assert module_extension.extension_class is not None
|
||||
t: type = module_extension.extension_class
|
||||
return t
|
||||
|
||||
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
form_schema = module_extension.form_schema
|
||||
|
||||
# TODO validate form_schema
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.helper import encrypter
|
||||
|
|
@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||
api_based_extension_id = config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
raise ValueError("api_based_extension_id is required")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
api_based_extension = db.session.scalar(stmt)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
|
@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||
raise ValueError(f"config is required, config: {self.config}")
|
||||
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||
assert api_based_extension_id is not None, "api_based_extension_id is required"
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
api_based_extension = db.session.scalar(stmt)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ class ExternalDataToolFactory:
|
|||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
# FIXME mypy issue here, figure out how to fix it
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import base64
|
|||
from libs import rsa
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
def obfuscated_token(token: str) -> str:
|
||||
if not token:
|
||||
return token
|
||||
if len(token) <= 8:
|
||||
|
|
@ -11,6 +11,10 @@ def obfuscated_token(token: str):
|
|||
return token[:6] + "*" * 12 + token[-2:]
|
||||
|
||||
|
||||
def full_mask_token(token_length=20):
|
||||
return "*" * token_length
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
|||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
|
@ -36,13 +36,13 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
|||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration(**plugin))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
|
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
|||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]
|
|||
|
||||
|
||||
def load_single_subclass_from_source(
|
||||
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
|
||||
*, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False
|
||||
) -> type:
|
||||
"""
|
||||
Load a single subclass from the source
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@ import re
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -18,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
|
|
@ -56,13 +58,11 @@ class IndexingRunner:
|
|||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
stmt = select(DatasetProcessRule).where(
|
||||
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
|
||||
)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
index_type = dataset_document.doc_form
|
||||
|
|
@ -123,11 +123,8 @@ class IndexingRunner:
|
|||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.commit()
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
|
||||
|
|
@ -208,7 +205,6 @@ class IndexingRunner:
|
|||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
|
||||
# build index
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
|
@ -310,7 +306,8 @@ class IndexingRunner:
|
|||
# delete image files and related db records
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
|
||||
image_file = db.session.scalar(stmt)
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
|
|
@ -339,14 +336,14 @@ class IndexingRunner:
|
|||
if dataset_document.data_source_type == "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
|
||||
file_detail = (
|
||||
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
|
||||
)
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
elif dataset_document.data_source_type == "notion_import":
|
||||
|
|
@ -357,7 +354,7 @@ class IndexingRunner:
|
|||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
|
|
@ -378,7 +375,7 @@ class IndexingRunner:
|
|||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
|
|
@ -401,7 +398,6 @@ class IndexingRunner:
|
|||
)
|
||||
|
||||
# replace doc id to document model id
|
||||
text_docs = cast(list[Document], text_docs)
|
||||
for text_doc in text_docs:
|
||||
if text_doc.metadata is not None:
|
||||
text_doc.metadata["document_id"] = dataset_document.id
|
||||
|
|
|
|||
|
|
@ -58,11 +58,8 @@ class LLMGenerator:
|
|||
prompts = [UserPromptMessage(content=prompt)]
|
||||
|
||||
with measure_time() as timer:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
|
|
@ -71,7 +68,7 @@ class LLMGenerator:
|
|||
try:
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
logger.exception("Failed to generate name after answer, use query instead")
|
||||
answer = query
|
||||
name = answer.strip()
|
||||
|
|
@ -115,13 +112,10 @@ class LLMGenerator:
|
|||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
text_content = response.message.get_text_content()
|
||||
|
|
@ -164,11 +158,8 @@ class LLMGenerator:
|
|||
)
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
rule_config["prompt"] = cast(str, response.message.content)
|
||||
|
|
@ -214,11 +205,8 @@ class LLMGenerator:
|
|||
try:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
|
|
@ -250,11 +238,8 @@ class LLMGenerator:
|
|||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
|
||||
except InvokeError as e:
|
||||
|
|
@ -262,11 +247,8 @@ class LLMGenerator:
|
|||
error_step = "generate variables"
|
||||
|
||||
try:
|
||||
statement_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
except InvokeError as e:
|
||||
|
|
@ -309,11 +291,8 @@ class LLMGenerator:
|
|||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = cast(str, response.message.content)
|
||||
|
|
@ -340,13 +319,10 @@ class LLMGenerator:
|
|||
|
||||
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={"temperature": 0.01, "max_tokens": 2000},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
|
|
@ -369,11 +345,8 @@ class LLMGenerator:
|
|||
model_parameters = model_config.get("model_parameters", {})
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
raw_content = response.message.content
|
||||
|
|
@ -560,11 +533,8 @@ class LLMGenerator:
|
|||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = cast(str, response.message.content)
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
|||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
|
|
@ -117,7 +117,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
|||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except httpx.RequestError as e:
|
||||
except httpx.RequestError:
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
|
|
@ -116,8 +116,7 @@ class MCPClient:
|
|||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
self._session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from collections.abc import Sequence
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
|
|
@ -33,12 +32,7 @@ class TokenBufferMemory:
|
|||
self.model_instance = model_instance
|
||||
|
||||
def _build_prompt_message_with_files(
|
||||
self,
|
||||
message_files: Sequence[MessageFile],
|
||||
text_content: str,
|
||||
message: Message,
|
||||
app_record,
|
||||
is_user_message: bool,
|
||||
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
|
||||
) -> PromptMessage:
|
||||
"""
|
||||
Build prompt message with files.
|
||||
|
|
@ -104,74 +98,74 @@ class TokenBufferMemory:
|
|||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
app_record = self.conversation.app
|
||||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
# fetch limited messages, and return reversed
|
||||
stmt = (
|
||||
select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
stmt = stmt.limit(message_limit)
|
||||
msg_limit_stmt = stmt.limit(message_limit)
|
||||
|
||||
messages = session.scalars(stmt).all()
|
||||
messages = db.session.scalars(msg_limit_stmt).all()
|
||||
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
messages = list(reversed(thread_messages))
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
# Process user message with files
|
||||
user_file_query = select(MessageFile).where(
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
# Process user message with files
|
||||
user_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(
|
||||
MessageFile.message_id == message.id,
|
||||
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
|
||||
)
|
||||
user_files = session.scalars(user_file_query).all()
|
||||
.all()
|
||||
)
|
||||
|
||||
if user_files:
|
||||
user_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=user_files,
|
||||
text_content=message.query,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
prompt_messages.append(user_prompt_message)
|
||||
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
# Process assistant message with files
|
||||
assistant_file_query = select(MessageFile).where(
|
||||
MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant"
|
||||
if user_files:
|
||||
user_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=user_files,
|
||||
text_content=message.query,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=True,
|
||||
)
|
||||
assistant_files = session.scalars(assistant_file_query).all()
|
||||
prompt_messages.append(user_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
if assistant_files:
|
||||
assistant_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=assistant_files,
|
||||
text_content=message.answer,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=False,
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
# Process assistant message with files
|
||||
assistant_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
|
||||
.all()
|
||||
)
|
||||
|
||||
if assistant_files:
|
||||
assistant_prompt_message = self._build_prompt_message_with_files(
|
||||
message_files=assistant_files,
|
||||
text_content=message.answer,
|
||||
message=message,
|
||||
app_record=app_record,
|
||||
is_user_message=False,
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message)
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -158,8 +158,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -188,8 +186,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -214,8 +210,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
TextEmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -237,8 +231,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -269,8 +261,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -295,8 +285,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -318,8 +306,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -343,8 +329,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -404,8 +388,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
|
||||
from core.helper.encrypter import decrypt_token
|
||||
|
|
@ -87,10 +88,9 @@ class ApiModeration(Moderation):
|
|||
|
||||
@staticmethod
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
extension = db.session.scalar(stmt)
|
||||
|
||||
return extension
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ class ModerationFactory:
|
|||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class OutputModeration(BaseModel):
|
|||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Moderation Output error, app_id: %s", app_id)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
|
|
@ -260,15 +261,15 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).where(App.id == app_id).first()
|
||||
app_stmt = select(App).where(App.id == app_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).where(Account.id == app.created_by).first()
|
||||
account_stmt = select(Account).where(Account.id == app.created_by)
|
||||
service_account = session.scalar(account_stmt)
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
current_tenant = (
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class TraceClient:
|
|||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.RequestException as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
|
|
@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
|
|||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app = session.query(App).where(App.id == app_id).first()
|
||||
app_stmt = select(App).where(App.id == app_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).where(Account.id == app.created_by).first()
|
||||
account_stmt = select(Account).where(Account.id == app.created_by)
|
||||
service_account = session.scalar(account_stmt)
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -228,9 +228,9 @@ class OpsTraceManager:
|
|||
|
||||
if not trace_config_data:
|
||||
return None
|
||||
|
||||
# decrypt_token
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
stmt = select(App).where(App.id == app_id)
|
||||
app = db.session.scalar(stmt)
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
@ -297,20 +297,19 @@ class OpsTraceManager:
|
|||
@classmethod
|
||||
def get_app_config_through_message_id(cls, message_id: str):
|
||||
app_model_config = None
|
||||
message_data = db.session.query(Message).where(Message.id == message_id).first()
|
||||
message_stmt = select(Message).where(Message.id == message_id)
|
||||
message_data = db.session.scalar(message_stmt)
|
||||
if not message_data:
|
||||
return None
|
||||
conversation_id = message_data.conversation_id
|
||||
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
conversation_data = db.session.scalar(conversation_stmt)
|
||||
if not conversation_data:
|
||||
return None
|
||||
|
||||
if conversation_data.app_model_config_id:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.where(AppModelConfig.id == conversation_data.app_model_config_id)
|
||||
.first()
|
||||
)
|
||||
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
|
||||
app_model_config = db.session.scalar(config_stmt)
|
||||
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
|
||||
app_model_config = conversation_data.override_model_configs
|
||||
|
||||
|
|
@ -852,7 +851,7 @@ class TraceQueueManager:
|
|||
if self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
|
||||
finally:
|
||||
self.start_timer()
|
||||
|
|
@ -871,7 +870,7 @@ class TraceQueueManager:
|
|||
tasks = self.collect_tasks()
|
||||
if tasks:
|
||||
self.send_to_celery(tasks)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error processing trace tasks")
|
||||
|
||||
def start_timer(self):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
|
|
@ -191,10 +193,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
user = db.session.scalar(stmt)
|
||||
if not user:
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(stmt)
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class BasePluginClient:
|
|||
response = requests.request(
|
||||
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.ConnectionError:
|
||||
logger.exception("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeVar, Union, cast
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
|
@ -85,7 +85,7 @@ def merge_blob_chunks(
|
|||
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
|
||||
meta=resp.meta,
|
||||
)
|
||||
yield cast(MessageType, merged_message)
|
||||
yield merged_message
|
||||
# Clean up the buffer
|
||||
del files[chunk_id]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -87,7 +87,6 @@ class PromptMessageUtil:
|
|||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ from core.entities.provider_entities import (
|
|||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
SystemConfiguration,
|
||||
UnaddedModelConfiguration,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
|
|
@ -154,8 +156,8 @@ class ProviderManager:
|
|||
for provider_entity in provider_entities:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
|
|
@ -276,15 +278,11 @@ class ProviderManager:
|
|||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get the corresponding TenantDefaultModel record
|
||||
default_model = (
|
||||
db.session.query(TenantDefaultModel)
|
||||
.where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
stmt = select(TenantDefaultModel).where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
default_model = db.session.scalar(stmt)
|
||||
|
||||
# If it does not exist, get the first available provider model from get_configurations
|
||||
# and update the TenantDefaultModel record
|
||||
|
|
@ -367,16 +365,11 @@ class ProviderManager:
|
|||
model_names = [model.model for model in available_models]
|
||||
if model not in model_names:
|
||||
raise ValueError(f"Model {model} does not exist.")
|
||||
|
||||
# Get the list of available models from get_configurations and check if it is LLM
|
||||
default_model = (
|
||||
db.session.query(TenantDefaultModel)
|
||||
.where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
stmt = select(TenantDefaultModel).where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
default_model = db.session.scalar(stmt)
|
||||
|
||||
# create or update TenantDefaultModel record
|
||||
if default_model:
|
||||
|
|
@ -546,6 +539,23 @@ class ProviderManager:
|
|||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
|
||||
"""
|
||||
Get all the credentials records from ProviderModelCredential by provider_name
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
|
||||
)
|
||||
|
||||
all_credentials = session.scalars(stmt).all()
|
||||
return all_credentials
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
|
|
@ -598,16 +608,13 @@ class ProviderManager:
|
|||
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
existed_provider_record = (
|
||||
db.session.query(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
.first()
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
existed_provider_record = db.session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
continue
|
||||
|
||||
|
|
@ -635,6 +642,44 @@ class ProviderManager:
|
|||
:param provider_model_records: provider model records
|
||||
:return:
|
||||
"""
|
||||
# Get custom provider configuration
|
||||
custom_provider_configuration = self._get_custom_provider_configuration(
|
||||
tenant_id, provider_entity, provider_records
|
||||
)
|
||||
|
||||
# Get all model credentials once
|
||||
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
|
||||
|
||||
# Get custom models which have not been added to the model list yet
|
||||
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
|
||||
|
||||
# Get custom model configurations
|
||||
custom_model_configurations = self._get_custom_model_configurations(
|
||||
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
|
||||
)
|
||||
|
||||
can_added_models = [
|
||||
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
|
||||
]
|
||||
|
||||
return CustomConfiguration(
|
||||
provider=custom_provider_configuration,
|
||||
models=custom_model_configurations,
|
||||
can_added_models=can_added_models,
|
||||
)
|
||||
|
||||
def _get_custom_provider_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
) -> CustomProviderConfiguration | None:
|
||||
"""Get custom provider configuration."""
|
||||
# Find custom provider record (non-system)
|
||||
custom_provider_record = next(
|
||||
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
|
||||
)
|
||||
|
||||
if not custom_provider_record:
|
||||
return None
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
|
|
@ -642,113 +687,98 @@ class ProviderManager:
|
|||
else []
|
||||
)
|
||||
|
||||
# Get custom provider record
|
||||
custom_provider_record = None
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
# Get and decrypt provider credentials
|
||||
provider_credentials = self._get_and_decrypt_credentials(
|
||||
tenant_id=tenant_id,
|
||||
record_id=custom_provider_record.id,
|
||||
encrypted_config=custom_provider_record.encrypted_config,
|
||||
secret_variables=provider_credential_secret_variables,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
is_provider=True,
|
||||
)
|
||||
|
||||
custom_provider_record = provider_record
|
||||
return CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get custom provider credentials
|
||||
custom_provider_configuration = None
|
||||
if custom_provider_record:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=custom_provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
def _get_can_added_models(
|
||||
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
|
||||
) -> list[dict]:
|
||||
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
|
||||
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
|
||||
|
||||
# Get cached provider credentials
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
# Get not added custom models credentials
|
||||
not_added_custom_models_credentials = [
|
||||
credential
|
||||
for credential in all_model_credentials
|
||||
if (credential.model_name, credential.model_type) not in existing_model_set
|
||||
]
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
provider_credentials = {}
|
||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
# Group credentials by model
|
||||
model_to_credentials = defaultdict(list)
|
||||
for credential in not_added_custom_models_credentials:
|
||||
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
return [
|
||||
{
|
||||
"model": model_key[0],
|
||||
"model_type": ModelType.value_of(model_key[1]),
|
||||
"available_model_credentials": [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in creds
|
||||
],
|
||||
}
|
||||
for model_key, creds in model_to_credentials.items()
|
||||
]
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable) or "", # type: ignore
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(credentials=provider_credentials)
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
def _get_custom_model_configurations(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_model_records: list[ProviderModel],
|
||||
can_added_models: list[dict],
|
||||
all_model_credentials: Sequence[ProviderModelCredential],
|
||||
) -> list[CustomModelConfiguration]:
|
||||
"""Get custom model configurations."""
|
||||
# Get model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
# Get custom provider model credentials
|
||||
# Create credentials lookup for efficient access
|
||||
credentials_map = defaultdict(list)
|
||||
for credential in all_model_credentials:
|
||||
credentials_map[(credential.model_name, credential.model_type)].append(credential)
|
||||
|
||||
custom_model_configurations = []
|
||||
|
||||
# Process existing model records
|
||||
for provider_model_record in provider_model_records:
|
||||
available_model_credentials = self.get_provider_model_available_credentials(
|
||||
tenant_id,
|
||||
provider_model_record.provider_name,
|
||||
provider_model_record.model_name,
|
||||
provider_model_record.model_type,
|
||||
# Use pre-fetched credentials instead of individual database calls
|
||||
available_model_credentials = [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in credentials_map.get(
|
||||
(provider_model_record.model_name, provider_model_record.model_type), []
|
||||
)
|
||||
]
|
||||
|
||||
# Get and decrypt model credentials
|
||||
provider_model_credentials = self._get_and_decrypt_credentials(
|
||||
tenant_id=tenant_id,
|
||||
record_id=provider_model_record.id,
|
||||
encrypted_config=provider_model_record.encrypted_config,
|
||||
secret_variables=model_credential_secret_variables,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
is_provider=False,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.model_name,
|
||||
|
|
@ -760,7 +790,71 @@ class ProviderManager:
|
|||
)
|
||||
)
|
||||
|
||||
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
|
||||
# Add models that can be added
|
||||
for model in can_added_models:
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=model["model"],
|
||||
model_type=model["model_type"],
|
||||
credentials=None,
|
||||
current_credential_id=None,
|
||||
current_credential_name=None,
|
||||
available_model_credentials=model["available_model_credentials"],
|
||||
unadded_to_model_list=True,
|
||||
)
|
||||
)
|
||||
|
||||
return custom_model_configurations
|
||||
|
||||
def _get_and_decrypt_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
record_id: str,
|
||||
encrypted_config: str | None,
|
||||
secret_variables: list[str],
|
||||
cache_type: ProviderCredentialsCacheType,
|
||||
is_provider: bool = False,
|
||||
) -> dict:
|
||||
"""Get and decrypt credentials with caching."""
|
||||
credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=record_id,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
|
||||
# Try to get from cache first
|
||||
cached_credentials = credentials_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
# Parse encrypted config
|
||||
if not encrypted_config:
|
||||
return {}
|
||||
|
||||
if is_provider and not encrypted_config.startswith("{"):
|
||||
return {"openai_api_key": encrypted_config}
|
||||
|
||||
try:
|
||||
credentials = cast(dict, json.loads(encrypted_config))
|
||||
except JSONDecodeError:
|
||||
return {}
|
||||
|
||||
# Decrypt secret variables
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in secret_variables:
|
||||
if variable in credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
credentials.get(variable) or "",
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# Cache the decrypted credentials
|
||||
credentials_cache.set(credentials=credentials)
|
||||
return credentials
|
||||
|
||||
def _to_system_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
|
|
@ -968,18 +1062,6 @@ class ProviderManager:
|
|||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if load_balancing_model_config.name == "__delete__":
|
||||
# to calculate current model whether has invalidate lb configs
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={},
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
|
|
@ -1045,6 +1127,7 @@ class ProviderManager:
|
|||
model=provider_model_setting.model_name,
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any, Optional
|
|||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
|
|
@ -212,11 +213,10 @@ class Jieba(BaseKeyword):
|
|||
return sorted_chunk_indices[:k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.first()
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
|
||||
)
|
||||
document_segment = db.session.scalar(stmt)
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.add(document_segment)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -24,7 +25,7 @@ default_retrieval_model = {
|
|||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
|
@ -127,7 +128,8 @@ class RetrievalService:
|
|||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
|
|
@ -316,10 +318,8 @@ class RetrievalService:
|
|||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
|
|
@ -378,17 +378,13 @@ class RetrievalService:
|
|||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
.first()
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
|
|||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
|
|
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
|
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
|
@ -249,14 +249,14 @@ class AnalyticdbVectorOpenAPI:
|
|||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
|
|
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
|
|
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
|
|
|
|||
|
|
@ -228,8 +228,8 @@ class AnalyticdbVectorBySql:
|
|||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
_, vector, score, page_content, metadata = record
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
|
|
@ -260,7 +260,7 @@ class AnalyticdbVectorBySql:
|
|||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, page_content, metadata, score = record
|
||||
_, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
|
|||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
|
|||
distance = distances[index]
|
||||
metadata = dict(metadatas[index])
|
||||
score = 1 - distance
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import clickzetta # type: ignore
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clickzetta import Connection
|
||||
from clickzetta.connector.v0.connection import Connection # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
|
@ -701,7 +701,7 @@ class ClickzettaVector(BaseVector):
|
|||
len(data_rows),
|
||||
vector_dimension,
|
||||
)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError):
|
||||
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
|
||||
logger.exception("SQL template: %s", insert_sql)
|
||||
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
|
||||
|
|
@ -787,7 +787,7 @@ class ClickzettaVector(BaseVector):
|
|||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
|
|
@ -879,7 +879,7 @@ class ClickzettaVector(BaseVector):
|
|||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
|
|
@ -938,7 +938,7 @@ class ClickzettaVector(BaseVector):
|
|||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.exception("JSON parsing failed")
|
||||
# Fallback: extract document_id with regex
|
||||
|
||||
|
|
@ -956,7 +956,7 @@ class ClickzettaVector(BaseVector):
|
|||
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError):
|
||||
logger.exception("Full-text search failed")
|
||||
# Fallback to LIKE search if full-text search fails
|
||||
return self._search_by_like(query, **kwargs)
|
||||
|
|
@ -978,7 +978,7 @@ class ClickzettaVector(BaseVector):
|
|||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
|
|
|
|||
|
|
@ -212,10 +212,10 @@ class CouchbaseVector(BaseVector):
|
|||
|
||||
documents_to_insert = [
|
||||
{"text": text, "embedding": vector, "metadata": metadata}
|
||||
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
|
||||
for _, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
|
||||
]
|
||||
for doc, id in zip(documents_to_insert, uuids):
|
||||
result = self._scope.collection(self._collection_name).upsert(id, doc)
|
||||
_ = self._scope.collection(self._collection_name).upsert(id, doc)
|
||||
|
||||
doc_ids.extend(uuids)
|
||||
|
||||
|
|
@ -241,7 +241,7 @@ class CouchbaseVector(BaseVector):
|
|||
"""
|
||||
try:
|
||||
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to delete documents, ids: %s", ids)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
|
|
@ -304,9 +304,9 @@ class CouchbaseVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 2)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
|
|||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
except requests.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
|
@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
|
|||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
|
|||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
|
|||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class MatrixoneVector(BaseVector):
|
|||
return client
|
||||
try:
|
||||
client.create_full_text_index()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to create full text index")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
return client
|
||||
|
|
|
|||
|
|
@ -376,7 +376,12 @@ class MilvusVector(BaseVector):
|
|||
if config.token:
|
||||
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
|
||||
else:
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
client = MilvusClient(
|
||||
uri=config.uri,
|
||||
user=config.user or "",
|
||||
password=config.password or "",
|
||||
db_name=config.database,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
|
|||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class OpenSearchVector(BaseVector):
|
|||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error executing vector search, query: %s", query)
|
||||
raise
|
||||
|
||||
|
|
@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
|
|||
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if hit["_score"] > score_threshold:
|
||||
if hit["_score"] >= score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
|
|
|
|||
|
|
@ -261,7 +261,7 @@ class OracleVector(BaseVector):
|
|||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue