feat: move model request to plugin daemon

This commit is contained in:
takatost 2024-09-29 00:14:44 +08:00
parent 1c3213184e
commit 47c8824be6
29 changed files with 127 additions and 118 deletions

View File

@ -114,20 +114,18 @@ class PluginConfig(BaseSettings):
"""
Plugin configs
"""
PLUGIN_API_URL: HttpUrl = Field(
description='Plugin API URL',
default='http://plugin:5002',
description="Plugin API URL",
default="http://plugin:5002",
)
PLUGIN_API_KEY: str = Field(
description='Plugin API key',
default='plugin-api-key',
description="Plugin API key",
default="plugin-api-key",
)
INNER_API_KEY_FOR_PLUGIN: str = Field(
description='Inner api key for plugin',
default='inner-api-key'
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
class EndpointConfig(BaseSettings):

View File

@ -140,7 +140,7 @@ class DraftWorkflowImportApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
@ -167,7 +167,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
@ -209,7 +209,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
@ -246,7 +246,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
@ -283,7 +283,7 @@ class DraftWorkflowRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
@ -336,7 +336,7 @@ class DraftWorkflowNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
@ -388,7 +388,7 @@ class PublishedWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
@ -428,7 +428,7 @@ class DefaultBlockConfigApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
@ -464,7 +464,7 @@ class ConvertToWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()

View File

@ -17,11 +17,9 @@ class PluginDebuggingKeyApi(Resource):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
return {
"key": PluginDebuggingService.get_plugin_debugging_key(tenant_id)
}
return {"key": PluginDebuggingService.get_plugin_debugging_key(tenant_id)}
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")

View File

@ -1 +1 @@
from .plugin import *
from .plugin import *

View File

@ -16,31 +16,36 @@ def get_tenant(view: Optional[Callable] = None):
def decorated_view(*args, **kwargs):
# fetch json body
parser = reqparse.RequestParser()
parser.add_argument('tenant_id', type=str, required=True, location='json')
parser.add_argument('user_id', type=str, required=True, location='json')
parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json")
kwargs = parser.parse_args()
user_id = kwargs.get('user_id')
tenant_id = kwargs.get('tenant_id')
user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")
del kwargs['tenant_id']
del kwargs['user_id']
del kwargs["tenant_id"]
del kwargs["user_id"]
try:
tenant_model = db.session.query(Tenant).filter(
Tenant.id == tenant_id,
).first()
tenant_model = (
db.session.query(Tenant)
.filter(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError('tenant not found')
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError('tenant not found')
raise ValueError("tenant not found")
kwargs['tenant_model'] = tenant_model
kwargs['user_id'] = user_id
kwargs["tenant_model"] = tenant_model
kwargs["user_id"] = user_id
return view_func(*args, **kwargs)
return decorated_view
if view is None:
@ -55,18 +60,18 @@ def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel
try:
data = request.get_json()
except Exception:
raise ValueError('invalid json')
raise ValueError("invalid json")
try:
payload = payload_type(**data)
except Exception as e:
raise ValueError(f'invalid payload: {str(e)}')
kwargs['payload'] = payload
raise ValueError(f"invalid payload: {str(e)}")
kwargs["payload"] = payload
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:

View File

@ -10,6 +10,7 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: ToolProviderType
provider_id: str
tool_name: str

View File

@ -366,9 +366,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage]
) -> list[PromptMessage]:
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""

View File

@ -50,8 +50,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -80,8 +81,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield response_chunk
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -20,6 +20,7 @@ class AppGenerateResponseConverter(ABC):
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
@ -28,6 +29,7 @@ class AppGenerateResponseConverter(ABC):
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
@ -45,8 +47,9 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod

View File

@ -64,11 +64,12 @@ class BaseAppGenerator:
if isinstance(generator, dict):
return generator
else:
def gen():
for message in generator:
if isinstance(message, dict):
yield f'data: {json.dumps(message)}\n\n'
yield f"data: {json.dumps(message)}\n\n"
else:
yield f'event: {message}\n\n'
return gen()
yield f"event: {message}\n\n"
return gen()

View File

@ -50,7 +50,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,

View File

@ -50,8 +50,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -80,8 +81,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield response_chunk
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -52,19 +52,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = False,
) -> dict | Generator[str, None, None]: ...
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[str, None, None]]:
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.

View File

