Merge branch 'main' into feat/memory-orchestration-be

This commit is contained in:
Stream 2025-09-23 17:43:52 +08:00
commit b7b5b0b8d0
No known key found for this signature in database
GPG Key ID: 033728094B100D70
59 changed files with 917 additions and 753 deletions

View File

@ -1,20 +1,11 @@
import logging
import psycogreen.gevent as pscycogreen_gevent # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore
_logger = logging.getLogger(__name__)
def _log(message: str):
_logger.debug(message)
# grpc gevent
grpc_gevent.init_gevent()
_log("gRPC patched with gevent.")
print("gRPC patched with gevent.", flush=True) # noqa: T201
pscycogreen_gevent.patch_psycopg()
_log("psycopg2 patched with gevent.")
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
from app import app, celery

View File

@ -1448,41 +1448,52 @@ def transform_datasource_credentials():
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if notion_plugin_id not in installed_plugins_ids:
if notion_plugin_unique_identifier:
# install notion plugin
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
auth_count = 0
for notion_tenant_credential in notion_tenant_credentials:
auth_count += 1
# get credential oauth params
access_token = notion_tenant_credential.access_token
# notion info
notion_info = notion_tenant_credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
new_credentials = {
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
"workspace_id": workspace_id,
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
}
datasource_provider = DatasourceProvider(
provider="notion_datasource",
tenant_id=tenant_id,
plugin_id=notion_plugin_id,
auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if notion_plugin_id not in installed_plugins_ids:
if notion_plugin_unique_identifier:
# install notion plugin
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
auth_count = 0
for notion_tenant_credential in notion_tenant_credentials:
auth_count += 1
# get credential oauth params
access_token = notion_tenant_credential.access_token
# notion info
notion_info = notion_tenant_credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
new_credentials = {
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
"workspace_id": workspace_id,
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
}
datasource_provider = DatasourceProvider(
provider="notion_datasource",
tenant_id=tenant_id,
plugin_id=notion_plugin_id,
auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
db.session.add(datasource_provider)
deal_notion_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
)
db.session.add(datasource_provider)
deal_notion_count += 1
continue
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
@ -1495,37 +1506,48 @@ def transform_datasource_credentials():
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {
"firecrawl_api_key": api_key,
"base_url": base_url,
}
datasource_provider = DatasourceProvider(
provider="firecrawl",
tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {
"firecrawl_api_key": api_key,
"base_url": base_url,
}
datasource_provider = DatasourceProvider(
provider="firecrawl",
tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
continue
db.session.commit()
# deal jina credentials
deal_jina_count = 0
@ -1538,36 +1560,45 @@ def transform_datasource_credentials():
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_jina_count += 1
except Exception as e:
click.echo(
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
)
db.session.add(datasource_provider)
deal_jina_count += 1
continue
db.session.commit()
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -5,7 +5,7 @@ import logging
import os
import time
import requests
import httpx
logger = logging.getLogger(__name__)
@ -30,10 +30,10 @@ class NacosHttpClient:
params = {}
try:
self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status()
return response.text
except requests.RequestException as e:
except httpx.RequestError as e:
return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
@ -78,7 +78,7 @@ class NacosHttpClient:
params = {"username": self.username, "password": self.password}
url = "http://" + self.server + "/nacos/v1/auth/login"
try:
resp = requests.request("POST", url, headers=None, params=params)
resp = httpx.request("POST", url, headers=None, params=params)
resp.raise_for_status()
response_data = resp.json()
self.token = response_data.get("accessToken")

View File

@ -1,6 +1,6 @@
import logging
import requests
import httpx
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields
@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except requests.HTTPError as e:
except httpx.HTTPStatusError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)
@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id)
except requests.HTTPError as e:
except httpx.HTTPStatusError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)

View File

@ -1,6 +1,6 @@
import logging
import requests
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
@ -101,8 +101,10 @@ class OAuthCallback(Resource):
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e:
error_text = e.response.text if e.response else str(e)
except httpx.RequestError as e:
error_text = str(e)
if isinstance(e, httpx.HTTPStatusError):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400

View File

@ -1,7 +1,7 @@
import json
import logging
import requests
import httpx
from flask_restx import Resource, fields, reqparse
from packaging import version
@ -57,7 +57,11 @@ class VersionApi(Resource):
return result
try:
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
timeout=httpx.Timeout(connect=3, read=10),
)
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"]

View File

@ -90,29 +90,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs

View File

@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def single_loop_generate(
@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def _generate_worker(

View File

@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
db.session.close()
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs

View File

@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs

View File

@ -1,3 +1,4 @@
import time
from collections.abc import Mapping
from typing import Any, cast
@ -119,15 +120,81 @@ class WorkflowBasedAppRunner:
return graph
def _get_graph_and_variable_pool_of_single_iteration(
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).
Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)
Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
Get graph and variable pool for single node execution (iteration or loop).
Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')
Returns:
A tuple containing (graph, variable_pool)
"""
# fetch workflow graph
graph_config = workflow.graph_dict
@ -145,18 +212,22 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
@ -190,30 +261,26 @@ class WorkflowBasedAppRunner:
raise ValueError("graph not found in workflow")
# fetch node config from node id
iteration_node_config = None
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
iteration_node_config = node
target_node_config = node
break
if not iteration_node_config:
raise ValueError("iteration node id not found in workflow graph")
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
# Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config
graph_config=workflow.graph_dict, config=target_node_config
)
except NotImplementedError:
variable_mapping = {}
@ -234,120 +301,44 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break
if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")
# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event

View File

@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel):
"""
Get custom provider record.
"""
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
Provider.provider_name.in_(self._get_provider_names()),
)
return session.execute(stmt).scalar_one_or_none()
@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(ProviderCredential.id).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.credential_name == credential_name,
)
if exclude_id:
@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel):
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel):
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
),
)
@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel):
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel):
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def _get_provider_names(self):
"""
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None):
"""
Add custom provider credentials.
@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel):
# Find all load balancing configs that use this credential_id
stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == credential_source,
)
@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel):
# Check if this credential is used in load balancing configs
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
)
@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the provider record
count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name,
@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
)
@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the custom model record
count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel):
"""
Get provider model setting.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel):
return
def _switch(s: Session):
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
)
preferred_model_provider = s.execute(stmt).scalars().first()

