Merge branch 'refs/heads/main' into feat/workflow-parallel-support

# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/start/start_node.py
#	api/core/workflow/nodes/variable_assigner/__init__.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
takatost 2024-08-21 16:59:23 +08:00
commit 35be41b337
252 changed files with 5991 additions and 2942 deletions

1
.gitignore vendored
View File

@ -178,3 +178,4 @@ pyrightconfig.json
api/.vscode api/.vscode
.idea/ .idea/
.vscode

View File

@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration # Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1 CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

View File

@ -5,8 +5,8 @@
"name": "Python: Flask", "name": "Python: Flask",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"python": "${workspaceFolder}/api/.venv/bin/python", "python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}/api", "cwd": "${workspaceFolder}",
"envFile": ".env", "envFile": ".env",
"module": "flask", "module": "flask",
"justMyCode": true, "justMyCode": true,
@ -18,15 +18,15 @@
"args": [ "args": [
"run", "run",
"--host=0.0.0.0", "--host=0.0.0.0",
"--port=5001", "--port=5001"
] ]
}, },
{ {
"name": "Python: Celery", "name": "Python: Celery",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"python": "${workspaceFolder}/api/.venv/bin/python", "python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}/api", "cwd": "${workspaceFolder}",
"module": "celery", "module": "celery",
"justMyCode": true, "justMyCode": true,
"envFile": ".env", "envFile": ".env",

View File

@ -37,6 +37,8 @@ class DifyConfig(
CODE_MAX_NUMBER: int = 9223372036854775807 CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808 CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_DEPTH: int = 5
CODE_MAX_PRECISION: int = 20
CODE_MAX_STRING_LENGTH: int = 80000 CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30 CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30 CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30

View File

@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
default=False, default=False,
) )
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """
Workspace configs Workspace configs
@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
) )
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description='The heads of model providers',
default='',
)
POSITION_PROVIDER_INCLUDES: str = Field(
description='The included model providers',
default='',
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description='The excluded model providers',
default='',
)
POSITION_TOOL_PINS: str = Field(
description='The heads of tools',
default='',
)
POSITION_TOOL_INCLUDES: str = Field(
description='The included tools',
default='',
)
POSITION_TOOL_EXCLUDES: str = Field(
description='The excluded tools',
default='',
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
@ -466,6 +524,7 @@ class FeatureConfig(
UpdateConfig, UpdateConfig,
WorkflowConfig, WorkflowConfig,
WorkspaceConfig, WorkspaceConfig,
PositionConfig,
# hosted services config # hosted services config
HostedServiceConfig, HostedServiceConfig,

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description='Dify version', description='Dify version',
default='0.7.0', default='0.7.1',
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

View File

@ -61,6 +61,7 @@ class AppListApi(Resource):
parser.add_argument('name', type=str, required=True, location='json') parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('description', type=str, location='json') parser.add_argument('description', type=str, location='json')
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json') parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
@ -94,6 +95,7 @@ class AppImportApi(Resource):
parser.add_argument('data', type=str, required=True, nullable=False, location='json') parser.add_argument('data', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json') parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json') parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json') parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
@ -167,6 +169,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('description', type=str, location='json') parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json') parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json') parser.add_argument('max_active_requests', type=int, location='json')
@ -208,6 +211,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, location='json') parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json') parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json') parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()

View File

@ -154,6 +154,8 @@ class ChatConversationApi(Resource):
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args() args = parser.parse_args()
subquery = ( subquery = (
@ -225,7 +227,17 @@ class ChatConversationApi(Resource):
if app_model.mode == AppMode.ADVANCED_CHAT.value: if app_model.mode == AppMode.ADVANCED_CHAT.value:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.order_by(Conversation.created_at.desc()) match args['sort_by']:
case 'created_at':
query = query.order_by(Conversation.created_at.asc())
case '-created_at':
query = query.order_by(Conversation.created_at.desc())
case 'updated_at':
query = query.order_by(Conversation.updated_at.asc())
case '-updated_at':
query = query.order_by(Conversation.updated_at.desc())
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate( conversations = db.paginate(
query, query,

View File

@ -16,6 +16,7 @@ from models.model import Site
def parse_app_site_args(): def parse_app_site_args():
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('title', type=str, required=False, location='json') parser.add_argument('title', type=str, required=False, location='json')
parser.add_argument('icon_type', type=str, required=False, location='json')
parser.add_argument('icon', type=str, required=False, location='json') parser.add_argument('icon', type=str, required=False, location='json')
parser.add_argument('icon_background', type=str, required=False, location='json') parser.add_argument('icon_background', type=str, required=False, location='json')
parser.add_argument('description', type=str, required=False, location='json') parser.add_argument('description', type=str, required=False, location='json')
@ -53,6 +54,7 @@ class AppSite(Resource):
for attr_name in [ for attr_name in [
'title', 'title',
'icon_type',
'icon', 'icon',
'icon_background', 'icon_background',
'description', 'description',

View File

@ -460,6 +460,7 @@ class ConvertToWorkflowApi(Resource):
if request.data: if request.data:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json') parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon', type=str, required=False, nullable=True, location='json') parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json') parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()

View File

@ -573,13 +573,13 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, vector_type): def get(self, vector_type):
match vector_type: match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value RetrievalMethod.SEMANTIC_SEARCH.value
] ]
} }
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -25,6 +25,8 @@ class ConversationApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -33,7 +35,8 @@ class ConversationApi(Resource):
user=end_user, user=end_user,
last_id=args['last_id'], last_id=args['last_id'],
limit=args['limit'], limit=args['limit'],
invoke_from=InvokeFrom.SERVICE_API invoke_from=InvokeFrom.SERVICE_API,
sort_by=args['sort_by']
) )
except services.errors.conversation.LastConversationNotExistsError: except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")

View File

@ -53,19 +53,22 @@ class SegmentApi(DatasetApiResource):
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.") "in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json') parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
for args_item in args['segments']: if args['segments'] is not None:
SegmentService.segment_create_args_validate(args_item, document) for args_item in args['segments']:
segments = SegmentService.multi_create_segment(args['segments'], document, dataset) SegmentService.segment_create_args_validate(args_item, document)
return { segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
'data': marshal(segments, segment_fields), return {
'doc_form': document.doc_form 'data': marshal(segments, segment_fields),
}, 200 'doc_form': document.doc_form
}, 200
else:
return {"error": "Segemtns is required"}, 400
def get(self, tenant_id, dataset_id, document_id): def get(self, tenant_id, dataset_id, document_id):
"""Create single segment.""" """Create single segment."""

View File

@ -26,6 +26,8 @@ class ConversationListApi(WebApiResource):
parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args() args = parser.parse_args()
pinned = None pinned = None
@ -40,6 +42,7 @@ class ConversationListApi(WebApiResource):
limit=args['limit'], limit=args['limit'],
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
pinned=pinned, pinned=pinned,
sort_by=args['sort_by']
) )
except LastConversationNotExistsError: except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")

View File

@ -6,6 +6,7 @@ from configs import dify_config
from controllers.web import api from controllers.web import api
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import AppIconUrlField
from models.account import TenantStatus from models.account import TenantStatus
from models.model import Site from models.model import Site
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -28,8 +29,10 @@ class AppSiteApi(WebApiResource):
'title': fields.String, 'title': fields.String,
'chat_color_theme': fields.String, 'chat_color_theme': fields.String,
'chat_color_theme_inverted': fields.Boolean, 'chat_color_theme_inverted': fields.Boolean,
'icon_type': fields.String,
'icon': fields.String, 'icon': fields.String,
'icon_background': fields.String, 'icon_background': fields.String,
'icon_url': AppIconUrlField,
'description': fields.String, 'description': fields.String,
'copyright': fields.String, 'copyright': fields.String,
'privacy_policy': fields.String, 'privacy_policy': fields.String,

View File

@ -64,15 +64,19 @@ class BaseAgentRunner(AppRunner):
""" """
Agent runner Agent runner
:param tenant_id: tenant id :param tenant_id: tenant id
:param application_generate_entity: application generate entity
:param conversation: conversation
:param app_config: app generate entity :param app_config: app generate entity
:param model_config: model config :param model_config: model config
:param config: dataset config :param config: dataset config
:param queue_manager: queue manager :param queue_manager: queue manager
:param message: message :param message: message
:param user_id: user id :param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory :param memory: memory
:param prompt_messages: prompt messages
:param variables_pool: variables pool
:param db_variables: db variables
:param model_instance: model instance
""" """
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
@ -445,7 +449,7 @@ class BaseAgentRunner(AppRunner):
try: try:
tool_responses = json.loads(agent_thought.observation) tool_responses = json.loads(agent_thought.observation)
except Exception as e: except Exception as e:
tool_responses = { tool: agent_thought.observation for tool in tools } tool_responses = dict.fromkeys(tools, agent_thought.observation)
for tool in tools: for tool in tools:
# generate a uuid for tool call # generate a uuid for tool call

View File

@ -292,6 +292,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
handle invoke action handle invoke action
:param action: action :param action: action
:param tool_instances: tool instances :param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta :return: observation, meta
""" """
# action is tool call, invoke tool # action is tool call, invoke tool

View File

@ -93,7 +93,7 @@ class DatasetConfigManager:
reranking_model=dataset_configs.get('reranking_model'), reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'), weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True), reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs["reranking_mode"], rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
) )
) )

View File

@ -1,6 +1,6 @@
import re import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory from core.external_data_tool.factory import ExternalDataToolFactory
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
:param config: model config args :param config: model config args
""" """
external_data_variables = [] external_data_variables = []
variables = [] variable_entities = []
# old external_data_tools # old external_data_tools
external_data_tools = config.get('external_data_tools', []) external_data_tools = config.get('external_data_tools', [])
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
) )
# variables and external_data_tools # variables and external_data_tools
for variable in config.get('user_input_form', []): for variables in config.get('user_input_form', []):
typ = list(variable.keys())[0] variable_type = list(variables.keys())[0]
if typ == 'external_data_tool': if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
val = variable[typ] variable = variables[variable_type]
if 'config' not in val: if 'config' not in variable:
continue continue
external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=val['variable'], variable=variable['variable'],
type=val['type'], type=variable['type'],
config=val['config'] config=variable['config']
) )
) )
elif typ in [ elif variable_type in [
VariableEntity.Type.TEXT_INPUT.value, VariableEntityType.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH.value, VariableEntityType.PARAGRAPH,
VariableEntity.Type.NUMBER.value, VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]: ]:
variables.append( variable = variables[variable_type]
variable_entities.append(
VariableEntity( VariableEntity(
type=VariableEntity.Type.value_of(typ), type=variable_type,
variable=variable[typ].get('variable'), variable=variable.get('variable'),
description=variable[typ].get('description'), description=variable.get('description'),
label=variable[typ].get('label'), label=variable.get('label'),
required=variable[typ].get('required', False), required=variable.get('required', False),
max_length=variable[typ].get('max_length'), max_length=variable.get('max_length'),
default=variable[typ].get('default'), options=variable.get('options'),
) default=variable.get('default'),
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
) )
) )
return variables, external_data_variables return variable_entities, external_data_variables
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
config=config config=config
) )
return config, ["external_data_tools"] return config, ["external_data_tools"]

View File

@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"
class VariableEntity(BaseModel): class VariableEntity(BaseModel):
""" """
Variable Entity. Variable Entity.
""" """
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'
@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')
variable: str variable: str
label: str label: str
description: Optional[str] = None description: Optional[str] = None
type: Type type: VariableEntityType
required: bool = False required: bool = False
max_length: Optional[int] = None max_length: Optional[int] = None
options: Optional[list[str]] = None options: Optional[list[str]] = None
default: Optional[str] = None default: Optional[str] = None
hint: Optional[str] = None hint: Optional[str] = None
@property
def name(self) -> str:
return self.variable
class ExternalDataVariableEntity(BaseModel): class ExternalDataVariableEntity(BaseModel):
""" """
@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
""" """
Workflow UI Based App Config Entity. Workflow UI Based App Config Entity.
""" """
workflow_id: str workflow_id: str

View File

@ -23,6 +23,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.file.message_file_parser import MessageFileParser from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
@ -67,8 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation # get conversation
conversation = None conversation = None
if args.get('conversation_id'): conversation_id = args.get('conversation_id')
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id', ''), user) if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# parse files # parse files
files = args['files'] if args.get('files') else [] files = args['files'] if args.get('files') else []
@ -225,6 +228,62 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id message_id=message.id
) )
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread # new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore 'flask_app': current_app._get_current_object(), # type: ignore
@ -296,7 +355,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -47,7 +47,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
@ -69,7 +69,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_application_generate_entity: AdvancedChatAppGenerateEntity _application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow _workflow: Workflow
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariable, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
def __init__( def __init__(
self, self,
@ -102,10 +102,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation = conversation self._conversation = conversation
self._message = message self._message = message
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariable.QUERY: message.query, SystemVariableKey.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id, SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id, SystemVariableKey.USER_ID: user_id,
} }
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
@ -312,7 +312,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response( yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -321,7 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run: if not workflow_run:
raise Exception('Workflow run not initialized.') raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response( yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional from typing import Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
class BaseAppGenerator: class BaseAppGenerator:
@ -9,29 +9,29 @@ class BaseAppGenerator:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables variables = app_config.variables
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs return filtered_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name) user_input_value = inputs.get(var.variable)
if var.required and not user_input_value: if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form') raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value: if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None? # TODO: should we return None here if the default value is None?
return var.default or '' return var.default or ''
if ( if (
var.type var.type
in ( in (
VariableEntity.Type.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntity.Type.SELECT, VariableEntityType.SELECT,
VariableEntity.Type.PARAGRAPH, VariableEntityType.PARAGRAPH,
) )
and user_input_value and user_input_value
and not isinstance(user_input_value, str) and not isinstance(user_input_value, str)
): ):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number # may raise ValueError if user_input_value is not a valid number
try: try:
if '.' in user_input_value: if '.' in user_input_value:
@ -39,14 +39,14 @@ class BaseAppGenerator:
else: else:
return int(user_input_value) return int(user_input_value)
except ValueError: except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number") raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT: if var.type == VariableEntityType.SELECT:
options = var.options or [] options = var.options or []
if user_input_value not in options: if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}') raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length: if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
return user_input_value return user_input_value

View File

@ -256,6 +256,7 @@ class AppRunner:
:param invoke_result: invoke result :param invoke_result: invoke result
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param stream: stream :param stream: stream
:param agent: agent
:return: :return:
""" """
if not stream: if not stream:
@ -278,6 +279,7 @@ class AppRunner:
Handle invoke result direct Handle invoke result direct
:param invoke_result: invoke result :param invoke_result: invoke result
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param agent: agent
:return: :return:
""" """
queue_manager.publish( queue_manager.publish(
@ -293,6 +295,7 @@ class AppRunner:
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param agent: agent
:return: :return:
""" """
model = None model = None

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy import and_ from sqlalchemy import and_
@ -36,17 +37,17 @@ logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator): class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response( def _handle_response(
self, application_generate_entity: Union[ self, application_generate_entity: Union[
ChatAppGenerateEntity, ChatAppGenerateEntity,
CompletionAppGenerateEntity, CompletionAppGenerateEntity,
AgentChatAppGenerateEntity, AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity AdvancedChatAppGenerateEntity
], ],
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool = False, stream: bool = False,
) -> Union[ ) -> Union[
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
@ -138,6 +139,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
""" """
Initialize generate records Initialize generate records
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:conversation conversation
:return: :return:
""" """
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
@ -192,6 +194,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add(conversation) db.session.add(conversation)
db.session.commit() db.session.commit()
db.session.refresh(conversation) db.session.refresh(conversation)
else:
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
message = Message( message = Message(
app_id=app_config.app_id, app_id=app_config.app_id,

View File

@ -13,7 +13,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, EndUser from models.model import App, EndUser
@ -79,14 +79,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_inputs=self.application_generate_entity.single_iteration_run.inputs user_inputs=self.application_generate_entity.single_iteration_run.inputs
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files files = self.application_generate_entity.files
# Create a variable pool. # Create a variable pool.
system_inputs = { system_inputs = {
SystemVariable.FILES: files, SystemVariableKey.FILES: files,
SystemVariable.USER_ID: user_id, SystemVariableKey.USER_ID: user_id,
} }
variable_pool = VariablePool( variable_pool = VariablePool(
@ -98,7 +98,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
# init graph # init graph
graph = self._init_graph(graph_config=workflow.graph_dict) graph = self._init_graph(graph_config=workflow.graph_dict)
# RUN WORKFLOW # RUN WORKFLOW
workflow_entry = WorkflowEntry( workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,

View File

@ -41,7 +41,9 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariable from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import EndUser from models.model import EndUser
@ -64,7 +66,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity _application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow, workflow: Workflow,
@ -88,8 +90,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow self._workflow = workflow
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id SystemVariableKey.USER_ID: user_id
} }
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()

View File

@ -4,7 +4,7 @@ from enum import Enum
from threading import Lock from threading import Lock
from typing import Literal, Optional from typing import Literal, Optional
from httpx import get, post from httpx import Timeout, get, post
from pydantic import BaseModel from pydantic import BaseModel
from yarl import URL from yarl import URL
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
CODE_EXECUTION_TIMEOUT = (10, 60) CODE_EXECUTION_TIMEOUT = Timeout(connect=10, write=10, read=60, pool=None)
class CodeExecutionException(Exception): class CodeExecutionException(Exception):
pass pass
@ -116,7 +116,7 @@ class CodeExecutor:
if response.data.error: if response.data.error:
raise CodeExecutionException(response.data.error) raise CodeExecutionException(response.data.error)
return response.data.stdout return response.data.stdout or ''
@classmethod @classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:

View File

@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
def get_default_code(cls) -> str: def get_default_code(cls) -> str:
return dedent( return dedent(
""" """
def main(arg1: int, arg2: int) -> dict: def main(arg1: str, arg2: str) -> dict:
return { return {
"result": arg1 + arg2, "result": arg1 + arg2,
} }

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file from core.tools.utils.yaml_utils import load_yaml_file
@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
return {name: index for index, name in enumerate(positions)} return {name: index for index, name in enumerate(positions)}
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for tools from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
)
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for providers from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
)
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
"""
Pin the items in the pin list to the beginning of the position map.
Overall logic: exclude > include > pin
:param position_map: the position map to be sorted and filtered
:param pin_list: the list of pins to be put at the beginning
:return: the sorted position map
"""
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
# Add pins to position map
position_map = {name: idx for idx, name in enumerate(pin_list)}
# Add remaining positions to position map
start_idx = len(position_map)
for name in positions:
if name not in position_map:
position_map[name] = start_idx
start_idx += 1
return position_map
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
) -> bool:
"""
Chcek if the object should be filtered out.
Overall logic: exclude > include > pin
:param include_set: the set of names to be included
:param exclude_set: the set of names to be excluded
:param name_func: the function to get the name of the object
:param data: the data to be filtered
:return: True if the object should be filtered out, False otherwise
"""
if not data:
return False
if not include_set and not exclude_set:
return False
name = name_func(data)
if name in exclude_set: # exclude_set is prioritized
return True
if include_set and name not in include_set: # filter out only if include_set is not empty
return True
return False
def sort_by_position_map( def sort_by_position_map(
position_map: dict[str, int], position_map: dict[str, int],
data: list[Any], data: list[Any],

View File

@ -700,6 +700,7 @@ class IndexingRunner:
DatasetDocument.tokens: tokens, DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.error: None,
} }
) )

View File

@ -271,9 +271,8 @@ class ModelInstance:
:param content_text: text content to be translated :param content_text: text content to be translated
:param tenant_id: user tenant id :param tenant_id: user tenant id
:param user: unique user id
:param voice: model timbre :param voice: model timbre
:param streaming: output is streaming :param user: unique user id
:return: text for given audio file :return: text for given audio file
""" """
if not isinstance(self.model_type_instance, TTSModel): if not isinstance(self.model_type_instance, TTSModel):
@ -369,6 +368,15 @@ class ModelManager:
return ModelInstance(provider_model_bundle, model) return ModelInstance(provider_model_bundle, model)
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Return first provider and the first model in the provider
:param tenant_id: tenant id
:param model_type: model type
:return: provider name, model name
"""
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
""" """
Get default model instance Get default model instance
@ -401,6 +409,10 @@ class LBModelManager:
managed_credentials: Optional[dict] = None) -> None: managed_credentials: Optional[dict] = None) -> None:
""" """
Load balancing model manager Load balancing model manager
:param tenant_id: tenant_id
:param provider: provider
:param model_type: model_type
:param model: model name
:param load_balancing_configs: all load balancing configurations :param load_balancing_configs: all load balancing configurations
:param managed_credentials: credentials if load balancing configuration name is __inherit__ :param managed_credentials: credentials if load balancing configuration name is __inherit__
""" """
@ -499,7 +511,6 @@ class LBModelManager:
config.id config.id
) )
res = redis_client.exists(cooldown_cache_key) res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res) res = cast(bool, res)
return res return res

