mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 02:36:29 +08:00
refactor: migrate session.query to select API in console controllers (#34607)
This commit is contained in:
parent
ac8bd12609
commit
396b39dff9
@ -66,13 +66,13 @@ class WebhookTriggerApi(Resource):
|
|||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
# Get webhook trigger for this app and node
|
# Get webhook trigger for this app and node
|
||||||
webhook_trigger = (
|
webhook_trigger = session.scalar(
|
||||||
session.query(WorkflowWebhookTrigger)
|
select(WorkflowWebhookTrigger)
|
||||||
.where(
|
.where(
|
||||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||||
WorkflowWebhookTrigger.node_id == node_id,
|
WorkflowWebhookTrigger.node_id == node_id,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not webhook_trigger:
|
if not webhook_trigger:
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import logging
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
@ -86,8 +87,8 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def post(self, template_id: str):
|
def post(self, template_id: str):
|
||||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
template = (
|
template = session.scalar(
|
||||||
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
|
||||||
)
|
)
|
||||||
if not template:
|
if not template:
|
||||||
raise ValueError("Customized pipeline template not found.")
|
raise ValueError("Customized pipeline template not found.")
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -21,12 +22,12 @@ def plugin_permission_required(
|
|||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
permission = (
|
permission = session.scalar(
|
||||||
session.query(TenantPluginPermission)
|
select(TenantPluginPermission)
|
||||||
.where(
|
.where(
|
||||||
TenantPluginPermission.tenant_id == tenant_id,
|
TenantPluginPermission.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not permission:
|
if not permission:
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from flask import Response
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from graphon.variables.input_entities import VariableEntity
|
from graphon.variables.input_entities import VariableEntity
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
@ -80,11 +81,11 @@ class MCPAppApi(Resource):
|
|||||||
|
|
||||||
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
||||||
"""Get and validate MCP server and app in one query session"""
|
"""Get and validate MCP server and app in one query session"""
|
||||||
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
||||||
|
|
||||||
app = session.query(App).where(App.id == mcp_server.app_id).first()
|
app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1))
|
||||||
if not app:
|
if not app:
|
||||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
||||||
|
|
||||||
@ -190,12 +191,12 @@ class MCPAppApi(Resource):
|
|||||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
||||||
"""Get end user - manages its own database session"""
|
"""Get end user - manages its own database session"""
|
||||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
return (
|
return session.scalar(
|
||||||
session.query(EndUser)
|
select(EndUser)
|
||||||
.where(EndUser.tenant_id == tenant_id)
|
.where(EndUser.tenant_id == tenant_id)
|
||||||
.where(EndUser.session_id == mcp_server_id)
|
.where(EndUser.session_id == mcp_server_id)
|
||||||
.where(EndUser.type == "mcp")
|
.where(EndUser.type == "mcp")
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_end_user(
|
def _create_end_user(
|
||||||
|
|||||||
@ -555,7 +555,7 @@ class TestWorkflowTriggerEndpoints:
|
|||||||
|
|
||||||
trigger = MagicMock()
|
trigger = MagicMock()
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.query.return_value.where.return_value.first.return_value = trigger
|
session.scalar.return_value = trigger
|
||||||
|
|
||||||
class DummySessionCtx:
|
class DummySessionCtx:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|||||||
@ -444,7 +444,7 @@ class TestMCPAppApi:
|
|||||||
)
|
)
|
||||||
|
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.query().where().first.side_effect = [server, app]
|
session.scalar.side_effect = [server, app]
|
||||||
|
|
||||||
result_server, result_app = api._get_mcp_server_and_app("server-1", session)
|
result_server, result_app = api._get_mcp_server_and_app("server-1", session)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user