View File

@ -8,7 +8,7 @@ from collections import deque
from collections.abc import Sequence
from datetime import datetime
import requests
import httpx
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
@ -65,13 +65,13 @@ class TraceClient:
def api_check(self):
try:
response = requests.head(self.endpoint, timeout=5)
response = httpx.head(self.endpoint, timeout=5)
if response.status_code == 405:
return True
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except requests.RequestException as e:
except httpx.RequestError as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@ -513,6 +513,21 @@ class ProviderManager:
return provider_name_to_provider_load_balancing_model_configs_dict
@staticmethod
def _get_provider_names(provider_name: str) -> list[str]:
"""
provider_name: `openai` or `langgenius/openai/openai`
return: [`openai`, `langgenius/openai/openai`]
"""
provider_names = [provider_name]
model_provider_id = ModelProviderID(provider_name)
if model_provider_id.is_langgenius():
if "/" in provider_name:
provider_names.append(model_provider_id.provider_name)
else:
provider_names.append(str(model_provider_id))
return provider_names
@staticmethod
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
"""
@ -525,7 +540,10 @@ class ProviderManager:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(ProviderCredential)
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
.where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
)
.order_by(ProviderCredential.created_at.desc())
)
@ -554,7 +572,7 @@ class ProviderManager:
select(ProviderModelCredential)
.where(
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_name,
ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
ProviderModelCredential.model_name == model_name,
ProviderModelCredential.model_type == model_type,
)

View File

@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None)
node_executions: list[NodeExecutionState] = Field(default_factory=list)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
@ -103,7 +104,8 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
def start(self) -> None:
"""Mark the graph execution as started."""
@ -172,6 +174,7 @@ class GraphExecution:
completed=self.completed,
aborted=self.aborted,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
)
@ -195,6 +198,7 @@ class GraphExecution:
self.completed = state.completed
self.aborted = state.aborted
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {
item.node_id: NodeExecution(
node_id=item.node_id,
@ -205,3 +209,7 @@ class GraphExecution:
)
for item in state.node_executions
}
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1

View File

@ -3,11 +3,12 @@ Event handler implementations for different event types.
"""
import logging
from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
@ -122,13 +123,15 @@ class EventHandler:
"""
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
self._event_collector.collect(event)
# Collect the event only for the first attempt; retries remain silent
if is_initial_attempt:
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None:
@ -161,7 +164,7 @@ class EventHandler:
node_execution.mark_taken()
# Store outputs in variable pool
self._store_node_outputs(event)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event)
@ -191,7 +194,7 @@ class EventHandler:
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
self._update_response_outputs(event.node_run_result.outputs)
# Collect the event
self._event_collector.collect(event)
@ -207,6 +210,7 @@ class EventHandler:
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
result = self._error_handler.handle_node_failure(event)
@ -227,10 +231,40 @@ class EventHandler:
Args:
event: The node exception event
"""
# Node continues via fail-branch, so it's technically "succeeded"
# Node continues via fail-branch/default-value, treat as completion
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
node = self._graph.nodes[event.node_id]
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
self._state_manager.finish_execution(event.node_id)
# Collect the exception event for observers
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None:
"""
@ -242,21 +276,31 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
# Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)
# Emit retry event for observers
self._event_collector.collect(event)
# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
"""
Store node outputs in the variable pool.
Args:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in event.node_run_result.outputs.items():
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
"""Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items():
for key, value in outputs.items():
if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "")
if existing:

View File

@ -5,6 +5,7 @@ Unified event manager for collecting and emitting events.
import threading
import time
from collections.abc import Generator
from contextlib import contextmanager
from typing import final
from core.workflow.graph_events import GraphEngineEvent
@ -51,43 +52,23 @@ class ReadWriteLock:
"""Release a write lock."""
self._read_ready.release()
def read_lock(self) -> "ReadLockContext":
@contextmanager
def read_lock(self):
"""Return a context manager for read locking."""
return ReadLockContext(self)
self.acquire_read()
try:
yield
finally:
self.release_read()
def write_lock(self) -> "WriteLockContext":
@contextmanager
def write_lock(self):
"""Return a context manager for write locking."""
return WriteLockContext(self)
@final
class ReadLockContext:
"""Context manager for read locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "ReadLockContext":
self._lock.acquire_read()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_read()
@final
class WriteLockContext:
"""Context manager for write locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "WriteLockContext":
self._lock.acquire_write()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_write()
self.acquire_write()
try:
yield
finally:
self.release_write()
@final

View File

@ -23,6 +23,7 @@ from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
@ -260,12 +261,23 @@ class GraphEngine:
if self._graph_execution.error:
raise self._graph_execution.error
else:
yield GraphRunSucceededEvent(
outputs=self._graph_runtime_state.outputs,
)
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
else:
yield GraphRunSucceededEvent(
outputs=outputs,
)
except Exception as e:
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
raise
finally:

View File

@ -15,6 +15,7 @@ from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0:
@ -138,6 +146,12 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
# Node-level events
# Retry before Started because Retry subclasses Started;
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStartedEvent):
self.node_count += 1
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
@ -167,11 +181,6 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
self.logger.warning(" Error: %s", event.error)
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStreamChunkEvent):
# Log stream chunks at debug level to avoid spam
final_indicator = " (FINAL)" if event.is_final else ""

View File

@ -19,6 +19,7 @@ from core.workflow.enums import (
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
)
from core.workflow.node_events import (
@ -372,43 +373,16 @@ class IterationNode(Node):
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
iteration_node_ids = set()
# init graph
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis
from core.workflow.entities import VariablePool
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
iteration_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
@ -444,9 +418,7 @@ class IterationNode(Node):
variable_mapping.update(sub_node_variable_mapping)
# remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
return variable_mapping
@ -485,7 +457,7 @@ class IterationNode(Node):
if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, GraphRunSucceededEvent):
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector)
if result is None:
outputs.append(None)

View File

@ -63,7 +63,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False

View File

@ -1,3 +1,4 @@
import contextlib
import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
@ -127,11 +128,13 @@ class LoopNode(Node):
try:
reach_break_condition = False
if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
with contextlib.suppress(ValueError):
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if reach_break_condition:
loop_count = 0
cost_tokens = 0
@ -295,42 +298,11 @@ class LoopNode(Node):
variable_mapping = {}
# init graph
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Extract loop node IDs statically from graph_config
# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Create minimal GraphRuntimeState for static analysis
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
loop_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not loop_graph:
raise ValueError("loop graph not found")
# Get node configs from graph_config instead of non-existent node_id_config_mapping
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
@ -371,12 +343,35 @@ class LoopNode(Node):
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
}
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
This method statically analyzes the graph configuration to find all nodes
that are part of the specified loop, without creating actual node instances.
:param graph_config: the complete graph configuration
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
loop_node_ids.add(node_id)
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""

View File

@ -33,6 +33,7 @@ file_fields = {
"created_by": fields.String,
"created_at": TimestampField,
"preview_url": fields.String,
"source_url": fields.String,
}

View File

@ -1,10 +1,32 @@
import psycogreen.gevent as pscycogreen_gevent # type: ignore
from gevent import events as gevent_events
from grpc.experimental import gevent as grpc_gevent # type: ignore
# NOTE(QuantumGhost): here we cannot use post_fork to patch gRPC, as
# grpc_gevent.init_gevent must be called after patching stdlib.
# Gunicorn calls `post_init` before applying monkey patch.
# Use `post_init` to setup gRPC gevent support would cause deadlock and
# some other weird issues.
#
# ref:
# - https://github.com/grpc/grpc/blob/62533ea13879d6ee95c6fda11ec0826ca822c9dd/src/python/grpcio/grpc/experimental/gevent.py
# - https://github.com/gevent/gevent/issues/2060#issuecomment-3016768668
# - https://github.com/benoitc/gunicorn/blob/master/gunicorn/arbiter.py#L607-L613
def post_fork(server, worker):
def post_patch(event):
# this function is only called for gevent worker.
# from gevent docs (https://www.gevent.org/api/gevent.monkey.html):
# You can also subscribe to the events to provide additional patching beyond what gevent distributes, either for
# additional standard library modules, or for third-party packages. The suggested time to do this patching is in
# the subscriber for gevent.events.GeventDidPatchBuiltinModulesEvent.
if not isinstance(event, gevent_events.GeventDidPatchBuiltinModulesEvent):
return
# grpc gevent
grpc_gevent.init_gevent()
server.log.info("gRPC patched with gevent.")
print("gRPC patched with gevent.", flush=True) # noqa: T201
pscycogreen_gevent.patch_psycopg()
server.log.info("psycopg2 patched with gevent.")
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
gevent_events.subscribers.append(post_patch)

View File

@ -1,7 +1,7 @@
import urllib.parse
from dataclasses import dataclass
import requests
import httpx
@dataclass
@ -58,7 +58,7 @@ class GitHubOAuth(OAuth):
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@ -70,11 +70,11 @@ class GitHubOAuth(OAuth):
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"token {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
@ -112,7 +112,7 @@ class GoogleOAuth(OAuth):
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@ -124,7 +124,7 @@ class GoogleOAuth(OAuth):
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"Bearer {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()

View File

@ -1,7 +1,7 @@
import urllib.parse
from typing import Any
import requests
import httpx
from flask_login import current_user
from sqlalchemy import select
@ -43,7 +43,7 @@ class NotionOAuth(OAuthDataSource):
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@ -239,7 +239,7 @@ class NotionOAuth(OAuthDataSource):
"Notion-Version": "2022-06-28",
}
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))
@ -254,7 +254,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json()
if response.status_code != 200:
message = response_json.get("message", "unknown error")
@ -270,7 +270,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json()
if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"]
@ -294,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))

View File

@ -47,7 +47,7 @@ def upgrade():
sa.Column('plugin_id', sa.String(length=255), nullable=False),
sa.Column('auth_type', sa.String(length=255), nullable=False),
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('avatar_url', sa.String(length=255), nullable=True),
sa.Column('avatar_url', sa.Text(), nullable=True),
sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('expires_at', sa.Integer(), server_default='-1', nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),

View File

@ -35,7 +35,7 @@ class DatasourceProvider(Base):
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
avatar_url: Mapped[str] = db.Column(db.Text, nullable=True, default="default")
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1")

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "2.0.0-beta2"
version = "1.9.0"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -1,6 +1,6 @@
import json
import requests
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -36,7 +36,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import json
import requests
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import json
import requests
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@ -1,7 +1,7 @@
import json
from urllib.parse import urljoin
import requests
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class WatercrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
def _get_request(self, url, headers):
return requests.get(url, headers=headers)
return httpx.get(url, headers=headers)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@ -83,7 +83,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting.
"""
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False