View File

@ -151,9 +151,9 @@ class AIModel(ABC):
os.path.join(provider_model_type_path, model_schema_yaml) os.path.join(provider_model_type_path, model_schema_yaml)
for model_schema_yaml in os.listdir(provider_model_type_path) for model_schema_yaml in os.listdir(provider_model_type_path)
if not model_schema_yaml.startswith('__') if not model_schema_yaml.startswith('__')
and not model_schema_yaml.startswith('_') and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml') and model_schema_yaml.endswith('.yaml')
] ]
# get _position.yaml file path # get _position.yaml file path

View File

@ -185,7 +185,7 @@ if you are not sure about the structure.
stream=stream, stream=stream,
user=user user=user
) )
model_parameters.pop("response_format") model_parameters.pop("response_format")
stop = stop or [] stop = stop or []
stop.extend(["\n```", "```\n"]) stop.extend(["\n```", "```\n"])
@ -249,10 +249,10 @@ if you are not sure about the structure.
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
input_generator=new_generator() input_generator=new_generator()
) )
return response return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None] input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]: ) -> Generator[LLMResultChunk, None, None]:
""" """
@ -310,7 +310,7 @@ if you are not sure about the structure.
) )
) )
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \ input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]: -> Generator[LLMResultChunk, None, None]:
""" """
@ -470,7 +470,7 @@ if you are not sure about the structure.
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -792,6 +792,13 @@ if you are not sure about the structure.
if not isinstance(parameter_value, str): if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be string.") raise ValueError(f"Model Parameter {parameter_name} should be string.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
elif parameter_rule.type == ParameterType.TEXT:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be text.")
# validate options # validate options
if parameter_rule.options and parameter_value not in parameter_rule.options: if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")