@ -49,8 +49,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -78,8 +79,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield response_chunk
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -42,7 +42,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[dict | str, None, None]: ...
@overload
@ -60,7 +60,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@ -143,7 +144,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
workflow_thread_pool_id: Optional[str] = None
workflow_thread_pool_id: Optional[str] = None,
) -> Union[dict, Generator[str | dict, None, None]]:
"""
Generate App response.
@ -189,12 +190,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: dict,
stream: bool = True) -> dict[str, Any] | Generator[str | dict, Any, None]:
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account | EndUser, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str | dict, Any, None]:
"""
Generate App response.

View File

@ -34,8 +34,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -62,8 +63,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield response_chunk
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
-> Generator[dict | str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
class DifyPluginCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""
"""Callback Handler that prints to std out."""

View File

@ -27,4 +27,4 @@ class ModelConfigScope(Enum):
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
VISION = "vision"

View File

@ -116,6 +116,7 @@ class BasicProviderConfig(BaseModel):
"""
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
@ -135,7 +136,7 @@ class BasicProviderConfig(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
raise ValueError(f"invalid mode value {value}")
type: Type = Field(..., description="The type of the credentials")
name: str = Field(..., description="The name of the credentials")
@ -145,6 +146,7 @@ class ProviderConfig(BasicProviderConfig):
"""
Model class for common provider settings like credentials
"""
class Option(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")

View File

@ -3,9 +3,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
tool_file_manager: dict[str, Any] = {
'manager': None
}
tool_file_manager: dict[str, Any] = {"manager": None}
class ToolFileParser:

View File

@ -114,4 +114,4 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
"inputs": execution.inputs_dict,
"outputs": execution.outputs_dict,
"process_data": execution.process_data_dict,
}
}

View File

@ -492,8 +492,9 @@ class DatasetRetrieval:
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
@ -506,6 +507,7 @@ class DatasetRetrieval:
tools.append(tool)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,

View File

@ -375,6 +375,3 @@ class ToolInvokeFrom(Enum):
WORKFLOW = "workflow"
AGENT = "agent"

View File

@ -43,4 +43,4 @@ class WorkflowToolConfigurationUtils:
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError('parameter configuration mismatch, please republish the tool to update')
raise ValueError("parameter configuration mismatch, please republish the tool to update")

View File

@ -65,7 +65,7 @@ class WorkflowToolProviderController(ToolProviderController):
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
"""
get db provider tool
@ -73,10 +73,11 @@ class WorkflowToolProviderController(ToolProviderController):
:param app: the app
:return: the tool
"""
workflow: Workflow | None = db.session.query(Workflow).filter(
Workflow.app_id == db_provider.app_id,
Workflow.version == db_provider.version
).first()
workflow: Workflow | None = (
db.session.query(Workflow)
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
)
if not workflow:
raise ValueError("workflow not found")
@ -84,10 +85,7 @@ class WorkflowToolProviderController(ToolProviderController):
# fetch start node
graph: Mapping = workflow.graph_dict
features_dict: Mapping = workflow.features_dict
features = WorkflowAppConfigManager.convert_features(
config_dict=features_dict,
app_mode=AppMode.WORKFLOW
)
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
@ -180,14 +178,18 @@ class WorkflowToolProviderController(ToolProviderController):
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
).first()
db_providers: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
)
.first()
)
if not db_providers:
return []
app = db_providers.app
if not app:
raise ValueError("can not read app of workflow")

View File

@ -2,4 +2,4 @@ from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
pass

View File

@ -161,7 +161,7 @@ class ApiToolManageService:
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name
provider_identity=provider_controller.entity.identity.name,
)
encrypted_credentials = tool_configuration.encrypt(credentials)
@ -293,7 +293,7 @@ class ApiToolManageService:
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name
provider_identity=provider_controller.entity.identity.name,
)
original_credentials = tool_configuration.decrypt(provider.credentials)
@ -412,7 +412,7 @@ class ApiToolManageService:
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name
provider_identity=provider_controller.entity.identity.name,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
# check if the credential has changed, save the original credential

View File

@ -238,7 +238,7 @@ class WorkflowService:
db.session.commit()
return workflow_node_execution
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
@ -258,7 +258,7 @@ class WorkflowService:
),
start_at=start_at,
tenant_id=tenant_id,
node_id=node_id
node_id=node_id,
)
return workflow_node_execution

View File

@ -6,4 +6,3 @@ def test_fetch_all_plugin_tools(setup_http_mock):
manager = PluginToolManager()
tools = manager.fetch_tool_providers(tenant_id="test-tenant")
assert len(tools) >= 1