View File

@ -1,6 +1,6 @@
import os
import requests
import httpx
class OperationService:
@ -12,7 +12,7 @@ class OperationService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)
response = httpx.request(method, url, json=json, params=params, headers=headers)
return response.json()

View File

@ -3,7 +3,7 @@ import json
from dataclasses import dataclass
from typing import Any
import requests
import httpx
from flask_login import current_user
from core.helper import encrypter
@ -216,7 +216,7 @@ class WebsiteService:
@classmethod
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
if not request.options.crawl_sub_pages:
response = requests.get(
response = httpx.get(
f"https://r.jina.ai/{request.url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
@ -224,7 +224,7 @@ class WebsiteService:
raise ValueError("Failed to crawl:")
return {"status": "active", "data": response.json().get("data")}
else:
response = requests.post(
response = httpx.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={
"url": request.url,
@ -287,7 +287,7 @@ class WebsiteService:
@classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
response = requests.post(
response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
@ -303,7 +303,7 @@ class WebsiteService:
}
if crawl_status_data["status"] == "completed":
response = requests.post(
response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
@ -362,7 +362,7 @@ class WebsiteService:
@classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
if not job_id:
response = requests.get(
response = httpx.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
@ -371,7 +371,7 @@ class WebsiteService:
return dict(response.json().get("data", {}))
else:
# Get crawl status first
status_response = requests.post(
status_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
@ -381,7 +381,7 @@ class WebsiteService:
raise ValueError("Crawl job is not completed")
# Get processed data
data_response = requests.post(
data_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},

View File

@ -1,8 +1,8 @@
import os
from typing import Literal
import httpx
import pytest
import requests
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
from core.tools.entities.common_entities import I18nObject
@ -27,13 +27,11 @@ class MockedHttp:
@classmethod
def requests_request(
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> requests.Response:
) -> httpx.Response:
"""
Mocked requests.request
Mocked httpx.request
"""
request = requests.PreparedRequest()
request.method = method
request.url = url
request = httpx.Request(method, url)
if url.endswith("/tools"):
content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
code=0, message="success", data=cls.list_tools()
@ -41,8 +39,7 @@ class MockedHttp:
else:
raise ValueError("")
response = requests.Response()
response.status_code = 200
response = httpx.Response(status_code=200)
response.request = request
response._content = content.encode("utf-8")
return response
@ -54,7 +51,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK_SWITCH:
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
def unpatch():
monkeypatch.undo()

View File

@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment
import os
import time
import requests
import httpx
from clickzetta import connect
@ -66,7 +66,7 @@ def test_dify_api():
max_retries = 30
for i in range(max_retries):
try:
response = requests.get(f"{base_url}/console/api/health")
response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200:
print("✓ Dify API is ready")
break

View File

@ -173,7 +173,7 @@ class DifyTestContainers:
# Start Dify Plugin Daemon container for plugin management
# Dify Plugin Daemon provides plugin lifecycle management and execution
logger.info("Initializing Dify Plugin Daemon container...")
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.2.0-local")
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local")
self.dify_plugin_daemon.with_exposed_ports(5002)
self.dify_plugin_daemon.env = {
"DB_HOST": db_host,

View File

@ -201,9 +201,9 @@ class TestOAuthCallback:
mock_db.session.rollback = MagicMock()
# Import the real requests module to create a proper exception
import requests
import httpx
request_exception = requests.exceptions.RequestException("OAuth error")
request_exception = httpx.RequestError("OAuth error")
request_exception.response = MagicMock()
request_exception.response.text = str(exception)

View File

@ -0,0 +1,120 @@
"""Tests for graph engine event handlers."""
from __future__ import annotations
from datetime import datetime
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
from core.workflow.graph_engine.event_management.event_manager import EventManager
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig
class _StubEdgeProcessor:
"""Minimal edge processor stub for tests."""
class _StubErrorHandler:
"""Minimal error handler stub for tests."""
class _StubNode:
"""Simple node stub exposing the attributes needed by the state manager."""
def __init__(self, node_id: str) -> None:
self.id = node_id
self.state = NodeState.UNKNOWN
self.title = "Stub Node"
self.execution_type = NodeExecutionType.EXECUTABLE
self.error_strategy = None
self.retry_config = RetryConfig()
self.retry = False
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
"""Construct an EventHandler with in-memory dependencies for testing."""
node = _StubNode(node_id)
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
variable_pool = VariablePool()
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_execution = GraphExecution(workflow_id="test-workflow")
event_manager = EventManager()
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
handler = EventHandler(
graph=graph,
graph_runtime_state=runtime_state,
graph_execution=graph_execution,
response_coordinator=response_coordinator,
event_collector=event_manager,
edge_processor=_StubEdgeProcessor(),
state_manager=state_manager,
error_handler=_StubErrorHandler(),
)
return handler, event_manager, graph_execution
def test_retry_does_not_emit_additional_start_event() -> None:
"""Ensure retry attempts do not produce duplicate start events."""
node_id = "test-node"
handler, event_manager, graph_execution = _build_event_handler(node_id)
execution_id = "exec-1"
node_type = NodeType.CODE
start_time = datetime.utcnow()
start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(start_event)
retry_event = NodeRunRetryEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
error="boom",
retry_index=1,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="boom",
error_type="TestError",
),
)
handler.dispatch(retry_event)
# Simulate the node starting execution again after retry
second_start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(second_start_event)
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
node_execution = graph_execution.get_or_create_node_execution(node_id)
assert node_execution.retry_count == 1

View File

@ -10,11 +10,18 @@ import time
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from core.workflow.enums import ErrorStrategy
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
from core.workflow.graph_events import (
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType
# Import the test framework from the new module
from .test_mock_config import MockConfigBuilder
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
@ -721,3 +728,39 @@ def test_event_sequence_validation_with_table_tests():
else:
assert result.event_sequence_match is True
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
def test_graph_run_emits_partial_success_when_node_failure_recovered():
runner = TableTestRunner()
fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
fixture_data=fixture_data,
query="hello",
use_mock_factory=True,
mock_config=mock_config,
)
llm_node = graph.nodes["llm"]
base_node_data = llm_node.get_base_node_data()
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
engine = GraphEngine(
workflow_id="test_workflow",
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
)
events = list(engine.run())
assert isinstance(events[-1], GraphRunPartialSucceededEvent)
partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
assert partial_event.exceptions_count == 1
assert partial_event.outputs.get("answer") == "fallback response"
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)

View File

@ -1,65 +0,0 @@
import pytest
pytest.skip(
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
allow_module_level=True,
)
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11
def test_retry_failed():
"""retry failed with success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

View File

@ -1,8 +1,8 @@
import urllib.parse
from unittest.mock import MagicMock, patch
import httpx
import pytest
import requests
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
({}, None, True),
],
)
@patch("requests.post")
@patch("httpx.post")
def test_should_retrieve_access_token(
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
):
@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
),
],
)
@patch("requests.get")
@patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
user_response = MagicMock()
user_response.json.return_value = user_data
@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest):
assert user_info.name == user_data["name"]
assert user_info.email == expected_email
@patch("requests.get")
@patch("httpx.get")
def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = requests.exceptions.RequestException("Network error")
mock_get.side_effect = httpx.RequestError("Network error")
with pytest.raises(requests.exceptions.RequestException):
with pytest.raises(httpx.RequestError):
oauth.get_raw_user_info("test_token")
@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({}, None, True),
],
)
@patch("requests.post")
@patch("httpx.post")
def test_should_retrieve_access_token(
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
):
@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
],
)
@patch("requests.get")
@patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
mock_response.json.return_value = user_data
mock_get.return_value = mock_response
@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest):
@pytest.mark.parametrize(
"exception_type",
[
requests.exceptions.HTTPError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
httpx.HTTPError,
httpx.ConnectError,
httpx.TimeoutException,
],
)
@patch("requests.get")
@patch("httpx.get")
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = exception_type("Error")

View File

@ -6,8 +6,8 @@ import json
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch
import httpx
import pytest
import requests
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService
@ -26,7 +26,7 @@ class TestAuthIntegration:
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
"""Test complete authentication flow: request → validation → encryption → storage"""
@ -47,7 +47,7 @@ class TestAuthIntegration:
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_cross_component_integration(self, mock_http):
"""Test factory → provider → HTTP call integration"""
mock_http.return_value = self._create_success_response()
@ -97,7 +97,7 @@ class TestAuthIntegration:
assert "another_secret" not in factory_str
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
"""Test concurrent authentication creation safety"""
@ -142,31 +142,31 @@ class TestAuthIntegration:
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_http_error_handling(self, mock_http):
"""Test proper HTTP error handling"""
mock_response = Mock()
mock_response.status_code = 401
mock_response.text = '{"error": "Unauthorized"}'
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
mock_http.return_value = mock_response
# PT012: Split into single statement for pytest.raises
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
with pytest.raises((requests.exceptions.HTTPError, Exception)):
with pytest.raises((httpx.HTTPError, Exception)):
factory.validate_credentials()
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_network_failure_recovery(self, mock_http, mock_session):
"""Test system recovery from network failures"""
mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
mock_http.side_effect = httpx.RequestError("Network timeout")
mock_session.add = Mock()
mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
with pytest.raises(requests.exceptions.RequestException):
with pytest.raises(httpx.RequestError):
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
mock_session.commit.assert_not_called()

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
import requests
from services.auth.firecrawl.firecrawl import FirecrawlAuth
@ -64,7 +64,7 @@ class TestFirecrawlAuth:
FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -95,7 +95,7 @@ class TestFirecrawlAuth:
(500, "Internal server error"),
],
)
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
@ -115,7 +115,7 @@ class TestFirecrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
@ -134,13 +134,13 @@ class TestFirecrawlAuth:
@pytest.mark.parametrize(
("exception_type", "exception_message"),
[
(requests.ConnectionError, "Network error"),
(requests.Timeout, "Request timeout"),
(requests.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"),
(httpx.ConnectError, "Network error"),
(httpx.TimeoutException, "Request timeout"),
(httpx.ReadTimeout, "Read timeout"),
(httpx.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message)
@ -162,7 +162,7 @@ class TestFirecrawlAuth:
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
@ -179,12 +179,12 @@ class TestFirecrawlAuth:
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.requests.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
with pytest.raises(requests.Timeout) as exc_info:
with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
import requests
from services.auth.jina.jina import JinaAuth
@ -35,7 +35,7 @@ class TestJinaAuth:
JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -53,7 +53,7 @@ class TestJinaAuth:
json={"url": "https://example.com"},
)
@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error"""
mock_response = MagicMock()
@ -68,7 +68,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error"""
mock_response = MagicMock()
@ -83,7 +83,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error"""
mock_response = MagicMock()
@ -98,7 +98,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response"""
mock_response = MagicMock()
@ -114,7 +114,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response"""
mock_response = MagicMock()
@ -130,15 +130,15 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina.requests.post")
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors"""
mock_post.side_effect = requests.ConnectionError("Network error")
mock_post.side_effect = httpx.ConnectError("Network error")
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(requests.ConnectionError):
with pytest.raises(httpx.ConnectError):
auth.validate_credentials()
def test_should_not_expose_api_key_in_error_messages(self):

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
import requests
from services.auth.watercrawl.watercrawl import WatercrawlAuth
@ -64,7 +64,7 @@ class TestWatercrawlAuth:
WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
(500, "Internal server error"),
],
)
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
@ -126,13 +126,13 @@ class TestWatercrawlAuth:
@pytest.mark.parametrize(
("exception_type", "exception_message"),
[
(requests.ConnectionError, "Network error"),
(requests.Timeout, "Request timeout"),
(requests.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"),
(httpx.ConnectError, "Network error"),
(httpx.TimeoutException, "Request timeout"),
(httpx.ReadTimeout, "Read timeout"),
(httpx.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message)
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
],
)
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock()
@ -193,12 +193,12 @@ class TestWatercrawlAuth:
# Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url
@patch("services.auth.watercrawl.watercrawl.requests.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
with pytest.raises(requests.Timeout) as exc_info:
with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message

View File

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.11, <3.13"
resolution-markers = [
"python_full_version >= '3.12.4' and sys_platform == 'linux'",
@ -1273,7 +1273,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "2.0.0b2"
version = "1.9.0"
source = { virtual = "." }
dependencies = [
{ name = "arize-phoenix-otel" },

View File

@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:2.0.0-beta.2
image: langgenius/dify-api:1.9.0
restart: always
environment:
# Use the shared environment variables.
@ -31,7 +31,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:2.0.0-beta.2
image: langgenius/dify-api:1.9.0
restart: always
environment:
# Use the shared environment variables.
@ -58,7 +58,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:2.0.0-beta.2
image: langgenius/dify-api:1.9.0
restart: always
environment:
# Use the shared environment variables.
@ -76,7 +76,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:2.0.0-beta.2
image: langgenius/dify-web:1.9.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -177,7 +177,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.3.0b1-local
image: langgenius/dify-plugin-daemon:0.3.0-local
restart: always
environment:
# Use the shared environment variables.

View File

@ -20,7 +20,17 @@ services:
ports:
- "${EXPOSE_POSTGRES_PORT:-5432}:5432"
healthcheck:
test: [ 'CMD', 'pg_isready', '-h', 'db', '-U', '${PGUSER:-postgres}', '-d', '${POSTGRES_DB:-dify}' ]
test:
[
"CMD",
"pg_isready",
"-h",
"db",
"-U",
"${PGUSER:-postgres}",
"-d",
"${POSTGRES_DB:-dify}",
]
interval: 1s
timeout: 3s
retries: 30
@ -41,7 +51,11 @@ services:
ports:
- "${EXPOSE_REDIS_PORT:-6379}:6379"
healthcheck:
test: [ 'CMD-SHELL', 'redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG' ]
test:
[
"CMD-SHELL",
"redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG",
]
# The DifySandbox
sandbox:
@ -65,13 +79,13 @@ services:
- ./volumes/sandbox/dependencies:/dependencies
- ./volumes/sandbox/conf:/conf
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ]
test: ["CMD", "curl", "-f", "http://localhost:8194/health"]
networks:
- ssrf_proxy_network
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.3.0b1-local
image: langgenius/dify-plugin-daemon:0.3.0-local
restart: always
env_file:
- ./middleware.env
@ -143,7 +157,12 @@ services:
volumes:
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
entrypoint: [ "sh", "-c", "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
entrypoint:
[
"sh",
"-c",
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
env_file:
- ./middleware.env
environment:

View File

@ -593,7 +593,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:2.0.0-beta.2
image: langgenius/dify-api:1.9.0
restart: always
environment:
# Use the shared environment variables.
@ -622,7 +622,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:2.0.0-beta.2
image: langgenius/dify-api:1.9.0
restart: always
environment:
# Use the shared environment variables.
@ -649,7 +649,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:2.0.0-beta.2
image: langgenius/dify-api:1.9.0
restart: always
environment:
# Use the shared environment variables.
@ -667,7 +667,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:2.0.0-beta.2
image: langgenius/dify-web:1.9.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -768,7 +768,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.3.0b1-local
image: langgenius/dify-plugin-daemon:0.3.0-local
restart: always
environment:
# Use the shared environment variables.

View File

@ -1,4 +0,0 @@
GET /console/api/spec/schema-definitions
Host: cloud-rag.dify.dev
authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiNzExMDZhYTQtZWJlMC00NGMzLWI4NWYtMWQ4Mjc5ZTExOGZmIiwiZXhwIjoxNzU2MTkyNDE4LCJpc3MiOiJDTE9VRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.Yx_TMdWVXCp5YEoQ8WR90lRhHHKggxAQvEl5RUnkZuc
###

View File

@ -124,7 +124,7 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
const handleAddGroup = useCallback(() => {
let maxInGroupName = 1
inputs.advanced_settings.groups.forEach((item) => {
const match = item.group_name.match(/(\d+)$/)
const match = /(\d+)$/.exec(item.group_name)
if (match) {
const num = Number.parseInt(match[1], 10)
if (num > maxInGroupName)

View File

@ -1,12 +1,18 @@
@import "preflight.css";
@tailwind base;
@tailwind components;
@import '../../themes/light.css';
@import '../../themes/dark.css';
@import "../../themes/manual-light.css";
@import "../../themes/manual-dark.css";
@import "../components/base/button/index.css";
@import "../components/base/action-button/index.css";
@import "../components/base/modal/index.css";
@tailwind base;
@tailwind components;
html {
color-scheme: light;
}
@ -680,10 +686,6 @@ button:focus-within {
display: none;
}
@import "../components/base/button/index.css";
@import "../components/base/action-button/index.css";
@import "../components/base/modal/index.css";
@tailwind utilities;
@layer utilities {

View File

@ -91,12 +91,10 @@ const remoteImageURLs = [hasSetWebPrefix ? new URL(`${process.env.NEXT_PUBLIC_WE
/** @type {import('next').NextConfig} */
const nextConfig = {
basePath: process.env.NEXT_PUBLIC_BASE_PATH || '',
webpack: (config, { dev, isServer }) => {
if (dev) {
config.plugins.push(codeInspectorPlugin({ bundler: 'webpack' }))
}
return config
turbopack: {
rules: codeInspectorPlugin({
bundler: 'turbopack'
})
},
productionBrowserSourceMaps: false, // enable browser source map generation during the production build
// Configure pageExtensions to include md and mdx
@ -112,6 +110,10 @@ const nextConfig = {
})),
},
experimental: {
optimizePackageImports: [
'@remixicon/react',
'@heroicons/react'
],
},
// fix all before production. Now it slow the develop speed.
eslint: {

View File

@ -1,6 +1,6 @@
{
"name": "dify-web",
"version": "2.0.0-beta2",
"version": "1.9.0",
"private": true,
"packageManager": "pnpm@10.16.0",
"engines": {
@ -19,7 +19,7 @@
"and_qq >= 14.9"
],
"scripts": {
"dev": "cross-env NODE_OPTIONS='--inspect' next dev",
"dev": "cross-env NODE_OPTIONS='--inspect' next dev --turbopack",
"build": "next build",
"build:docker": "next build && node scripts/optimize-standalone.js",
"start": "cp -r .next/static .next/standalone/.next/static && cp -r public .next/standalone/public && cross-env PORT=$npm_config_port HOSTNAME=$npm_config_host node .next/standalone/server.js",
@ -203,7 +203,7 @@
"autoprefixer": "^10.4.20",
"babel-loader": "^10.0.0",
"bing-translate-api": "^4.0.2",
"code-inspector-plugin": "^0.18.1",
"code-inspector-plugin": "1.2.9",
"cross-env": "^7.0.3",
"eslint": "^9.35.0",
"eslint-config-next": "15.5.0",

View File

@ -519,8 +519,8 @@ importers:
specifier: ^4.0.2
version: 4.1.0
code-inspector-plugin:
specifier: ^0.18.1
version: 0.18.3
specifier: 1.2.9
version: 1.2.9
cross-env:
specifier: ^7.0.3
version: 7.0.3
@ -1372,6 +1372,24 @@ packages:
'@clack/prompts@0.11.0':
resolution: {integrity: sha512-pMN5FcrEw9hUkZA4f+zLlzivQSeQf5dRGJjSUbvVYDLvpKCdQx5OaknvKzgbtXOizhP+SJJJjqEbOe55uKKfAw==}
'@code-inspector/core@1.2.9':
resolution: {integrity: sha512-A1w+G73HlTB6S8X6sA6tT+ziWHTAcTyH+7FZ1Sgd3ZLXF/E/jT+hgRbKposjXMwxcbodRc6hBG6UyiV+VxwE6Q==}
'@code-inspector/esbuild@1.2.9':
resolution: {integrity: sha512-DuyfxGupV43CN8YElIqynAniBtE86i037+3OVJYrm3jlJscXzbV98/kOzvu+VJQQvElcDgpgD6C/aGmPvFEiUg==}
'@code-inspector/mako@1.2.9':
resolution: {integrity: sha512-8N+MHdr64AnthLB4v+YGe8/9bgog3BnkxIW/fqX5iVS0X06mF7X1pxfZOD2bABVtv1tW25lRtNs5AgvYJs0vpg==}
'@code-inspector/turbopack@1.2.9':
resolution: {integrity: sha512-UVOUbqU6rpi5eOkrFamKrdeSWb0/OFFJQBaxbgs1RK5V5f4/iVwC5KjO2wkjv8cOGU4EppLfBVSBI1ysOo8S5A==}
'@code-inspector/vite@1.2.9':
resolution: {integrity: sha512-saIokJ3o3SdrHEgTEg1fbbowbKfh7J4mYtu0i1mVfah1b1UfdCF/iFHTEJ6SADMiY47TeNZTg0TQWTlU1AWPww==}
'@code-inspector/webpack@1.2.9':
resolution: {integrity: sha512-9YEykVrOIc0zMV7pyTyZhCprjScjn6gPPmxb4/OQXKCrP2fAm+NB188rg0s95e4sM7U3qRUpPA4NUH5F7Ogo+g==}
'@cspotcode/source-map-support@0.8.1':
resolution: {integrity: sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==}
engines: {node: '>=12'}
@ -4425,11 +4443,8 @@ packages:
resolution: {integrity: sha512-QVb0dM5HvG+uaxitm8wONl7jltx8dqhfU33DcqtOZcLSVIKSDDLDi7+0LbAKiyI8hD9u42m2YxXSkMGWThaecQ==}
engines: {iojs: '>= 1.0.0', node: '>= 0.12.0'}
code-inspector-core@0.18.3:
resolution: {integrity: sha512-60pT2cPoguMTUYdN1MMpjoPUnuF0ud/u7M2y+Vqit/bniLEit9dySEWAVxLU/Ukc5ILrDeLKEttc6fCMl9RUrA==}
code-inspector-plugin@0.18.3:
resolution: {integrity: sha512-d9oJXZUsnvfTaQDwFmDNA2F+AR/TXIxWg1rr8KGcEskltR2prbZsfuu1z70EAn4khpx0smfi/PvIIwNJQ7FAMw==}
code-inspector-plugin@1.2.9:
resolution: {integrity: sha512-PGp/AQ03vaajimG9rn5+eQHGifrym5CSNLCViPtwzot7FM3MqEkGNqcvimH0FVuv3wDOcP5KvETAUSLf1BE3HA==}
collapse-white-space@2.1.0:
resolution: {integrity: sha512-loKTxY1zCOuG4j9f6EPnuyyYkf58RnhhWTvRoZEokgB+WbdXehfjFviyOVYkqzEWz1Q5kRiZdBYS5SwxbQYwzw==}
@ -5055,9 +5070,6 @@ packages:
esast-util-from-js@2.0.1:
resolution: {integrity: sha512-8Ja+rNJ0Lt56Pcf3TAmpBZjmx8ZcK5Ts4cAzIOjsjevg9oSXJnl6SUQ2EevU8tv3h6ZLWmoKL5H4fgWvdvfETw==}
esbuild-code-inspector-plugin@0.18.3:
resolution: {integrity: sha512-FaPt5eFMtW1oXMWqAcqfAJByNagP1V/R9dwDDLQO29JmryMF35+frskTqy+G53whmTaVi19+TCrFqhNbMZH5ZQ==}
esbuild-register@3.6.0:
resolution: {integrity: sha512-H2/S7Pm8a9CL1uhp9OvjwrBh5Pvx0H8qVOxNu8Wed9Y7qv56MPtq+GGM8RJpq6glYJn9Wspr8uw7l55uyinNeg==}
peerDependencies:
@ -6413,8 +6425,8 @@ packages:
resolution: {integrity: sha512-MbjN408fEndfiQXbFQ1vnd+1NoLDsnQW41410oQBXiyXDMYH5z505juWa4KUE1LqxRC7DgOgZDbKLxHIwm27hA==}
engines: {node: '>=0.10'}
launch-ide@1.0.1:
resolution: {integrity: sha512-U7qBxSNk774PxWq4XbmRe0ThiIstPoa4sMH/OGSYxrFVvg8x3biXcF1fsH6wasDpEmEXMdINUrQhBdwsSgKyMg==}
launch-ide@1.2.0:
resolution: {integrity: sha512-7nXSPQOt3b2JT52Ge8jp4miFcY+nrUEZxNLWBzrEfjmByDTb9b5ytqMSwGhsNwY6Cntwop+6n7rWIFN0+S8PTw==}
layout-base@1.0.2:
resolution: {integrity: sha512-8h2oVEZNktL4BH2JCOI90iD1yXwL6iNW7KcCKT2QZgQJR2vbqDsldCTPRU9NifTCqHZci57XvQQ15YTu+sTYPg==}
@ -8693,9 +8705,6 @@ packages:
vfile@6.0.3:
resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==}
vite-code-inspector-plugin@0.18.3:
resolution: {integrity: sha512-178H73vbDUHE+JpvfAfioUHlUr7qXCYIEa2YNXtzenFQGOjtae59P1jjcxGfa6pPHEnOoaitb13K+0qxwhi/WA==}
vm-browserify@1.1.2:
resolution: {integrity: sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==}
@ -8754,9 +8763,6 @@ packages:
engines: {node: '>= 10.13.0'}
hasBin: true
webpack-code-inspector-plugin@0.18.3:
resolution: {integrity: sha512-3782rsJhBnRiw0IpR6EqnyGDQoiSq0CcGeLJ52rZXlszYCe8igXtcujq7OhI0byaivWQ1LW7sXKyMEoVpBhq0w==}
webpack-dev-middleware@6.1.3:
resolution: {integrity: sha512-A4ChP0Qj8oGociTs6UdlRUGANIGrCDL3y+pmQMc+dSsraXHCatFpmMey4mYELA+juqwUqwQsUgJJISXl1KWmiw==}
engines: {node: '>= 14.15.0'}
@ -9993,6 +9999,48 @@ snapshots:
picocolors: 1.1.1
sisteransi: 1.0.5
'@code-inspector/core@1.2.9':
dependencies:
'@vue/compiler-dom': 3.5.17
chalk: 4.1.2
dotenv: 16.6.1
launch-ide: 1.2.0
portfinder: 1.0.37
transitivePeerDependencies:
- supports-color
'@code-inspector/esbuild@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
transitivePeerDependencies:
- supports-color
'@code-inspector/mako@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
transitivePeerDependencies:
- supports-color
'@code-inspector/turbopack@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
'@code-inspector/webpack': 1.2.9
transitivePeerDependencies:
- supports-color
'@code-inspector/vite@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
chalk: 4.1.1
transitivePeerDependencies:
- supports-color
'@code-inspector/webpack@1.2.9':
dependencies:
'@code-inspector/core': 1.2.9
transitivePeerDependencies:
- supports-color
'@cspotcode/source-map-support@0.8.1':
dependencies:
'@jridgewell/trace-mapping': 0.3.9
@ -12799,7 +12847,7 @@ snapshots:
'@vue/compiler-core@3.5.17':
dependencies:
'@babel/parser': 7.28.0
'@babel/parser': 7.28.4
'@vue/shared': 3.5.17
entities: 4.5.0
estree-walker: 2.0.2
@ -13503,24 +13551,15 @@ snapshots:
co@4.6.0: {}
code-inspector-core@0.18.3:
code-inspector-plugin@1.2.9:
dependencies:
'@vue/compiler-dom': 3.5.17
'@code-inspector/core': 1.2.9
'@code-inspector/esbuild': 1.2.9
'@code-inspector/mako': 1.2.9
'@code-inspector/turbopack': 1.2.9
'@code-inspector/vite': 1.2.9
'@code-inspector/webpack': 1.2.9
chalk: 4.1.1
dotenv: 16.6.1
launch-ide: 1.0.1
portfinder: 1.0.37
transitivePeerDependencies:
- supports-color
code-inspector-plugin@0.18.3:
dependencies:
chalk: 4.1.1
code-inspector-core: 0.18.3
dotenv: 16.6.1
esbuild-code-inspector-plugin: 0.18.3
vite-code-inspector-plugin: 0.18.3
webpack-code-inspector-plugin: 0.18.3
transitivePeerDependencies:
- supports-color
@ -14160,12 +14199,6 @@ snapshots:
esast-util-from-estree: 2.0.0
vfile-message: 4.0.2
esbuild-code-inspector-plugin@0.18.3:
dependencies:
code-inspector-core: 0.18.3
transitivePeerDependencies:
- supports-color
esbuild-register@3.6.0(esbuild@0.25.0):
dependencies:
debug: 4.4.1
@ -16020,7 +16053,7 @@ snapshots:
dependencies:
language-subtag-registry: 0.3.23
launch-ide@1.0.1:
launch-ide@1.2.0:
dependencies:
chalk: 4.1.2
dotenv: 16.6.1
@ -18779,12 +18812,6 @@ snapshots:
'@types/unist': 3.0.3
vfile-message: 4.0.2
vite-code-inspector-plugin@0.18.3:
dependencies:
code-inspector-core: 0.18.3
transitivePeerDependencies:
- supports-color
vm-browserify@1.1.2: {}
void-elements@3.1.0: {}
@ -18855,12 +18882,6 @@ snapshots:
- bufferutil
- utf-8-validate
webpack-code-inspector-plugin@0.18.3:
dependencies:
code-inspector-core: 0.18.3
transitivePeerDependencies:
- supports-color
webpack-dev-middleware@6.1.3(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)):
dependencies:
colorette: 2.0.20

View File

@ -26,6 +26,9 @@
"paths": {
"@/*": [
"./*"
],
"~@/*": [
"./*"
]
}
},