View File

@ -70,7 +70,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
# doc: https://platform.openai.com/docs/guides/text-to-speech # doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs) client = AzureOpenAI(**credentials_kwargs)
# max font is 4096,there is 3500 limit for each request # max length is 4096 characters, there is 3500 limit for each request
max_length = 3500 max_length = 3500
if len(content_text) > max_length: if len(content_text) > max_length:
sentences = self._split_text_into_sentences(content_text, max_length=max_length) sentences = self._split_text_into_sentences(content_text, max_length=max_length)

View File

@ -6,7 +6,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@ -234,7 +234,7 @@ class ModelProviderFactory:
] ]
# get _position.yaml file path # get _position.yaml file path
position_map = get_position_map(model_providers_path) position_map = get_provider_position_map(model_providers_path)
# traverse all model_provider_dir_paths # traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = [] model_providers: list[ModelProviderExtension] = []

View File

@ -84,7 +84,8 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _add_custom_parameters(self, credentials: dict) -> None: def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat' credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
def _add_function_call(self, model: str, credentials: dict) -> None: def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials) model_schema = self.get_model_schema(model, credentials)

View File

@ -31,6 +31,14 @@ provider_credential_schema:
placeholder: placeholder:
zh_Hans: 在此输入您的 API Key zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: Base URL, 如https://api.moonshot.cn/v1
en_US: Base URL, e.g. https://api.moonshot.cn/v1
model_credential_schema: model_credential_schema:
model: model:
label: label:

View File

@ -37,6 +37,9 @@ parameter_rules:
options: options:
- text - text
- json_object - json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing: pricing:
input: '0.15' input: '0.15'
output: '0.60' output: '0.60'

View File

@ -0,0 +1,44 @@
model: gpt-4o-2024-08-06
label:
zh_Hans: gpt-4o-2024-08-06
en_US: gpt-4o-2024-08-06
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,4 @@
model: netease-youdao/bce-reranker-base_v1
model_type: rerank
model_properties:
context_size: 512

View File

@ -0,0 +1,4 @@
model: BAAI/bge-reranker-v2-m3
model_type: rerank
model_properties:
context_size: 8192

View File

@ -0,0 +1,87 @@
from typing import Optional
import httpx
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class SiliconflowRerankModel(RerankModel):
def _invoke(self, model: str, credentials: dict, query: str, docs: list[str],
score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
if len(docs) == 0:
return RerankResult(model=model, docs=[])
base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1')
if base_url.endswith('/'):
base_url = base_url[:-1]
try:
response = httpx.post(
base_url + '/rerank',
json={
"model": model,
"query": query,
"documents": docs,
"top_n": top_n,
"return_documents": True
},
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results['results']:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}

View File

@ -12,10 +12,11 @@ help:
en_US: Get your API Key from SiliconFlow en_US: Get your API Key from SiliconFlow
zh_Hans: 从 SiliconFlow 获取 API Key zh_Hans: 从 SiliconFlow 获取 API Key
url: url:
en_US: https://cloud.siliconflow.cn/keys en_US: https://cloud.siliconflow.cn/account/ak
supported_model_types: supported_model_types:
- llm - llm
- text-embedding - text-embedding
- rerank
- speech2text - speech2text
configurate_methods: configurate_methods:
- predefined-model - predefined-model

View File

@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors, RateLimitErrors,
ServerUnavailableErrors, ServerUnavailableErrors,
) )
from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs from core.model_runtime.model_providers.volcengine_maas.llm.models import (
get_model_config,
get_v2_req_params,
)
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
-> LLMResult | Generator: -> LLMResult | Generator:
client = MaaSClient.from_credential(credentials) client = MaaSClient.from_credential(credentials)
req_params = get_v2_req_params(credentials, model_parameters, stop)
req_params = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).copy()
if credentials.get('context_size'):
req_params['max_prompt_tokens'] = credentials.get('context_size')
if credentials.get('max_tokens'):
req_params['max_new_tokens'] = credentials.get('max_tokens')
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
extra_model_kwargs = {} extra_model_kwargs = {}
if tools: if tools:
extra_model_kwargs['tools'] = [ extra_model_kwargs['tools'] = [
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
] ]
resp = MaaSClient.wrap_exception( resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream: if not stream:
@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
""" """
used to define customizable model schema used to define customizable model schema
""" """
max_tokens = ModelConfigs.get( model_config = get_model_config(credentials)
credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
if credentials.get('max_tokens'):
max_tokens = int(credentials.get('max_tokens'))
rules = [ rules = [
ParameterRule( ParameterRule(
name='temperature', name='temperature',
@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='presence_penalty', name='presence_penalty',
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
use_template='presence_penalty', use_template='presence_penalty',
label={ label=I18nObject(
'en_US': 'Presence Penalty', en_US='Presence Penalty',
'zh_Hans': '存在惩罚', zh_Hans= '存在惩罚',
}, ),
min=-2.0, min=-2.0,
max=2.0, max=2.0,
), ),
@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='frequency_penalty', name='frequency_penalty',
type=ParameterType.FLOAT, type=ParameterType.FLOAT,
use_template='frequency_penalty', use_template='frequency_penalty',
label={ label=I18nObject(
'en_US': 'Frequency Penalty', en_US= 'Frequency Penalty',
'zh_Hans': '频率惩罚', zh_Hans= '频率惩罚',
}, ),
min=-2.0, min=-2.0,
max=2.0, max=2.0,
), ),
@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
type=ParameterType.INT, type=ParameterType.INT,
use_template='max_tokens', use_template='max_tokens',
min=1, min=1,
max=max_tokens, max=model_config.properties.max_tokens,
default=512, default=512,
label=I18nObject( label=I18nObject(
zh_Hans='最大生成长度', zh_Hans='最大生成长度',
@ -266,17 +242,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
), ),
] ]
model_properties = ModelConfigs.get( model_properties = {}
credentials['base_model_name'], {}).get('model_properties', {}).copy() model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
if credentials.get('mode'): model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
model_features = ModelConfigs.get(
credentials['base_model_name'], {}).get('features', [])
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject( label=I18nObject(
@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_type=ModelType.LLM, model_type=ModelType.LLM,
model_properties=model_properties, model_properties=model_properties,
parameter_rules=rules, parameter_rules=rules,
features=model_features, features=model_config.features,
) )
return entity return entity

View File

