mirror of https://github.com/langgenius/dify.git
Merge branch 'fix/full_text_search' into deploy/rag-dev
This commit is contained in:
commit
b55c354139
213
api/commands.py
213
api/commands.py
|
|
@ -1448,41 +1448,52 @@ def transform_datasource_credentials():
|
|||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
|
|
@ -1495,37 +1506,48 @@ def transform_datasource_credentials():
|
|||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
|
|
@ -1538,36 +1560,45 @@ def transform_datasource_credentials():
|
|||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import logging
|
|||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -30,10 +30,10 @@ class NacosHttpClient:
|
|||
params = {}
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||
|
|
@ -78,7 +78,7 @@ class NacosHttpClient:
|
|||
params = {"username": self.username, "password": self.password}
|
||||
url = "http://" + self.server + "/nacos/v1/auth/login"
|
||||
try:
|
||||
resp = requests.request("POST", url, headers=None, params=params)
|
||||
resp = httpx.request("POST", url, headers=None, params=params)
|
||||
resp.raise_for_status()
|
||||
response_data = resp.json()
|
||||
self.token = response_data.get("accessToken")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
|
@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
|
|||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
|
@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
|||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
|
|
@ -101,8 +101,10 @@ class OAuthCallback(Resource):
|
|||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
except httpx.RequestError as e:
|
||||
error_text = str(e)
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from packaging import version
|
||||
|
||||
|
|
@ -57,7 +57,11 @@ class VersionApi(Resource):
|
|||
return result
|
||||
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args["current_version"]
|
||||
|
|
|
|||
|
|
@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
|
|||
|
|
@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
|
@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
|
|
@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
|
|
|||
|
|
@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
|
|||
|
|
@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
|
|
@ -119,15 +120,81 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
(either single iteration or single loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
|
||||
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool, graph_runtime_state)
|
||||
|
||||
Raises:
|
||||
ValueError: If neither single_iteration_run nor single_loop_run is specified
|
||||
"""
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
start_at=time.time(),
|
||||
)
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
|
||||
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
Get graph and variable pool for single node execution (iteration or loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
node_id: The node ID to execute
|
||||
user_inputs: User inputs for the node
|
||||
graph_runtime_state: The graph runtime state
|
||||
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
|
||||
node_type_label: Label for error messages ('iteration' or 'loop')
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool)
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
|
|
@ -145,18 +212,22 @@ class WorkflowBasedAppRunner:
|
|||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in iteration
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
|
|
@ -190,30 +261,26 @@ class WorkflowBasedAppRunner:
|
|||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
iteration_node_config = node
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_version = iteration_node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
|
@ -234,120 +301,44 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in loop
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in loop
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
loop_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
loop_node_config = node
|
||||
break
|
||||
|
||||
if not loop_node_config:
|
||||
raise ValueError("loop node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
|
||||
node_version = loop_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=loop_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
load_into_variable_pool(
|
||||
self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
|
|
|||
|
|
@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
Get custom provider record.
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(provider_names),
|
||||
Provider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
|
|
@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
stmt = select(ProviderCredential.id).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.credential_name == credential_name,
|
||||
)
|
||||
if exclude_id:
|
||||
|
|
@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel):
|
|||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
|
|
@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel):
|
|||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel):
|
|||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
|
|
@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel):
|
|||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def _get_provider_names(self):
|
||||
"""
|
||||
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
|
|
@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
|
|
@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel):
|
|||
# Find all load balancing configs that use this credential_id
|
||||
stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source,
|
||||
)
|
||||
|
|
@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
|
|
@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel):
|
|||
# Check if this credential is used in load balancing configs
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
)
|
||||
|
|
@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel):
|
|||
# if this is the last credential, we need to delete the provider record
|
||||
count_stmt = select(func.count(ProviderCredential.id)).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
|
|
@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
|
|
@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
|
|
@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.credential_name == credential_name,
|
||||
|
|
@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
|
|
@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
|
|
@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
|
|
@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
)
|
||||
|
|
@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
|
|||
# if this is the last credential, we need to delete the custom model record
|
||||
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
|
|
@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
|
|
@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
|
|||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
|
|
@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
|
|
@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel):
|
|||
return
|
||||
|
||||
def _switch(s: Session):
|
||||
# get preferred provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from collections import deque
|
|||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
|
|
@ -65,13 +65,13 @@ class TraceClient:
|
|||
|
||||
def api_check(self):
|
||||
try:
|
||||
response = requests.head(self.endpoint, timeout=5)
|
||||
response = httpx.head(self.endpoint, timeout=5)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -513,6 +513,21 @@ class ProviderManager:
|
|||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_names(provider_name: str) -> list[str]:
|
||||
"""
|
||||
provider_name: `openai` or `langgenius/openai/openai`
|
||||
return: [`openai`, `langgenius/openai/openai`]
|
||||
"""
|
||||
provider_names = [provider_name]
|
||||
model_provider_id = ModelProviderID(provider_name)
|
||||
if model_provider_id.is_langgenius():
|
||||
if "/" in provider_name:
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
else:
|
||||
provider_names.append(str(model_provider_id))
|
||||
return provider_names
|
||||
|
||||
@staticmethod
|
||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
|
|
@ -525,7 +540,10 @@ class ProviderManager:
|
|||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
||||
.where(
|
||||
ProviderCredential.tenant_id == tenant_id,
|
||||
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||
)
|
||||
.order_by(ProviderCredential.created_at.desc())
|
||||
)
|
||||
|
||||
|
|
@ -554,7 +572,7 @@ class ProviderManager:
|
|||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_name,
|
||||
ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||
ProviderModelCredential.model_name == model_name,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
|
|||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
|
|
@ -103,7 +104,8 @@ class GraphExecution:
|
|||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
|
|
@ -172,6 +174,7 @@ class GraphExecution:
|
|||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
|
|
@ -195,6 +198,7 @@ class GraphExecution:
|
|||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
|
|
@ -205,3 +209,7 @@ class GraphExecution:
|
|||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ Event handler implementations for different event types.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
|
|
@ -122,13 +123,15 @@ class EventHandler:
|
|||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
# Collect the event only for the first attempt; retries remain silent
|
||||
if is_initial_attempt:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
|
|
@ -161,7 +164,7 @@ class EventHandler:
|
|||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
|
|
@ -191,7 +194,7 @@ class EventHandler:
|
|||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
|
@ -207,6 +210,7 @@ class EventHandler:
|
|||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
|
|
@ -227,10 +231,40 @@ class EventHandler:
|
|||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
# Node continues via fail-branch/default-value, treat as completion
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update response outputs if applicable
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Collect the exception event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
|
|
@ -242,21 +276,31 @@ class EventHandler:
|
|||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
# Finish the previous attempt before re-queuing the node
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Emit retry event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
# Re-queue node for execution
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
for key, value in outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Unified event manager for collecting and emitting events.
|
|||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
|
@ -51,43 +52,23 @@ class ReadWriteLock:
|
|||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
def read_lock(self) -> "ReadLockContext":
|
||||
@contextmanager
|
||||
def read_lock(self):
|
||||
"""Return a context manager for read locking."""
|
||||
return ReadLockContext(self)
|
||||
self.acquire_read()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_read()
|
||||
|
||||
def write_lock(self) -> "WriteLockContext":
|
||||
@contextmanager
|
||||
def write_lock(self):
|
||||
"""Return a context manager for write locking."""
|
||||
return WriteLockContext(self)
|
||||
|
||||
|
||||
@final
|
||||
class ReadLockContext:
|
||||
"""Context manager for read locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "ReadLockContext":
|
||||
self._lock.acquire_read()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_read()
|
||||
|
||||
|
||||
@final
|
||||
class WriteLockContext:
|
||||
"""Context manager for write locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "WriteLockContext":
|
||||
self._lock.acquire_write()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_write()
|
||||
self.acquire_write()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_write()
|
||||
|
||||
|
||||
@final
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from core.workflow.graph_events import (
|
|||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
|
@ -260,12 +261,23 @@ class GraphEngine:
|
|||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
yield GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.workflow.graph_events import (
|
|||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
|
|
@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
|
|||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
|
|
@ -138,6 +146,12 @@ class DebugLoggingLayer(GraphEngineLayer):
|
|||
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
# Node-level events
|
||||
# Retry before Started because Retry subclasses Started;
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.node_count += 1
|
||||
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
|
||||
|
|
@ -167,11 +181,6 @@ class DebugLoggingLayer(GraphEngineLayer):
|
|||
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
|
||||
self.logger.warning(" Error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
# Log stream chunks at debug level to avoid spam
|
||||
final_indicator = " (FINAL)" if event.is_final else ""
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from core.workflow.enums import (
|
|||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
|
|
@ -372,43 +373,16 @@ class IterationNode(Node):
|
|||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# init graph
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
# Create minimal GraphInitParams for static analysis
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from="",
|
||||
invoke_from="",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Create minimal GraphRuntimeState for static analysis
|
||||
from core.workflow.entities import VariablePool
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create node factory for static analysis
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=typed_node_data.start_node_id,
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
|
|
@ -444,9 +418,7 @@ class IterationNode(Node):
|
|||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
|
|
@ -485,7 +457,7 @@ class IterationNode(Node):
|
|||
if isinstance(event, GraphNodeEventBase):
|
||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||
yield event
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||
result = variable_pool.get(self._node_data.output_selector)
|
||||
if result is None:
|
||||
outputs.append(None)
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class RetrievalSetting(BaseModel):
|
|||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
|
|
@ -127,11 +128,13 @@ class LoopNode(Node):
|
|||
try:
|
||||
reach_break_condition = False
|
||||
if break_conditions:
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
with contextlib.suppress(ValueError):
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
|
||||
if reach_break_condition:
|
||||
loop_count = 0
|
||||
cost_tokens = 0
|
||||
|
|
@ -295,42 +298,11 @@ class LoopNode(Node):
|
|||
|
||||
variable_mapping = {}
|
||||
|
||||
# init graph
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
# Create minimal GraphInitParams for static analysis
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="",
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from="",
|
||||
invoke_from="",
|
||||
call_depth=0,
|
||||
)
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
# Create minimal GraphRuntimeState for static analysis
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
# Create node factory for static analysis
|
||||
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
|
||||
loop_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=typed_node_data.start_node_id,
|
||||
)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
# Get node configs from graph_config
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
|
|
@ -371,12 +343,35 @@ class LoopNode(Node):
|
|||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
|
||||
}
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
This method statically analyzes the graph configuration to find all nodes
|
||||
that are part of the specified loop, without creating actual node instances.
|
||||
|
||||
:param graph_config: the complete graph configuration
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
loop_node_ids.add(node_id)
|
||||
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -58,7 +58,7 @@ class GitHubOAuth(OAuth):
|
|||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
|
|
@ -70,11 +70,11 @@ class GitHubOAuth(OAuth):
|
|||
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = requests.get(self._USER_INFO_URL, headers=headers)
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
|
||||
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = email_response.json()
|
||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ class GoogleOAuth(OAuth):
|
|||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
|
|
@ -124,7 +124,7 @@ class GoogleOAuth(OAuth):
|
|||
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = requests.get(self._USER_INFO_URL, headers=headers)
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import urllib.parse
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
|
|
@ -239,7 +239,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
|
||||
results.extend(response_json.get("results", []))
|
||||
|
|
@ -254,7 +254,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
message = response_json.get("message", "unknown error")
|
||||
|
|
@ -270,7 +270,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
|
||||
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
|
||||
response_json = response.json()
|
||||
if "object" in response_json and response_json["object"] == "user":
|
||||
user_type = response_json["type"]
|
||||
|
|
@ -294,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
|
|||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
|
||||
results.extend(response_json.get("results", []))
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def upgrade():
|
|||
sa.Column('plugin_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('auth_type', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||
sa.Column('avatar_url', sa.Text(), nullable=True),
|
||||
sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('expires_at', sa.Integer(), server_default='-1', nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class DatasourceProvider(Base):
|
|||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
|
||||
avatar_url: Mapped[str] = db.Column(db.Text, nullable=True, default="default")
|
||||
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dify-api"
|
||||
version = "2.0.0-beta2"
|
||||
version = "1.9.0"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
|||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
|
|||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
|
|||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class WatercrawlAuth(ApiKeyAuthBase):
|
|||
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
|
||||
|
||||
def _get_request(self, url, headers):
|
||||
return requests.get(url, headers=headers)
|
||||
return httpx.get(url, headers=headers)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class RetrievalSetting(BaseModel):
|
|||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
|
||||
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
|
||||
class OperationService:
|
||||
|
|
@ -12,7 +12,7 @@ class OperationService:
|
|||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
|
|
@ -216,7 +216,7 @@ class WebsiteService:
|
|||
@classmethod
|
||||
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
|
||||
if not request.options.crawl_sub_pages:
|
||||
response = requests.get(
|
||||
response = httpx.get(
|
||||
f"https://r.jina.ai/{request.url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
|
@ -224,7 +224,7 @@ class WebsiteService:
|
|||
raise ValueError("Failed to crawl:")
|
||||
return {"status": "active", "data": response.json().get("data")}
|
||||
else:
|
||||
response = requests.post(
|
||||
response = httpx.post(
|
||||
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
|
||||
json={
|
||||
"url": request.url,
|
||||
|
|
@ -287,7 +287,7 @@ class WebsiteService:
|
|||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
response = requests.post(
|
||||
response = httpx.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
|
|
@ -303,7 +303,7 @@ class WebsiteService:
|
|||
}
|
||||
|
||||
if crawl_status_data["status"] == "completed":
|
||||
response = requests.post(
|
||||
response = httpx.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
||||
|
|
@ -362,7 +362,7 @@ class WebsiteService:
|
|||
@classmethod
|
||||
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
|
||||
if not job_id:
|
||||
response = requests.get(
|
||||
response = httpx.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
|
@ -371,7 +371,7 @@ class WebsiteService:
|
|||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
# Get crawl status first
|
||||
status_response = requests.post(
|
||||
status_response = httpx.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
|
|
@ -381,7 +381,7 @@ class WebsiteService:
|
|||
raise ValueError("Crawl job is not completed")
|
||||
|
||||
# Get processed data
|
||||
data_response = requests.post(
|
||||
data_response = httpx.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
|
@ -27,13 +27,11 @@ class MockedHttp:
|
|||
@classmethod
|
||||
def requests_request(
|
||||
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Mocked requests.request
|
||||
Mocked httpx.request
|
||||
"""
|
||||
request = requests.PreparedRequest()
|
||||
request.method = method
|
||||
request.url = url
|
||||
request = httpx.Request(method, url)
|
||||
if url.endswith("/tools"):
|
||||
content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
|
||||
code=0, message="success", data=cls.list_tools()
|
||||
|
|
@ -41,8 +39,7 @@ class MockedHttp:
|
|||
else:
|
||||
raise ValueError("")
|
||||
|
||||
response = requests.Response()
|
||||
response.status_code = 200
|
||||
response = httpx.Response(status_code=200)
|
||||
response.request = request
|
||||
response._content = content.encode("utf-8")
|
||||
return response
|
||||
|
|
@ -54,7 +51,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
|||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK_SWITCH:
|
||||
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
|
||||
monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
|
||||
|
||||
def unpatch():
|
||||
monkeypatch.undo()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment
|
|||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ def test_dify_api():
|
|||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/console/api/health")
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ class DifyTestContainers:
|
|||
# Start Dify Plugin Daemon container for plugin management
|
||||
# Dify Plugin Daemon provides plugin lifecycle management and execution
|
||||
logger.info("Initializing Dify Plugin Daemon container...")
|
||||
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.2.0-local")
|
||||
self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local")
|
||||
self.dify_plugin_daemon.with_exposed_ports(5002)
|
||||
self.dify_plugin_daemon.env = {
|
||||
"DB_HOST": db_host,
|
||||
|
|
|
|||
|
|
@ -201,9 +201,9 @@ class TestOAuthCallback:
|
|||
mock_db.session.rollback = MagicMock()
|
||||
|
||||
# Import the real requests module to create a proper exception
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
request_exception = requests.exceptions.RequestException("OAuth error")
|
||||
request_exception = httpx.RequestError("OAuth error")
|
||||
request_exception.response = MagicMock()
|
||||
request_exception.response.text = str(exception)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,120 @@
|
|||
"""Tests for graph engine event handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
|
||||
from core.workflow.graph_engine.event_management.event_manager import EventManager
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import RetryConfig
|
||||
|
||||
|
||||
class _StubEdgeProcessor:
|
||||
"""Minimal edge processor stub for tests."""
|
||||
|
||||
|
||||
class _StubErrorHandler:
|
||||
"""Minimal error handler stub for tests."""
|
||||
|
||||
|
||||
class _StubNode:
|
||||
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||
|
||||
def __init__(self, node_id: str) -> None:
|
||||
self.id = node_id
|
||||
self.state = NodeState.UNKNOWN
|
||||
self.title = "Stub Node"
|
||||
self.execution_type = NodeExecutionType.EXECUTABLE
|
||||
self.error_strategy = None
|
||||
self.retry_config = RetryConfig()
|
||||
self.retry = False
|
||||
|
||||
|
||||
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||
"""Construct an EventHandler with in-memory dependencies for testing."""
|
||||
|
||||
node = _StubNode(node_id)
|
||||
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)
|
||||
|
||||
variable_pool = VariablePool()
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_execution = GraphExecution(workflow_id="test-workflow")
|
||||
|
||||
event_manager = EventManager()
|
||||
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
|
||||
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)
|
||||
|
||||
handler = EventHandler(
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
graph_execution=graph_execution,
|
||||
response_coordinator=response_coordinator,
|
||||
event_collector=event_manager,
|
||||
edge_processor=_StubEdgeProcessor(),
|
||||
state_manager=state_manager,
|
||||
error_handler=_StubErrorHandler(),
|
||||
)
|
||||
|
||||
return handler, event_manager, graph_execution
|
||||
|
||||
|
||||
def test_retry_does_not_emit_additional_start_event() -> None:
|
||||
"""Ensure retry attempts do not produce duplicate start events."""
|
||||
|
||||
node_id = "test-node"
|
||||
handler, event_manager, graph_execution = _build_event_handler(node_id)
|
||||
|
||||
execution_id = "exec-1"
|
||||
node_type = NodeType.CODE
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
)
|
||||
handler.dispatch(start_event)
|
||||
|
||||
retry_event = NodeRunRetryEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
error="boom",
|
||||
retry_index=1,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="boom",
|
||||
error_type="TestError",
|
||||
),
|
||||
)
|
||||
handler.dispatch(retry_event)
|
||||
|
||||
# Simulate the node starting execution again after retry
|
||||
second_start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_title="Stub Node",
|
||||
start_at=start_time,
|
||||
)
|
||||
handler.dispatch(second_start_event)
|
||||
|
||||
collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]
|
||||
|
||||
assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]
|
||||
|
||||
node_execution = graph_execution.get_or_create_node_execution(node_id)
|
||||
assert node_execution.retry_count == 1
|
||||
|
|
@ -10,11 +10,18 @@ import time
|
|||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType
|
||||
|
||||
# Import the test framework from the new module
|
||||
from .test_mock_config import MockConfigBuilder
|
||||
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
|
||||
|
||||
|
||||
|
|
@ -721,3 +728,39 @@ def test_event_sequence_validation_with_table_tests():
|
|||
else:
|
||||
assert result.event_sequence_match is True
|
||||
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
|
||||
|
||||
|
||||
def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
||||
runner = TableTestRunner()
|
||||
|
||||
fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
|
||||
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()
|
||||
|
||||
graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
|
||||
fixture_data=fixture_data,
|
||||
query="hello",
|
||||
use_mock_factory=True,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
llm_node = graph.nodes["llm"]
|
||||
base_node_data = llm_node.get_base_node_data()
|
||||
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
|
||||
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
assert isinstance(events[-1], GraphRunPartialSucceededEvent)
|
||||
|
||||
partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
|
||||
assert partial_event.exceptions_count == 1
|
||||
assert partial_event.outputs.get("answer") == "fallback response"
|
||||
|
||||
assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)
|
||||
|
|
|
|||
|
|
@ -1,65 +0,0 @@
|
|||
import pytest
|
||||
|
||||
pytest.skip(
|
||||
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_retry_default_value_partial_success():
|
||||
"""retry default value node with partial success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
"default-value",
|
||||
[{"key": "result", "type": "string", "value": "http node got error response"}],
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert events[-1].outputs == {"answer": "http node got error response"}
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
def test_retry_failed():
|
||||
"""retry failed with success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
None,
|
||||
None,
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
||||
assert len(events) == 8
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import urllib.parse
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
|||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("requests.post")
|
||||
@patch("httpx.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
|
|
@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
|||
),
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
|
|
@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest):
|
|||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
with pytest.raises(httpx.RequestError):
|
||||
oauth.get_raw_user_info("test_token")
|
||||
|
||||
|
||||
|
|
@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
|||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("requests.post")
|
||||
@patch("httpx.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
|
|
@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
|||
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||
mock_response.json.return_value = user_data
|
||||
mock_get.return_value = mock_response
|
||||
|
|
@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest):
|
|||
@pytest.mark.parametrize(
|
||||
"exception_type",
|
||||
[
|
||||
requests.exceptions.HTTPError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
httpx.HTTPError,
|
||||
httpx.ConnectError,
|
||||
httpx.TimeoutException,
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
@patch("httpx.get")
|
||||
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import json
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
|
@ -26,7 +26,7 @@ class TestAuthIntegration:
|
|||
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test complete authentication flow: request → validation → encryption → storage"""
|
||||
|
|
@ -47,7 +47,7 @@ class TestAuthIntegration:
|
|||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_cross_component_integration(self, mock_http):
|
||||
"""Test factory → provider → HTTP call integration"""
|
||||
mock_http.return_value = self._create_success_response()
|
||||
|
|
@ -97,7 +97,7 @@ class TestAuthIntegration:
|
|||
assert "another_secret" not in factory_str
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test concurrent authentication creation safety"""
|
||||
|
|
@ -142,31 +142,31 @@ class TestAuthIntegration:
|
|||
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
|
||||
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_http_error_handling(self, mock_http):
|
||||
"""Test proper HTTP error handling"""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = '{"error": "Unauthorized"}'
|
||||
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
|
||||
mock_http.return_value = mock_response
|
||||
|
||||
# PT012: Split into single statement for pytest.raises
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
with pytest.raises((requests.exceptions.HTTPError, Exception)):
|
||||
with pytest.raises((httpx.HTTPError, Exception)):
|
||||
factory.validate_credentials()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_network_failure_recovery(self, mock_http, mock_session):
|
||||
"""Test system recovery from network failures"""
|
||||
mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
|
||||
mock_http.side_effect = httpx.RequestError("Network timeout")
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
|
||||
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
with pytest.raises(httpx.RequestError):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ class TestFirecrawlAuth:
|
|||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -95,7 +95,7 @@ class TestFirecrawlAuth:
|
|||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -115,7 +115,7 @@ class TestFirecrawlAuth:
|
|||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
|
|
@ -134,13 +134,13 @@ class TestFirecrawlAuth:
|
|||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(requests.ConnectionError, "Network error"),
|
||||
(requests.Timeout, "Request timeout"),
|
||||
(requests.ReadTimeout, "Read timeout"),
|
||||
(requests.ConnectTimeout, "Connection timeout"),
|
||||
(httpx.ConnectError, "Network error"),
|
||||
(httpx.TimeoutException, "Request timeout"),
|
||||
(httpx.ReadTimeout, "Read timeout"),
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_post.side_effect = exception_type(exception_message)
|
||||
|
|
@ -162,7 +162,7 @@ class TestFirecrawlAuth:
|
|||
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -179,12 +179,12 @@ class TestFirecrawlAuth:
|
|||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
|
||||
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(requests.Timeout) as exc_info:
|
||||
with pytest.raises(httpx.TimeoutException) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.jina.jina import JinaAuth
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ class TestJinaAuth:
|
|||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -53,7 +53,7 @@ class TestJinaAuth:
|
|||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -68,7 +68,7 @@ class TestJinaAuth:
|
|||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_409_error(self, mock_post):
|
||||
"""Test handling of 409 Conflict error"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -83,7 +83,7 @@ class TestJinaAuth:
|
|||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -98,7 +98,7 @@ class TestJinaAuth:
|
|||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -114,7 +114,7 @@ class TestJinaAuth:
|
|||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_unexpected_error_without_text(self, mock_post):
|
||||
"""Test handling of unexpected errors without text response"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -130,15 +130,15 @@ class TestJinaAuth:
|
|||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = requests.ConnectionError("Network error")
|
||||
mock_post.side_effect = httpx.ConnectError("Network error")
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(requests.ConnectionError):
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
auth.validate_credentials()
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.watercrawl.watercrawl import WatercrawlAuth
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ class TestWatercrawlAuth:
|
|||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
|
|||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
|
|||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
|
|
@ -126,13 +126,13 @@ class TestWatercrawlAuth:
|
|||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(requests.ConnectionError, "Network error"),
|
||||
(requests.Timeout, "Request timeout"),
|
||||
(requests.ReadTimeout, "Read timeout"),
|
||||
(requests.ConnectTimeout, "Connection timeout"),
|
||||
(httpx.ConnectError, "Network error"),
|
||||
(httpx.TimeoutException, "Request timeout"),
|
||||
(httpx.ReadTimeout, "Read timeout"),
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_get.side_effect = exception_type(exception_message)
|
||||
|
|
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
|
|||
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_get):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
|
|||
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
|
||||
"""Test that urljoin is used correctly for URL construction with various base URLs"""
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -193,12 +193,12 @@ class TestWatercrawlAuth:
|
|||
# Verify the correct URL was called
|
||||
assert mock_get.call_args[0][0] == expected_url
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
|
||||
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(requests.Timeout) as exc_info:
|
||||
with pytest.raises(httpx.TimeoutException) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.11, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12.4' and sys_platform == 'linux'",
|
||||
|
|
@ -1273,7 +1273,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "2.0.0b2"
|
||||
version = "1.9.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "arize-phoenix-otel" },
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:2.0.0-beta.2
|
||||
image: langgenius/dify-api:1.9.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -31,7 +31,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:2.0.0-beta.2
|
||||
image: langgenius/dify-api:1.9.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -58,7 +58,7 @@ services:
|
|||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:2.0.0-beta.2
|
||||
image: langgenius/dify-api:1.9.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -76,7 +76,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:2.0.0-beta.2
|
||||
image: langgenius/dify-web:1.9.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -177,7 +177,7 @@ services:
|
|||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.3.0b1-local
|
||||
image: langgenius/dify-plugin-daemon:0.3.0-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
|
|||
|
|
@ -20,7 +20,17 @@ services:
|
|||
ports:
|
||||
- "${EXPOSE_POSTGRES_PORT:-5432}:5432"
|
||||
healthcheck:
|
||||
test: [ 'CMD', 'pg_isready', '-h', 'db', '-U', '${PGUSER:-postgres}', '-d', '${POSTGRES_DB:-dify}' ]
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"pg_isready",
|
||||
"-h",
|
||||
"db",
|
||||
"-U",
|
||||
"${PGUSER:-postgres}",
|
||||
"-d",
|
||||
"${POSTGRES_DB:-dify}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
|
|
@ -41,7 +51,11 @@ services:
|
|||
ports:
|
||||
- "${EXPOSE_REDIS_PORT:-6379}:6379"
|
||||
healthcheck:
|
||||
test: [ 'CMD-SHELL', 'redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG' ]
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG",
|
||||
]
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
|
|
@ -65,13 +79,13 @@ services:
|
|||
- ./volumes/sandbox/dependencies:/dependencies
|
||||
- ./volumes/sandbox/conf:/conf
|
||||
healthcheck:
|
||||
test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ]
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8194/health"]
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.3.0b1-local
|
||||
image: langgenius/dify-plugin-daemon:0.3.0-local
|
||||
restart: always
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
|
|
@ -143,7 +157,12 @@ services:
|
|||
volumes:
|
||||
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
|
||||
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
||||
entrypoint: [ "sh", "-c", "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||
entrypoint:
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
||||
]
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
environment:
|
||||
|
|
|
|||
|
|
@ -593,7 +593,7 @@ x-shared-env: &shared-api-worker-env
|
|||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:2.0.0-beta.2
|
||||
image: langgenius/dify-api:1.9.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -622,7 +622,7 @@ services:
|
|||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:2.0.0-beta.2
|
||||
image: langgenius/dify-api:1.9.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -649,7 +649,7 @@ services:
|
|||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:2.0.0-beta.2
|
||||
image: langgenius/dify-api:1.9.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
@ -667,7 +667,7 @@ services:
|
|||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:2.0.0-beta.2
|
||||
image: langgenius/dify-web:1.9.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
|
@ -768,7 +768,7 @@ services:
|
|||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.3.0b1-local
|
||||
image: langgenius/dify-plugin-daemon:0.3.0-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
|
|||
const handleAddGroup = useCallback(() => {
|
||||
let maxInGroupName = 1
|
||||
inputs.advanced_settings.groups.forEach((item) => {
|
||||
const match = item.group_name.match(/(\d+)$/)
|
||||
const match = /(\d+)$/.exec(item.group_name)
|
||||
if (match) {
|
||||
const num = Number.parseInt(match[1], 10)
|
||||
if (num > maxInGroupName)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "dify-web",
|
||||
"version": "2.0.0-beta2",
|
||||
"version": "1.9.0",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.16.0",
|
||||
"engines": {
|
||||
|
|
|
|||
Loading…
Reference in New Issue