@ -1,181 +1,123 @@
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
ModelConfigs = {
'Doubao-pro-4k': { class ModelProperties(BaseModel):
'req_params': { context_size: int
'max_prompt_tokens': 4096, max_tokens: int
'max_new_tokens': 4096, mode: LLMMode
},
'model_properties': { class ModelConfig(BaseModel):
'context_size': 4096, properties: ModelProperties
'mode': 'chat', features: list[ModelFeature]
},
'features': [
ModelFeature.TOOL_CALL configs: dict[str, ModelConfig] = {
], 'Doubao-pro-4k': ModelConfig(
}, properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
'Doubao-lite-4k': { features=[ModelFeature.TOOL_CALL]
'req_params': { ),
'max_prompt_tokens': 4096, 'Doubao-lite-4k': ModelConfig(
'max_new_tokens': 4096, properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
}, features=[ModelFeature.TOOL_CALL]
'model_properties': { ),
'context_size': 4096, 'Doubao-pro-32k': ModelConfig(
'mode': 'chat', properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
}, features=[ModelFeature.TOOL_CALL]
'features': [ ),
ModelFeature.TOOL_CALL 'Doubao-lite-32k': ModelConfig(
], properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
}, features=[ModelFeature.TOOL_CALL]
'Doubao-pro-32k': { ),
'req_params': { 'Doubao-pro-128k': ModelConfig(
'max_prompt_tokens': 32768, properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
'max_new_tokens': 32768, features=[ModelFeature.TOOL_CALL]
}, ),
'model_properties': { 'Doubao-lite-128k': ModelConfig(
'context_size': 32768, properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
'mode': 'chat', features=[ModelFeature.TOOL_CALL]
}, ),
'features': [ 'Skylark2-pro-4k': ModelConfig(
ModelFeature.TOOL_CALL properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
], features=[]
}, ),
'Doubao-lite-32k': { 'Llama3-8B': ModelConfig(
'req_params': { properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
'max_prompt_tokens': 32768, features=[]
'max_new_tokens': 32768, ),
}, 'Llama3-70B': ModelConfig(
'model_properties': { properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
'context_size': 32768, features=[]
'mode': 'chat', ),
}, 'Moonshot-v1-8k': ModelConfig(
'features': [ properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
ModelFeature.TOOL_CALL features=[]
], ),
}, 'Moonshot-v1-32k': ModelConfig(
'Doubao-pro-128k': { properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
'req_params': { features=[]
'max_prompt_tokens': 131072, ),
'max_new_tokens': 131072, 'Moonshot-v1-128k': ModelConfig(
}, properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
'model_properties': { features=[]
'context_size': 131072, ),
'mode': 'chat', 'GLM3-130B': ModelConfig(
}, properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
'features': [ features=[]
ModelFeature.TOOL_CALL ),
], 'GLM3-130B-Fin': ModelConfig(
}, properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
'Doubao-lite-128k': { features=[]
'req_params': { ),
'max_prompt_tokens': 131072, 'Mistral-7B': ModelConfig(
'max_new_tokens': 131072, properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
}, features=[]
'model_properties': { )
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Skylark2-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4000,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [],
},
'Llama3-8B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Llama3-70B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-8k': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 16384,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 65536,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B-Fin': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Mistral-7B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 2048,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
}
} }
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = configs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_tokens=int(credentials.get('max_tokens', 0)),
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
),
features=[]
)
return model_configs
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None=None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
req_params['max_prompt_tokens'] = model_configs.properties.context_size
req_params['max_new_tokens'] = model_configs.properties.max_tokens
# model parameters
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params

View File

@ -1,9 +1,27 @@
from pydantic import BaseModel
class ModelProperties(BaseModel):
context_size: int
max_chunks: int
class ModelConfig(BaseModel):
properties: ModelProperties
ModelConfigs = { ModelConfigs = {
'Doubao-embedding': { 'Doubao-embedding': ModelConfig(
'req_params': {}, properties=ModelProperties(context_size=4096, max_chunks=1)
'model_properties': { ),
'context_size': 4096,
'max_chunks': 1,
}
},
} }
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = ModelConfigs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_chunks=int(credentials.get('max_chunks', 0)),
)
)
return model_configs

View File

@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors, RateLimitErrors,
ServerUnavailableErrors, ServerUnavailableErrors,
) )
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
""" """
generate custom model entities from credentials generate custom model entities from credentials
""" """
model_properties = ModelConfigs.get( model_config = get_model_config(credentials)
credentials['base_model_name'], {}).get('model_properties', {}).copy() model_properties = {}
if credentials.get('context_size'): model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
credentials.get('context_size', 4096))
if credentials.get('max_chunks'):
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
credentials.get('max_chunks', 4096))
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject(en_US=model), label=I18nObject(en_US=model),

View File

@ -0,0 +1,198 @@
from datetime import datetime, timedelta
from threading import Lock
from requests import post
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
@staticmethod
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class _CommonWenxin:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en',
'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh',
'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
credentials_kwargs = {
"api_key": credentials['api_key'],
"secret_key": credentials['secret_key']
}
return credentials_kwargs
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token

View File

@ -1,102 +1,17 @@
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from json import dumps, loads from json import dumps, loads
from threading import Lock
from typing import Any, Union from typing import Any, Union
from requests import Response, post from requests import Response, post
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError, BadRequestError,
InternalServerError, InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
) )
# map api_key to access_token
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class ErnieMessage: class ErnieMessage:
class Role(Enum): class Role(Enum):
@ -120,51 +35,7 @@ class ErnieMessage:
self.content = content self.content = content
self.role = role self.role = role
class ErnieBotModel: class ErnieBotModel(_CommonWenxin):
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
def generate(self, model: str, stream: bool, messages: list[ErnieMessage], def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
@ -199,51 +70,6 @@ class ErnieBotModel:
return self._handle_chat_stream_generate_response(resp) return self._handle_chat_stream_generate_response(resp)
return self._handle_chat_generate_response(resp) return self._handle_chat_generate_response(resp)
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
return [ErnieMessage(message.content, message.role) for message in messages] return [ErnieMessage(message.content, message.role) for message in messages]

View File

@ -1,17 +0,0 @@
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

View File

@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError, InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
BadRequestError, from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
InsufficientAccountBalance,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
api_key = credentials['api_key'] api_key = credentials['api_key']
secret_key = credentials['secret_key'] secret_key = credentials['secret_key']
try: try:
BaiduAccessToken._get_access_token(api_key, secret_key) BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e: except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
:return: Invoke error mapping :return: Invoke error mapping
""" """
return { return invoke_error_mapping()
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}

View File

@ -0,0 +1,9 @@
model: bge-large-en
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: bge-large-zh
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: embedding-v1
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: tao-8k
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 1
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,184 @@
import time
from abc import abstractmethod
from collections.abc import Mapping
from json import dumps
from typing import Any, Optional
import numpy as np
from requests import Response, post
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
invoke_error_mapping,
)
class TextEmbedding:
@abstractmethod
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
raise NotImplementedError
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
access_token = self._get_access_token()
url = f'{self.api_bases[model]}?access_token={access_token}'
body = self._build_embed_request_body(model, texts, user)
headers = {
'Content-Type': 'application/json',
}
resp = post(url, data=dumps(body), headers=headers)
if resp.status_code != 200:
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
return self._handle_embed_response(model, resp)
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
if len(texts) == 0:
raise BadRequestError('The number of texts should not be zero.')
body = {
'input': texts,
'user_id': user,
}
return body
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
data = response.json()
if 'error_code' in data:
code = data['error_code']
msg = data['error_msg']
# raise error
self._handle_error(code, msg)
embeddings = [v['embedding'] for v in data['data']]
_usage = data['usage']
tokens = _usage['prompt_tokens']
total_tokens = _usage['total_tokens']
return embeddings, tokens, total_tokens
class WenxinTextEmbeddingModel(TextEmbeddingModel):
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
return WenxinTextEmbedding(api_key, secret_key)
def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
secret_key = credentials['secret_key']
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
user = user if user else 'ErnieBotDefault'
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
used_total_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
model,
inputs[i: i + max_chunks],
user)
used_tokens += _used_tokens
used_total_tokens += _total_used_tokens
batched_embeddings += embeddings_batch
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
return TextEmbeddingResult(
model=model,
embeddings=batched_embeddings,
usage=usage,
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0
total_num_tokens = 0
for text in texts:
total_num_tokens += self._get_num_tokens_by_gpt2(text)
return total_num_tokens
def validate_credentials(self, model: str, credentials: Mapping) -> None:
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return invoke_error_mapping()
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=total_tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage

View File

@ -17,6 +17,7 @@ help:
en_US: https://cloud.baidu.com/wenxin.html en_US: https://cloud.baidu.com/wenxin.html
supported_model_types: supported_model_types:
- llm - llm
- text-embedding
configurate_methods: configurate_methods:
- predefined-model - predefined-model
provider_credential_schema: provider_credential_schema:

View File

@ -0,0 +1,57 @@
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

View File

@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
tools=tools, stop=stop, stream=stream, user=user, tools=tools, stop=stop, stream=stream, user=user,
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
) )
) )
@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_param = XinferenceHelper.get_xinference_extra_parameter( extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
) )
if 'completion_type' not in credentials: if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability: if 'chat' in extra_param.model_ability:
@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
else: else:
extra_args = XinferenceHelper.get_xinference_extra_parameter( extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
) )
if 'chat' in extra_args.model_ability: if 'chat' in extra_args.model_ability:
@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
xinference_client = Client( xinference_client = Client(
base_url=credentials['server_url'], base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
) )
xinference_model = xinference_client.get_model(credentials['model_uid']) xinference_model = xinference_client.get_model(credentials['model_uid'])

View File

@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel):
# initialize client # initialize client
client = Client( client = Client(
base_url=credentials['server_url'] base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
) )
xinference_client = client.get_model(model_uid=credentials['model_uid']) xinference_client = client.get_model(model_uid=credentials['model_uid'])

View File

@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
# initialize client # initialize client
client = Client( client = Client(
base_url=credentials['server_url'] base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
) )
xinference_client = client.get_model(model_uid=credentials['model_uid']) xinference_client = client.get_model(model_uid=credentials['model_uid'])

View File

@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
server_url = credentials['server_url'] server_url = credentials['server_url']
model_uid = credentials['model_uid'] model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) api_key = credentials.get('api_key')
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=server_url,
model_uid=model_uid,
api_key=api_key,
)
if extra_args.max_tokens: if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens credentials['max_tokens'] = extra_args.max_tokens
if server_url.endswith('/'): if server_url.endswith('/'):
server_url = server_url[:-1] server_url = server_url[:-1]
client = Client(base_url=server_url) client = Client(
base_url=server_url,
api_key=api_key,
)
try: try:
handle = client.get_model(model_uid=model_uid) handle = client.get_model(model_uid=model_uid)

View File

@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel):
extra_param = XinferenceHelper.get_xinference_extra_parameter( extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
) )
if 'text-to-audio' not in extra_param.model_ability: if 'text-to-audio' not in extra_param.model_ability:
@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel):
credentials['server_url'] = credentials['server_url'][:-1] credentials['server_url'] = credentials['server_url'][:-1]
try: try:
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) api_key = credentials.get('api_key')
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
handle = RESTfulAudioModelHandle(
credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
)
model_support_voice = [x.get("value") for x in model_support_voice = [x.get("value") for x in
self.get_tts_model_voices(model=model, credentials=credentials)] self.get_tts_model_voices(model=model, credentials=credentials)]

View File

@ -35,13 +35,13 @@ cache_lock = Lock()
class XinferenceHelper: class XinferenceHelper:
@staticmethod @staticmethod
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
XinferenceHelper._clean_cache() XinferenceHelper._clean_cache()
with cache_lock: with cache_lock:
if model_uid not in cache: if model_uid not in cache:
cache[model_uid] = { cache[model_uid] = {
'expires': time() + 300, 'expires': time() + 300,
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key)
} }
return cache[model_uid]['value'] return cache[model_uid]['value']
@ -56,7 +56,7 @@ class XinferenceHelper:
pass pass
@staticmethod @staticmethod
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
""" """
get xinference model extra parameter like model_format and model_handle_type get xinference model extra parameter like model_format and model_handle_type
""" """
@ -70,9 +70,10 @@ class XinferenceHelper:
session = Session() session = Session()
session.mount('http://', HTTPAdapter(max_retries=3)) session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3)) session.mount('https://', HTTPAdapter(max_retries=3))
headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
try: try:
response = session.get(url, timeout=10) response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e: except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200: if response.status_code != 200:

View File

@ -5,6 +5,7 @@ from typing import Optional
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import ( from core.entities.provider_entities import (
@ -18,12 +19,9 @@ from core.entities.provider_entities import (
) )
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.helper.position_helper import is_filtered
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ( from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider from extensions import ext_hosting_provider
from extensions.ext_database import db from extensions.ext_database import db
@ -45,6 +43,7 @@ class ProviderManager:
""" """
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.decoding_rsa_key = None self.decoding_rsa_key = None
self.decoding_cipher_rsa = None self.decoding_cipher_rsa = None
@ -117,6 +116,16 @@ class ProviderManager:
# Construct ProviderConfiguration objects for each provider # Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities: for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
continue
provider_name = provider_entity.provider provider_name = provider_entity.provider
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
@ -271,6 +280,24 @@ class ProviderManager:
) )
) )
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Get names of first model and its provider
:param tenant_id: workspace id
:param model_type: model type
:return: provider name, model name
"""
provider_configurations = self.get_configurations(tenant_id)
# get available models from provider_configurations
all_models = provider_configurations.get_models(
model_type=model_type,
only_active=False
)
return all_models[0].provider.provider, all_models[0].model
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
-> TenantDefaultModel: -> TenantDefaultModel:
""" """
@ -811,7 +838,7 @@ class ProviderManager:
-> list[ModelSettings]: -> list[ModelSettings]:
""" """
Convert to model settings. Convert to model settings.
:param provider_entity: provider entity
:param provider_model_settings: provider model settings include enabled, load balancing enabled :param provider_model_settings: provider model settings include enabled, load balancing enabled
:param load_balancing_model_configs: load balancing model configs :param load_balancing_model_configs: load balancing model configs
:return: :return:

View File

@ -152,8 +152,27 @@ class PGVector(BaseVector):
return docs return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# do not support bm25 search top_k = kwargs.get("top_k", 5)
return []
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
for record in cur:
metadata, text, score = record
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def delete(self) -> None: def delete(self) -> None:
with self._get_cursor() as cur: with self._get_cursor() as cur:

View File

@ -21,6 +21,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
create an image message create an image message
:param image: the url of the image :param image: the url of the image
:param save_as: save as
:return: the image message :return: the image message
""" """
``` ```
@ -34,6 +35,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
create a link message create a link message
:param link: the url of the link :param link: the url of the link
:param save_as: save as
:return: the link message :return: the link message
""" """
``` ```
@ -47,6 +49,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
create a text message create a text message
:param text: the text of the message :param text: the text of the message
:param save_as: save as
:return: the text message :return: the text message
""" """
``` ```
@ -63,6 +66,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
create a blob message create a blob message
:param blob: the blob :param blob: the blob
:param meta: meta
:param save_as: save as
:return: the blob message :return: the blob message
""" """
``` ```

View File

@ -1,6 +1,6 @@
import os.path import os.path
from core.helper.position_helper import get_position_map, sort_by_position_map from core.helper.position_helper import get_tool_position_map, sort_by_position_map
from core.tools.entities.api_entities import UserToolProvider from core.tools.entities.api_entities import UserToolProvider
@ -10,11 +10,11 @@ class BuiltinToolProviderSort:
@classmethod @classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position: if not cls._position:
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
def name_func(provider: UserToolProvider) -> str: def name_func(provider: UserToolProvider) -> str:
return provider.name return provider.name
sorted_providers = sort_by_position_map(cls._position, providers, name_func) sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers return sorted_providers

View File

@ -0,0 +1,49 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 19.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 200 130.2" style="enable-background:new 0 0 200 130.2;" xml:space="preserve">
<style type="text/css">
.st0{fill:#3EB1C8;}
.st1{fill:#D8D2C4;}
.st2{fill:#4F5858;}
.st3{fill:#FFC72C;}
.st4{fill:#EF3340;}
</style>
<g>
<polygon class="st0" points="111.8,95.5 111.8,66.8 135.4,59 177.2,73.3 "/>
<polygon class="st1" points="153.6,36.8 111.8,51.2 135.4,59 177.2,44.6 "/>
<polygon class="st2" points="135.4,59 177.2,44.6 177.2,73.3 "/>
<polygon class="st3" points="177.2,0.3 177.2,29 153.6,36.8 111.8,22.5 "/>
<polygon class="st4" points="153.6,36.8 111.8,51.2 111.8,22.5 "/>
<g>
<g>
<g>
<g>
<path class="st2" d="M26.3,104.8c-0.5-3.7-4.1-6.5-8.1-6.5c-7.3,0-10.1,6.2-10.1,12.7c0,6.2,2.8,12.4,10.1,12.4
c5,0,7.8-3.4,8.4-8.3h7.9c-0.8,9.2-7.2,15.2-16.3,15.2C6.8,130.2,0,121.7,0,111c0-11,6.8-19.6,18.2-19.6c8.2,0,15,4.8,16,13.3
H26.3z"/>
<path class="st2" d="M37.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
<path class="st2" d="M68.7,101.8c8.5,0,13.9,5.6,13.9,14.2c0,8.5-5.5,14.1-13.9,14.1c-8.4,0-13.9-5.6-13.9-14.1
C54.9,107.4,60.3,101.8,68.7,101.8z M68.7,124.5c5,0,6.5-4.3,6.5-8.6c0-4.3-1.5-8.6-6.5-8.6c-5,0-6.5,4.3-6.5,8.6
C62.2,120.2,63.8,124.5,68.7,124.5z"/>
<path class="st2" d="M91.2,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2c-4.3-0.9-8.5-2.4-8.5-7.2
c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5c0,2.6,4.2,3,8.4,4
c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H91.2z"/>
<path class="st2" d="M118.1,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2
c-4.3-0.9-8.5-2.4-8.5-7.2c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5
c0,2.6,4.2,3,8.4,4c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H118.1z"/>
<path class="st2" d="M138.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
<path class="st2" d="M163.7,117.7c0.2,4.7,2.5,6.8,6.6,6.8c3,0,5.3-1.8,5.8-3.5h6.5c-2.1,6.3-6.5,9-12.6,9
c-8.5,0-13.7-5.8-13.7-14.1c0-8,5.6-14.2,13.7-14.2c9.1,0,13.6,7.7,13,15.9H163.7z M175.7,113.1c-0.7-3.7-2.3-5.7-5.9-5.7
c-4.7,0-6,3.6-6.1,5.7H175.7z"/>
<path class="st2" d="M187.2,107.5h-4.4v-4.9h4.4v-2.1c0-4.7,3-8.2,9-8.2c1.3,0,2.6,0.2,3.9,0.2V98c-0.9-0.1-1.8-0.2-2.7-0.2
c-2,0-2.8,0.8-2.8,3.1v1.6h5.1v4.9h-5.1v21.9h-7.4V107.5z"/>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 3.0 KiB

View File

@ -0,0 +1,20 @@
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class CrossRefProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
CrossRefQueryDOITool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"doi": '10.1007/s00894-022-05373-8',
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,29 @@
identity:
author: Sakura4036
name: crossref
label:
en_US: CrossRef
zh_Hans: CrossRef
description:
en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers.
zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接使得读者能够非常便捷地获取文献全文。
icon: icon.svg
tags:
- search
credentials_for_provider:
mailto:
type: text-input
required: true
label:
en_US: email address
zh_Hans: email地址
pt_BR: email address
placeholder:
en_US: Please input your email address
zh_Hans: 请输入你的email地址
pt_BR: Please input your email address
help:
en_US: According to the requirements of Crossref, an email address is required
zh_Hans: 根据Crossref的要求需要提供一个邮箱地址
pt_BR: According to the requirements of Crossref, an email address is required
url: https://api.crossref.org/swagger-ui/index.html

View File

@ -0,0 +1,25 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class CrossRefQueryDOITool(BuiltinTool):
"""
Tool for querying the metadata of a publication using its DOI.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
doi = tool_parameters.get('doi')
if not doi:
raise ToolParameterValidationError('doi is required.')
# doc: https://github.com/CrossRef/rest-api-doc
url = f"https://api.crossref.org/works/{doi}"
response = requests.get(url)
response.raise_for_status()
response = response.json()
message = response.get('message', {})
return self.create_json_message(message)

View File

@ -0,0 +1,23 @@
identity:
name: crossref_query_doi
author: Sakura4036
label:
en_US: CrossRef Query DOI
zh_Hans: CrossRef DOI 查询
pt_BR: CrossRef Query DOI
description:
human:
en_US: A tool for searching literature information using CrossRef by DOI.
zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。
pt_BR: A tool for searching literature information using CrossRef by DOI.
llm: A tool for searching literature information using CrossRef by DOI.
parameters:
- name: doi
type: string
required: true
label:
en_US: DOI
zh_Hans: DOI
pt_BR: DOI
llm_description: DOI for searching in CrossRef
form: llm

View File

@ -0,0 +1,120 @@
import time
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
def convert_time_str_to_seconds(time_str: str) -> int:
"""
Convert a time string to seconds.
example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430
"""
time_str = time_str.lower().strip().replace(' ', '')
seconds = 0
if 'h' in time_str:
hours, time_str = time_str.split('h')
seconds += int(hours) * 3600
if 'm' in time_str:
minutes, time_str = time_str.split('m')
seconds += int(minutes) * 60
if 's' in time_str:
seconds += int(time_str.replace('s', ''))
return seconds
class CrossRefQueryTitleAPI:
"""
Tool for querying the metadata of a publication using its title.
Crossref API doc: https://github.com/CrossRef/rest-api-doc
"""
query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}"
rate_limit: int = 50
rate_interval: float = 1
max_limit: int = 1000
def __init__(self, mailto: str):
self.mailto = mailto
def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
"""
Query the metadata of a publication using its title.
:param query: the title of the publication
:param rows: the number of results to return
:param sort: the sort field
:param order: the sort order
:param fuzzy_query: whether to return all items that match the query
"""
url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto)
response = requests.get(url)
response.raise_for_status()
rate_limit = int(response.headers['x-ratelimit-limit'])
# convert time string to seconds
rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval'])
self.rate_limit = rate_limit
self.rate_interval = rate_interval
response = response.json()
if response['status'] != 'ok':
return []
message = response['message']
if fuzzy_query:
# fuzzy query return all items
return message['items']
else:
for paper in message['items']:
title = paper['title'][0]
if title.lower() != query.lower():
continue
return [paper]
return []
def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
"""
Query the metadata of a publication using its title.
:param query: the title of the publication
:param rows: the number of results to return
:param sort: the sort field
:param order: the sort order
:param fuzzy_query: whether to return all items that match the query
"""
rows = min(rows, self.max_limit)
if rows > self.rate_limit:
# query multiple times
query_times = rows // self.rate_limit + 1
results = []
for i in range(query_times):
result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query)
if fuzzy_query:
results.extend(result)
else:
# fuzzy_query=False, only one result
if result:
return result
time.sleep(self.rate_interval)
return results
else:
# query once
return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query)
class CrossRefQueryTitleTool(BuiltinTool):
"""
Tool for querying the metadata of a publication using its title.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
query = tool_parameters.get('query')
fuzzy_query = tool_parameters.get('fuzzy_query', False)
rows = tool_parameters.get('rows', 3)
sort = tool_parameters.get('sort', 'relevance')
order = tool_parameters.get('order', 'desc')
mailto = self.runtime.credentials['mailto']
result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query)
return [self.create_json_message(r) for r in result]

View File

@ -0,0 +1,105 @@
identity:
name: crossref_query_title
author: Sakura4036
label:
en_US: CrossRef Title Query
zh_Hans: CrossRef 标题查询
pt_BR: CrossRef Title Query
description:
human:
en_US: A tool for querying literature information using CrossRef by title.
zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。
pt_BR: A tool for querying literature information using CrossRef by title.
llm: A tool for querying literature information using CrossRef by title.
parameters:
- name: query
type: string
required: true
label:
en_US: 标题
zh_Hans: 查询语句
pt_BR: 标题
human_description:
en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
zh_Hans: 用于搜索文献信息有助于查找引用。包括标题作者ISSN和出版年份
pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
llm_description: key words for querying in Web of Science
form: llm
- name: fuzzy_query
type: boolean
default: false
label:
en_US: Whether to fuzzy search
zh_Hans: 是否模糊搜索
pt_BR: Whether to fuzzy search
human_description:
en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
zh_Hans: 用于选择搜索类型模糊搜索返回更多结果精确搜索返回1条结果或无
pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
form: form
- name: limit
type: number
required: false
label:
en_US: max query number
zh_Hans: 最大搜索数
pt_BR: max query number
human_description:
en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数)
pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
form: llm
default: 50
- name: sort
type: select
required: true
options:
- value: relevance
label:
en_US: relevance
zh_Hans: 相关性
pt_BR: relevance
- value: published
label:
en_US: publication date
zh_Hans: 出版日期
pt_BR: publication date
- value: references-count
label:
en_US: references-count
zh_Hans: 引用次数
pt_BR: references-count
default: relevance
label:
en_US: sorting field
zh_Hans: 排序字段
pt_BR: sorting field
human_description:
en_US: Sorting of query results
zh_Hans: 检索结果的排序字段
pt_BR: Sorting of query results
form: form
- name: order
type: select
required: true
options:
- value: desc
label:
en_US: descending
zh_Hans: 降序
pt_BR: descending
- value: asc
label:
en_US: ascending
zh_Hans: 升序
pt_BR: ascending
default: desc
label:
en_US: Order
zh_Hans: 排序
pt_BR: Order
human_description:
en_US: Order of query results
zh_Hans: 检索结果的排序方式
pt_BR: Order of query results
form: form

View File

@ -29,6 +29,6 @@ class GitlabProvider(BuiltinToolProviderController):
if response.status_code != 200: if response.status_code != 200:
raise ToolProviderCredentialValidationError((response.json()).get('message')) raise ToolProviderCredentialValidationError((response.json()).get('message'))
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e)) raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e))
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError(str(e)) raise ToolProviderCredentialValidationError(str(e))

View File

@ -2,37 +2,37 @@ identity:
author: Leo.Wang author: Leo.Wang
name: gitlab name: gitlab
label: label:
en_US: Gitlab en_US: GitLab
zh_Hans: Gitlab zh_Hans: GitLab
description: description:
en_US: Gitlab plugin for commit en_US: GitLab plugin, API v4 only.
zh_Hans: 用于获取Gitlab commit的插件 zh_Hans: 用于获取GitLab内容的插件目前仅支持 API v4。
icon: gitlab.svg icon: gitlab.svg
credentials_for_provider: credentials_for_provider:
access_tokens: access_tokens:
type: secret-input type: secret-input
required: true required: true
label: label:
en_US: Gitlab access token en_US: GitLab access token
zh_Hans: Gitlab access token zh_Hans: GitLab access token
placeholder: placeholder:
en_US: Please input your Gitlab access token en_US: Please input your GitLab access token
zh_Hans: 请输入你的 Gitlab access token zh_Hans: 请输入你的 GitLab access token
help: help:
en_US: Get your Gitlab access token from Gitlab en_US: Get your GitLab access token from GitLab
zh_Hans: 从 Gitlab 获取您的 access token zh_Hans: 从 GitLab 获取您的 access token
url: https://docs.gitlab.com/16.9/ee/api/oauth2.html url: https://docs.gitlab.com/16.9/ee/api/oauth2.html
site_url: site_url:
type: text-input type: text-input
required: false required: false
default: 'https://gitlab.com' default: 'https://gitlab.com'
label: label:
en_US: Gitlab site url en_US: GitLab site url
zh_Hans: Gitlab site url zh_Hans: GitLab site url
placeholder: placeholder:
en_US: Please input your Gitlab site url en_US: Please input your GitLab site url
zh_Hans: 请输入你的 Gitlab site url zh_Hans: 请输入你的 GitLab site url
help: help:
en_US: Find your Gitlab url en_US: Find your GitLab url
zh_Hans: 找到你的Gitlab url zh_Hans: 找到你的 GitLab url
url: https://gitlab.com/help url: https://gitlab.com/help

View File

@ -18,6 +18,7 @@ class GitlabCommitsTool(BuiltinTool):
employee = tool_parameters.get('employee', '') employee = tool_parameters.get('employee', '')
start_time = tool_parameters.get('start_time', '') start_time = tool_parameters.get('start_time', '')
end_time = tool_parameters.get('end_time', '') end_time = tool_parameters.get('end_time', '')
change_type = tool_parameters.get('change_type', 'all')
if not project: if not project:
return self.create_text_message('Project is required') return self.create_text_message('Project is required')
@ -36,11 +37,11 @@ class GitlabCommitsTool(BuiltinTool):
site_url = 'https://gitlab.com' site_url = 'https://gitlab.com'
# Get commit content # Get commit content
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time) result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
return self.create_text_message(json.dumps(result, ensure_ascii=False)) return [self.create_json_message(item) for item in result]
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]: def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]:
domain = site_url domain = site_url
headers = {"PRIVATE-TOKEN": access_token} headers = {"PRIVATE-TOKEN": access_token}
results = [] results = []
@ -74,7 +75,7 @@ class GitlabCommitsTool(BuiltinTool):
for commit in commits: for commit in commits:
commit_sha = commit['id'] commit_sha = commit['id']
print(f"\tCommit SHA: {commit_sha}") author_name = commit['author_name']
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
diff_response = requests.get(diff_url, headers=headers) diff_response = requests.get(diff_url, headers=headers)
@ -87,14 +88,23 @@ class GitlabCommitsTool(BuiltinTool):
removed_lines = diff['diff'].count('\n-') removed_lines = diff['diff'].count('\n-')
total_changes = added_lines + removed_lines total_changes = added_lines + removed_lines
if total_changes > 1: if change_type == "new":
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) if added_lines > 1:
results.append({ final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
"project": project_name, results.append({
"commit_sha": commit_sha, "commit_sha": commit_sha,
"diff": final_code "author_name": author_name,
}) "diff": final_code
print(f"Commit code:{final_code}") })
else:
if total_changes > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')])
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
results.append({
"commit_sha": commit_sha,
"author_name": author_name,
"diff": final_code_escaped
})
except requests.RequestException as e: except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}") print(f"Error fetching data from GitLab: {e}")

View File

@ -2,24 +2,24 @@ identity:
name: gitlab_commits name: gitlab_commits
author: Leo.Wang author: Leo.Wang
label: label:
en_US: Gitlab Commits en_US: GitLab Commits
zh_Hans: Gitlab代码提交内容 zh_Hans: GitLab 提交内容查询
description: description:
human: human:
en_US: A tool for query gitlab commits. Input should be a exists username. en_US: A tool for query GitLab commits, Input should be a exists username or projec.
zh_Hans: 一个用于查询gitlab代码提交记录的的工具,输入的内容应该是一个已存在的用户名或者项目名。 zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。
llm: A tool for query gitlab commits. Input should be a exists username or project. llm: A tool for query GitLab commits, Input should be a exists username or project.
parameters: parameters:
- name: employee - name: username
type: string type: string
required: false required: false
label: label:
en_US: employee en_US: username
zh_Hans: 员工用户名 zh_Hans: 员工用户名
human_description: human_description:
en_US: employee en_US: username
zh_Hans: 员工用户名 zh_Hans: 员工用户名
llm_description: employee for gitlab llm_description: User name for GitLab
form: llm form: llm
- name: project - name: project
type: string type: string
@ -30,7 +30,7 @@ parameters:
human_description: human_description:
en_US: project en_US: project
zh_Hans: 项目名 zh_Hans: 项目名
llm_description: project for gitlab llm_description: project for GitLab
form: llm form: llm
- name: start_time - name: start_time
type: string type: string
@ -41,7 +41,7 @@ parameters:
human_description: human_description:
en_US: start_time en_US: start_time
zh_Hans: 开始时间 zh_Hans: 开始时间
llm_description: start_time for gitlab llm_description: Start time for GitLab
form: llm form: llm
- name: end_time - name: end_time
type: string type: string
@ -52,5 +52,26 @@ parameters:
human_description: human_description:
en_US: end_time en_US: end_time
zh_Hans: 结束时间 zh_Hans: 结束时间
llm_description: end_time for gitlab llm_description: End time for GitLab
form: llm
- name: change_type
type: select
required: false
options:
- value: all
label:
en_US: all
zh_Hans: 所有
- value: new
label:
en_US: new
zh_Hans: 新增
default: all
label:
en_US: change_type
zh_Hans: 变更类型
human_description:
en_US: change_type
zh_Hans: 变更类型
llm_description: Content change type for GitLab
form: llm form: llm

View File

@ -0,0 +1,95 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GitlabFilesTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get('project', '')
branch = tool_parameters.get('branch', '')
path = tool_parameters.get('path', '')
if not project:
return self.create_text_message('Project is required')
if not branch:
return self.create_text_message('Branch is required')
if not path:
return self.create_text_message('Path is required')
access_token = self.runtime.credentials.get('access_tokens')
site_url = self.runtime.credentials.get('site_url')
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
return self.create_text_message("Gitlab API Access Tokens is required.")
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
site_url = 'https://gitlab.com'
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, project)
if not project_id:
return self.create_text_message(f"Project '{project}' not found.")
# Get commit content
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
return [self.create_json_message(item) for item in result]
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
parts = path.split('/', 1)
if len(parts) < 2:
return None, None
return parts[0], parts[1]
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
headers = {"PRIVATE-TOKEN": access_token}
try:
url = f"{site_url}/api/v4/projects?search={project_name}"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
for project in projects:
if project['name'] == project_name:
return project['id']
except requests.RequestException as e:
print(f"Error fetching project ID from GitLab: {e}")
return None
def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# List files and directories in the given path
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
response = requests.get(url, headers=headers)
response.raise_for_status()
items = response.json()
for item in items:
item_path = item['path']
if item['type'] == 'tree': # It's a directory
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
else: # It's a file
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({
"path": item_path,
"branch": branch,
"content": file_content
})
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results

View File

@ -0,0 +1,45 @@
identity:
name: gitlab_files
author: Leo.Wang
label:
en_US: GitLab Files
zh_Hans: GitLab 文件获取
description:
human:
en_US: A tool for query GitLab files, Input should be branch and a exists file or directory path.
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
llm: A tool for query GitLab files, Input should be a exists file or directory path.
parameters:
- name: project
type: string
required: true
label:
en_US: project
zh_Hans: 项目
human_description:
en_US: project
zh_Hans: 项目
llm_description: Project for GitLab
form: llm
- name: branch
type: string
required: true
label:
en_US: branch
zh_Hans: 分支
human_description:
en_US: branch
zh_Hans: 分支
llm_description: Branch for GitLab
form: llm
- name: path
type: string
required: true
label:
en_US: path
zh_Hans: 文件路径
human_description:
en_US: path
zh_Hans: 文件路径
llm_description: File path for GitLab
form: llm

View File

@ -0,0 +1,43 @@
from typing import Any
from core.helper import ssrf_proxy
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class JinaTokenizerTool(BuiltinTool):
_jina_tokenizer_endpoint = 'https://tokenize.jina.ai/'
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> ToolInvokeMessage:
content = tool_parameters['content']
body = {
"content": content
}
headers = {
'Content-Type': 'application/json'
}
if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'):
headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key')
if tool_parameters.get('return_chunks', False):
body['return_chunks'] = True
if tool_parameters.get('return_tokens', False):
body['return_tokens'] = True
if tokenizer := tool_parameters.get('tokenizer'):
body['tokenizer'] = tokenizer
response = ssrf_proxy.post(
self._jina_tokenizer_endpoint,
headers=headers,
json=body,
)
return self.create_json_message(response.json())

View File

@ -0,0 +1,70 @@
identity:
name: jina_tokenizer
author: hjlarry
label:
en_US: JinaTokenizer
description:
human:
en_US: Free API to tokenize text and segment long text into chunks.
zh_Hans: 免费的API可以将文本tokenize也可以将长文本分割成多个部分。
llm: Free API to tokenize text and segment long text into chunks.
parameters:
- name: content
type: string
required: true
label:
en_US: Content
zh_Hans: 内容
llm_description: the content which need to tokenize or segment
form: llm
- name: return_tokens
type: boolean
required: false
label:
en_US: Return the tokens
zh_Hans: 是否返回tokens
human_description:
en_US: Return the tokens and their corresponding ids in the response.
zh_Hans: 返回tokens及其对应的ids。
form: form
- name: return_chunks
type: boolean
label:
en_US: Return the chunks
zh_Hans: 是否分块
human_description:
en_US: Chunking the input into semantically meaningful segments while handling a wide variety of text types and edge cases based on common structural cues.
zh_Hans: 将输入分块为具有语义意义的片段,同时根据常见的结构线索处理各种文本类型和边缘情况。
form: form
- name: tokenizer
type: select
options:
- value: cl100k_base
label:
en_US: cl100k_base
- value: o200k_base
label:
en_US: o200k_base
- value: p50k_base
label:
en_US: p50k_base
- value: r50k_base
label:
en_US: r50k_base
- value: p50k_edit
label:
en_US: p50k_edit
- value: gpt2
label:
en_US: gpt2
label:
en_US: Tokenizer
human_description:
en_US: |
· cl100k_base --- gpt-4, gpt-3.5-turbo, gpt-3.5
· o200k_base --- gpt-4o, gpt-4o-mini
· p50k_base --- text-davinci-003, text-davinci-002
· r50k_base --- text-davinci-001, text-curie-001
· p50k_edit --- text-davinci-edit-001, code-davinci-edit-001
· gpt2 --- gpt-2
form: form

View File

@ -0,0 +1,73 @@
from novita_client import (
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
Txt2ImgV3Refiner,
V3TaskImage,
)
class NovitaAiToolBase:
def _extract_loras(self, loras_str: str):
if not loras_str:
return []
loras_ori_list = lora_str.strip().split(';')
result_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
)
result_list.append(lora)
return result_list
def _extract_embeddings(self, embeddings_str: str):
if not embeddings_str:
return []
embeddings_ori_list = embeddings_str.strip().split(';')
result_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
result_list.append(embedding)
return result_list
def _extract_hires_fix(self, hires_fix_str: str):
hires_fix_info = hires_fix_str.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
)
return hires_fix
def _extract_refiner(self, switch_at: str):
refiner = Txt2ImgV3Refiner(
switch_at=float(switch_at)
)
return refiner
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
return True
return False

View File

@ -4,19 +4,15 @@ from typing import Any, Union
from novita_client import ( from novita_client import (
NovitaClient, NovitaClient,
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
Txt2ImgV3Refiner,
V3TaskImage,
) )
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
class NovitaAiTxt2ImgTool(BuiltinTool): class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
def _invoke(self, def _invoke(self,
user_id: str, user_id: str,
tool_parameters: dict[str, Any], tool_parameters: dict[str, Any],
@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
# process loras # process loras
if 'loras' in res_parameters: if 'loras' in res_parameters:
loras_ori_list = res_parameters.get('loras').strip().split(';') res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
locals_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
)
locals_list.append(lora)
res_parameters['loras'] = locals_list
# process embeddings # process embeddings
if 'embeddings' in res_parameters: if 'embeddings' in res_parameters:
embeddings_ori_list = res_parameters.get('embeddings').strip().split(';') res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
locals_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
locals_list.append(embedding)
res_parameters['embeddings'] = locals_list
# process hires_fix # process hires_fix
if 'hires_fix' in res_parameters: if 'hires_fix' in res_parameters:
hires_fix_ori = res_parameters.get('hires_fix') res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
hires_fix_info = hires_fix_ori.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
)
res_parameters['hires_fix'] = hires_fix # process refiner
if 'refiner_switch_at' in res_parameters:
if 'refiner_switch_at' in res_parameters: res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
refiner = Txt2ImgV3Refiner( del res_parameters['refiner_switch_at']
switch_at=float(res_parameters.get('refiner_switch_at'))
)
del res_parameters['refiner_switch_at']
res_parameters['refiner'] = refiner
return res_parameters return res_parameters
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
return True
return False

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
@ -18,6 +18,13 @@ from models.model import App, AppMode
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
from models.workflow import Workflow from models.workflow import Workflow
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
}
class WorkflowToolProviderController(ToolProviderController): class WorkflowToolProviderController(ToolProviderController):
provider_id: str provider_id: str
@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
if not app: if not app:
raise ValueError('app not found') raise ValueError('app not found')
controller = WorkflowToolProviderController(**{ controller = WorkflowToolProviderController(**{
'identity': { 'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
'credentials_schema': {}, 'credentials_schema': {},
'provider_id': db_provider.id or '', 'provider_id': db_provider.id or '',
}) })
# init tools # init tools
controller.tools = [controller._get_db_provider_tool(db_provider, app)] controller.tools = [controller._get_db_provider_tool(db_provider, app)]
@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
""" """
get db provider tool get db provider tool
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
if variable: if variable:
parameter_type = None parameter_type = None
options = None options = None
if variable.type in [ if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH,
]:
parameter_type = ToolParameter.ToolParameterType.STRING
elif variable.type in [
VariableEntity.Type.SELECT
]:
parameter_type = ToolParameter.ToolParameterType.SELECT
elif variable.type in [
VariableEntity.Type.NUMBER
]:
parameter_type = ToolParameter.ToolParameterType.NUMBER
else:
raise ValueError(f'unsupported variable type {variable.type}') raise ValueError(f'unsupported variable type {variable.type}')
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
if variable.type == VariableEntity.Type.SELECT and variable.options:
if variable.type == VariableEntityType.SELECT and variable.options:
options = [ options = [
ToolParameterOption( ToolParameterOption(
value=option, value=option,
@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
""" """
if self.tools is not None: if self.tools is not None:
return self.tools return self.tools
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id, WorkflowToolProvider.app_id == self.provider_id,
@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers: if not db_providers:
return [] return []
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
return self.tools return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
""" """
get tool by name get tool by name
@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
for tool in self.tools: for tool in self.tools:
if tool.identity.name == tool_name: if tool.identity.name == tool_name:
return tool return tool
return None return None

View File

@ -10,14 +10,11 @@ from configs import dify_config
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
)
from core.tools.errors import ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin._positions import BuiltinToolProviderSort
@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ( from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@ -37,6 +31,7 @@ from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolManager: class ToolManager:
_builtin_provider_lock = Lock() _builtin_provider_lock = Lock()
_builtin_providers = {} _builtin_providers = {}
@ -106,7 +101,7 @@ class ToolManager:
tenant_id: str, tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool]:
""" """
get the tool runtime get the tool runtime
@ -345,7 +340,7 @@ class ToolManager:
provider_class = load_single_subclass_from_source( provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}', module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)), script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'), 'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController) parent_type=BuiltinToolProviderController)
provider: BuiltinToolProviderController = provider_class() provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider cls._builtin_providers[provider.identity.name] = provider
@ -413,6 +408,15 @@ class ToolManager:
# append builtin providers # append builtin providers
for provider in builtin_providers: for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider( user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider, provider_controller=provider,
db_provider=find_db_builtin_provider(provider.identity.name), db_provider=find_db_builtin_provider(provider.identity.name),
@ -472,7 +476,7 @@ class ToolManager:
@classmethod @classmethod
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
ApiToolProviderController, dict[str, Any]]: ApiToolProviderController, dict[str, Any]]:
""" """
get the api provider get the api provider
@ -592,4 +596,5 @@ class ToolManager:
else: else:
raise ValueError(f"provider type {provider_type} not found") raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_builtin_providers_cache() ToolManager.load_builtin_providers_cache()

View File

@ -7,14 +7,14 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, FileVar] VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys' SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = 'env' ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = 'conversation' CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool(BaseModel): class VariablePool(BaseModel):
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
description='User inputs', description='User inputs',
) )
system_variables: Mapping[SystemVariable, Any] = Field( system_variables: Mapping[SystemVariableKey, Any] = Field(
description='System variables', description='System variables',
) )
@ -78,7 +78,7 @@ class VariablePool(BaseModel):
None None
""" """
if len(selector) < 2: if len(selector) < 2:
raise ValueError('Invalid selector') raise ValueError("Invalid selector")
if value is None: if value is None:
return return
@ -105,13 +105,13 @@ class VariablePool(BaseModel):
ValueError: If the selector is invalid. ValueError: If the selector is invalid.
""" """
if len(selector) < 2: if len(selector) < 2:
raise ValueError('Invalid selector') raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key) value = self.variable_dictionary[selector[0]].get(hash_key)
return value return value
@deprecated('This method is deprecated, use `get` instead.') @deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None: def get_any(self, selector: Sequence[str], /) -> Any | None:
""" """
Retrieves the value from the variable pool based on the given selector. Retrieves the value from the variable pool based on the given selector.
@ -126,7 +126,7 @@ class VariablePool(BaseModel):
ValueError: If the selector is invalid. ValueError: If the selector is invalid.
""" """
if len(selector) < 2: if len(selector) < 2:
raise ValueError('Invalid selector') raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key) value = self.variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None return value.to_object() if value else None

View File

@ -1,25 +1,13 @@
from enum import Enum from enum import Enum
class SystemVariable(str, Enum): class SystemVariableKey(str, Enum):
""" """
System Variables. System Variables.
""" """
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod QUERY = "query"
def value_of(cls, value: str): FILES = "files"
""" CONVERSATION_ID = "conversation_id"
Get value of given system variable. USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')

View File

@ -13,8 +13,8 @@ from models.workflow import WorkflowNodeExecutionStatus
MAX_NUMBER = dify_config.CODE_MAX_NUMBER MAX_NUMBER = dify_config.CODE_MAX_NUMBER
MIN_NUMBER = dify_config.CODE_MIN_NUMBER MIN_NUMBER = dify_config.CODE_MIN_NUMBER
MAX_PRECISION = 20 MAX_PRECISION = dify_config.CODE_MAX_PRECISION
MAX_DEPTH = 5 MAX_DEPTH = dify_config.CODE_MAX_DEPTH
MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH
MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH
@ -23,7 +23,7 @@ MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH
class CodeNode(BaseNode): class CodeNode(BaseNode):
_node_data_cls = CodeNodeData _node_data_cls = CodeNodeData
node_type = NodeType.CODE _node_type = NodeType.CODE
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -316,8 +316,8 @@ class CodeNode(BaseNode):
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: CodeNodeData node_data: CodeNodeData
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:

View File

@ -25,7 +25,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
@ -110,7 +110,7 @@ class LLMNode(BaseNode):
# fetch prompt messages # fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data, node_data=node_data,
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value]) query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
if node_data.memory else None, if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs, inputs=inputs,
@ -370,7 +370,7 @@ class LLMNode(BaseNode):
inputs[variable_selector.variable] = variable_value inputs[variable_selector.variable] = variable_value
return inputs # type: ignore return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
""" """
@ -382,7 +382,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled: if not node_data.vision.enabled:
return [] return []
files = variable_pool.get_any(['sys', SystemVariable.FILES.value]) files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
if not files: if not files:
return [] return []
@ -543,7 +543,7 @@ class LLMNode(BaseNode):
return None return None
# get conversation id # get conversation id
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value]) conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None: if conversation_id is None:
return None return None
@ -722,10 +722,10 @@ class LLMNode(BaseNode):
variable_mapping['#context#'] = node_data.context.variable_selector variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled: if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
if node_data.memory: if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
if node_data.prompt_config: if node_data.prompt_config:
enable_jinja = False enable_jinja = False

View File

@ -1,3 +1,7 @@
from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
""" """
Start Node Data Start Node Data
""" """
variables: list[VariableEntity] = [] variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from typing import Any from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -17,22 +18,22 @@ class StartNode(BaseNode):
Run node Run node
:return: :return:
""" """
# Get cleaned inputs node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
cleaned_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) system_inputs = self.graph_runtime_state.variable_pool.system_variables
for var in self.graph_runtime_state.variable_pool.system_variables: for var in system_inputs:
cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var] node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs, inputs=node_inputs,
outputs=cleaned_inputs outputs=node_inputs
) )
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: StartNodeData node_data: StartNodeData
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:

View File

@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
from os import path from os import path
from typing import Any, cast from typing import Any, cast
from core.app.segments import ArrayAnyVariable, parser from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -141,8 +141,8 @@ class ToolNode(BaseNode):
return result return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariable.FILES.value]) variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else [] return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\ def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\

View File

@ -1,109 +1,8 @@
from collections.abc import Sequence from .node import VariableAssignerNode
from enum import Enum from .node_data import VariableAssignerData, WriteMode
from typing import Optional, cast
from sqlalchemy import select __all__ = [
from sqlalchemy.orm import Session 'VariableAssignerNode',
'VariableAssignerData',
from core.app.segments import SegmentType, Variable, factory 'WriteMode',
from core.workflow.entities.base_node_data_entities import BaseNodeData ]
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
class VariableAssignerNodeError(Exception):
pass
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]
class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
# Update conversation variable.
# TODO: Find a better way to use the database.
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')

Some files were not shown because too many files have changed in this diff Show More