From 5e7c5863ef13ecb03eae8f8f5516182b363572ce Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 23 Jun 2025 16:51:28 +0800 Subject: [PATCH] refactor(tool oauth): update api implementation --- README.md | 259 ------------ .../console/workspace/model_providers.py | 1 - .../console/workspace/tool_providers.py | 118 ++++-- api/core/tools/entities/api_entities.py | 13 +- api/core/tools/entities/tool_entities.py | 33 ++ api/core/tools/tool_manager.py | 25 +- ...9_1133-222376193a49_multiple_credential.py | 39 -- ...9_1353-a9306e69af07_multiple_credential.py | 33 -- ...9_1359-6835b906335f_multiple_credential.py | 33 -- ...9_1359-e315d2a83984_multiple_credential.py | 33 -- ...9_1511-110e30078dd3_multiple_credential.py | 53 --- ...025_06_24_1705-71f5020c6470_tool_oauth.py} | 47 ++- api/models/tools.py | 56 +-- api/services/plugin/oauth_service.py | 4 +- .../tools/builtin_tools_manage_service.py | 368 ++++++++++-------- api/services/tools/tools_transform_service.py | 16 +- 16 files changed, 393 insertions(+), 738 deletions(-) delete mode 100644 README.md delete mode 100644 api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py delete mode 100644 api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py delete mode 100644 api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py delete mode 100644 api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py delete mode 100644 api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py rename api/migrations/versions/{2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py => 2025_06_24_1705-71f5020c6470_tool_oauth.py} (54%) diff --git a/README.md b/README.md deleted file mode 100644 index ca09adec08..0000000000 --- a/README.md +++ /dev/null @@ -1,259 +0,0 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) - -

- 📌 Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast -

- -

- Dify Cloud · - Self-hosting · - Documentation · - Dify edition overview -

- -

- - Static Badge - - Static Badge - - chat on Discord - - join Reddit - - follow on X(Twitter) - - follow on LinkedIn - - Docker Pulls - - Commits last month - - Issues closed - - Discussion posts -

- -

- README in English - 繁體中文文件 - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in Deutsch - README in বাংলা -

- -Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. - -## Quick start - -> Before installing Dify, make sure your machine meets the following minimum system requirements: -> -> - CPU >= 2 Core -> - RAM >= 4 GiB - -
- -The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: - -```bash -cd dify -cd docker -cp .env.example .env -docker compose up -d -``` - -After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. - -#### Seeking help - -Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) if you encounter problems setting up Dify. Reach out to [the community and us](#community--contact) if you are still having issues. - -> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) - -## Key features - -**1. Workflow**: -Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. - -**2. Comprehensive model support**: -Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). - -![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) - -**3. Prompt IDE**: -Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app. - -**4. RAG Pipeline**: -Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. - -**5. Agent capabilities**: -You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha. - -**6. LLMOps**: -Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. - -**7. Backend-as-a-Service**: -All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. - -## Feature Comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
- -## Using Dify - -- **Cloud
** - We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. - -- **Self-hosting Dify Community Edition
** - Quickly get Dify running in your environment with this [starter guide](#quick-start). - Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. - -- **Dify for enterprise / organizations
** - We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
- > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding. - -## Staying ahead - -Star Dify on GitHub and be instantly notified of new releases. - -![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - -## Advanced Setup - -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). - -If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. - -- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) -- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) -- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) -- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) -- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) - -#### Using Terraform for Deployment - -Deploy Dify to Cloud Platform with a single click using [terraform](https://www.terraform.io/) - -##### Azure Global - -- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform) - -##### Google Cloud - -- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) - -#### Using AWS CDK for Deployment - -Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) - -##### AWS - -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - -## Contributing - -For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. - -> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). - -## Community & contact - -- [GitHub Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. -- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. -- [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. - -**Contributors** - - - - - -## Star history - -[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) - -## Security disclosure - -To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. - -## License - -This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index ff0fcbda6e..32139781b0 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -35,7 +35,6 @@ class ModelProviderListApi(Resource): model_provider_service = ModelProviderService() provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) - return jsonable_encoder({"data": provider_list}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index a46071059f..a4839fe8a1 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -18,6 +18,7 @@ from controllers.console.wraps import ( ) from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import ToolProviderCredentialType from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required @@ -89,17 +90,47 @@ class ToolBuiltinProviderDeleteApi(Resource): @account_initialization_required def post(self, provider): user = current_user - if not user.is_admin_or_owner: raise Forbidden() + tenant_id = user.current_tenant_id + req = reqparse.RequestParser() + req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = req.parse_args() + + return BuiltinToolManageService.delete_builtin_tool_provider( + tenant_id, + provider, + args["credential_id"], + ) + + +class ToolBuiltinProviderAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + user = current_user + user_id = user.id tenant_id = user.current_tenant_id - return BuiltinToolManageService.delete_builtin_tool_provider( - user_id, - tenant_id, - provider, + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=False, nullable=False, location="json") + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + if args["type"] not in ToolProviderCredentialType.values(): + raise ValueError(f"Invalid credential type: {args['type']}") + + return BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=args["credentials"], + name=args["name"], + api_type=ToolProviderCredentialType.of(args["type"]), ) @@ -143,9 +174,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): def get(self, provider): tenant_id = current_user.current_tenant_id - return BuiltinToolManageService.get_builtin_tool_provider_credentials( - tenant_id=tenant_id, - provider_name=provider, + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) ) @@ -567,9 +600,9 @@ class ToolBuiltinListApi(Resource): [ provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -588,9 +621,9 @@ class ToolApiListApi(Resource): [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -609,9 +642,9 @@ class ToolWorkflowListApi(Resource): [ provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -656,14 +689,13 @@ class ToolPluginOAuthApi(Resource): ) oauth_handler = OAuthHandler() - context_id = OAuthProxyService.create_proxy_context(user_id=current_user.id, - tenant_id=tenant_id, - plugin_id=plugin_id, - provider=provider) + context_id = OAuthProxyService.create_proxy_context( + user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider + ) # todo decrypt oauth params oauth_params = plugin_oauth_config.oauth_params - oauth_params[ - 'redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + oauth_params["redirect_uri"] = redirect_uri response = oauth_handler.get_authorization_url( tenant_id, @@ -676,14 +708,13 @@ class ToolPluginOAuthApi(Resource): class ToolOAuthCallback(Resource): - @setup_required def get(self): - args = (reqparse - .RequestParser() - .add_argument("context_id", type=str, required=True, nullable=False, location="args") - .parse_args() - ) + args = ( + reqparse.RequestParser() + .add_argument("context_id", type=str, required=True, nullable=False, location="args") + .parse_args() + ) context_id = args["context_id"] context = OAuthProxyService.use_proxy_context(context_id) if context is None: @@ -703,7 +734,8 @@ class ToolOAuthCallback(Resource): plugin_id=plugin_id, ) oauth_params = plugin_oauth_config.oauth_params - oauth_params['redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + oauth_params["redirect_uri"] = redirect_uri credentials = oauth_handler.get_credentials( tenant_id, @@ -712,12 +744,20 @@ class ToolOAuthCallback(Resource): provider, system_credentials=oauth_params, request=request, - ) + ).credentials if not credentials: - raise Exception("no credentials found for this plugin") + raise Exception("the plugin credentials failed") - #TODO add credentials to database + # add credentials to database + BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=dict(credentials), + name=provider, + api_type=ToolProviderCredentialType.OAUTH2, + ) return redirect(f"{dify_config.CONSOLE_WEB_URL}") @@ -730,10 +770,8 @@ class ToolBuiltinProviderSetDefaultApi(Resource): parser.add_argument("id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return BuiltinToolManageService.set_default_provider( - tenant_id=current_user.current_tenant_id, - user_id=current_user.id, - provider=provider, - id=args["id"]) + tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + ) # tool oauth @@ -746,10 +784,12 @@ api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") +api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") -api.add_resource(ToolBuiltinProviderSetDefaultApi, - "/workspaces/current/tool-provider/builtin//set-default") +api.add_resource( + ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//set-default" +) api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index b96c994cff..eaadd4d214 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ToolProviderCredentialType, ToolProviderType class ToolApiEntity(BaseModel): @@ -70,3 +70,14 @@ class ToolProviderApiEntity(BaseModel): "tools": tools, "labels": self.labels, } + + +class ToolProviderCredentialApiEntity(BaseModel): + id: str = Field(description="The unique id of the credential") + name: str = Field(description="The name of the credential") + provider: str = Field(description="The provider of the credential") + credential_type: ToolProviderCredentialType = Field(description="The type of the credential") + is_default: bool = Field( + default=False, description="Whether the credential is the default credential for the provider in the workspace" + ) + credentials: dict = Field(description="The credentials of the provider") diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 03047c0545..5094519b6f 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -434,3 +434,36 @@ class ToolSelector(BaseModel): def to_plugin_parameter(self) -> dict[str, Any]: return self.model_dump() + + +class ToolProviderCredentialType(enum.StrEnum): + API_KEY = "api_key" + OAUTH2 = "oauth2" + + def get_name(self): + if self == ToolProviderCredentialType.API_KEY: + return "API KEY" + elif self == ToolProviderCredentialType.OAUTH2: + return "AUTH" + else: + return self.value.replace("_", " ").upper() + + def is_editable(self): + return self == ToolProviderCredentialType.API_KEY + + def is_validate_allowed(self): + return self == ToolProviderCredentialType.API_KEY + + @classmethod + def values(cls): + return [item.value for item in cls] + + @classmethod + def of(cls, credential_type: str) -> "ToolProviderCredentialType": + type_name = credential_type.lower() + if type_name == "api_key": + return cls.API_KEY + elif type_name == "oauth2": + return cls.OAUTH2 + else: + raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0bfe6329b1..f25267dbf6 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -200,6 +200,7 @@ class ToolManager: (BuiltinToolProvider.provider == str(provider_id_entity)) | (BuiltinToolProvider.provider == provider_id_entity.provider_name), ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .first() ) @@ -209,6 +210,7 @@ class ToolManager: builtin_provider = ( db.session.query(BuiltinToolProvider) .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .first() ) @@ -575,18 +577,27 @@ class ToolManager: with db.session.no_autoflush: if "builtin" in filters: - # get builtin providers + + def get_builtin_providers(tenant_id): + # according to multi credentials, select the one with is_default=True first, then created_at oldest + # for compatibility with old version + sql = """ + SELECT DISTINCT ON (tenant_id, provider) id + FROM tool_builtin_providers + WHERE tenant_id = :tenant_id + ORDER BY tenant_id, provider, is_default DESC, created_at DESC + """ + ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() + builtin_providers = cls.list_builtin_providers(tenant_id) - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() - ) + # get builtin providers + db_builtin_providers = get_builtin_providers(tenant_id) # rewrite db_builtin_providers for db_provider in db_builtin_providers: - tool_provider_id = str(ToolProviderID(db_provider.provider)) - db_provider.provider = tool_provider_id + db_provider.provider = str(ToolProviderID(db_provider.provider)) def find_db_builtin_provider(provider): return next((x for x in db_builtin_providers if x.provider == provider), None) diff --git a/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py b/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py deleted file mode 100644 index 82e812cb3d..0000000000 --- a/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py +++ /dev/null @@ -1,39 +0,0 @@ -"""multiple credential - -Revision ID: 222376193a49 -Revises: 99310d2c25a6 -Create Date: 2025-06-19 11:33:46.400455 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '222376193a49' -down_revision = '99310d2c25a6' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') - - with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: - batch_op.add_column(sa.Column('owner_type', sa.Text(), nullable=False)) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: - batch_op.drop_column('owner_type') - - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'credential_type']) - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py b/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py deleted file mode 100644 index 216661550a..0000000000 --- a/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py +++ /dev/null @@ -1,33 +0,0 @@ -"""multiple credential - -Revision ID: a9306e69af07 -Revises: 222376193a49 -Create Date: 2025-06-19 13:53:41.554159 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'a9306e69af07' -down_revision = '222376193a49' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.create_unique_constraint('unique_builtin_tool_provider', ['provider', 'tenant_id', 'default']) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique') - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py b/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py deleted file mode 100644 index d90e0d178e..0000000000 --- a/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py +++ /dev/null @@ -1,33 +0,0 @@ -"""multiple credential - -Revision ID: 6835b906335f -Revises: e315d2a83984 -Create Date: 2025-06-19 13:59:58.107955 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '6835b906335f' -down_revision = 'e315d2a83984' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['provider', 'tenant_id', 'default']) - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py b/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py deleted file mode 100644 index 2f0caeaf0d..0000000000 --- a/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py +++ /dev/null @@ -1,33 +0,0 @@ -"""multiple credential - -Revision ID: e315d2a83984 -Revises: a9306e69af07 -Create Date: 2025-06-19 13:59:13.860523 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'e315d2a83984' -down_revision = 'a9306e69af07' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique') - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id']) - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py b/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py deleted file mode 100644 index 84a5461a4d..0000000000 --- a/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py +++ /dev/null @@ -1,53 +0,0 @@ -"""multiple credential - -Revision ID: 110e30078dd3 -Revises: 6835b906335f -Create Date: 2025-06-19 15:11:42.688478 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '110e30078dd3' -down_revision = '6835b906335f' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op: - batch_op.alter_column('plugin_id', - existing_type=sa.UUID(), - type_=sa.String(length=512), - existing_nullable=False) - - with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: - batch_op.add_column(sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False)) - batch_op.alter_column('plugin_id', - existing_type=sa.UUID(), - type_=sa.String(length=512), - existing_nullable=False) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: - batch_op.alter_column('plugin_id', - existing_type=sa.String(length=512), - type_=sa.UUID(), - existing_nullable=False) - batch_op.drop_column('enabled') - - with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op: - batch_op.alter_column('plugin_id', - existing_type=sa.String(length=512), - type_=sa.UUID(), - existing_nullable=False) - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py b/api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py similarity index 54% rename from api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py rename to api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py index 95e74571d5..ffb4fffe56 100644 --- a/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py +++ b/api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py @@ -1,8 +1,8 @@ -"""add tool oauth credentials +"""tool oauth -Revision ID: 99310d2c25a6 +Revision ID: 71f5020c6470 Revises: 4474872b0ee6 -Create Date: 2025-06-18 15:06:15.261915 +Create Date: 2025-06-24 17:05:43.118647 """ from alembic import op @@ -11,7 +11,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = '99310d2c25a6' +revision = '71f5020c6470' down_revision = '4474872b0ee6' branch_labels = None depends_on = None @@ -21,30 +21,30 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tool_oauth_system_clients', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'), sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx') ) - op.create_table('tool_oauth_user_clients', + op.create_table('tool_oauth_tenant_clients', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), - sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), - sa.PrimaryKeyConstraint('id', name='tool_oauth_user_client_pkey'), - sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_user_client') + sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'), + sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client') ) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique') + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) - batch_op.alter_column('credential_type', - existing_type=sa.VARCHAR(length=255), - type_=sa.String(length=32), - existing_nullable=False, - existing_server_default=sa.text("'api_key'::character varying")) + batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api_key'::character varying"), nullable=False)) batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') - batch_op.create_unique_constraint('unique_builtin_tool_provider', ['tenant_id', 'provider', 'credential_type']) # ### end Alembic commands ### @@ -52,15 +52,14 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique') batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider']) - batch_op.alter_column('credential_type', - existing_type=sa.String(length=32), - type_=sa.VARCHAR(length=255), - existing_nullable=False, - existing_server_default=sa.text("'api_key'::character varying")) - batch_op.drop_column('default') + batch_op.drop_column('credential_type') + batch_op.drop_column('is_default') + batch_op.drop_column('name') - op.drop_table('tool_oauth_user_clients') + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id']) + + op.drop_table('tool_oauth_tenant_clients') op.drop_table('tool_oauth_system_clients') # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index 9e50cec52f..b2979a69dc 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,4 +1,3 @@ -import enum import json from datetime import datetime from typing import Any, cast @@ -18,25 +17,6 @@ from .model import Account, App, Tenant from .types import StringUUID -class ToolProviderCredentialType(enum.StrEnum): - API_KEY = "api_key" - OAUTH2 = "oauth2" - - def get_name(self): - return self.value.replace("_", " ").upper() - - def is_editable(self): - return self == ToolProviderCredentialType.API_KEY - - @classmethod - def get_credential_type(cls, credential_type: str) -> "ToolProviderCredentialType": - if credential_type == "api_key": - return cls.API_KEY - elif credential_type == "oauth2": - return cls.OAUTH2 - else: - raise ValueError(f"Invalid credential type: {credential_type}") - # system level tool oauth client params (client_id, client_secret, etc.) class ToolOAuthSystemClient(Base): __tablename__ = "tool_oauth_system_clients" @@ -48,8 +28,6 @@ class ToolOAuthSystemClient(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) provider: Mapped[str] = mapped_column(db.String(255), nullable=False) - # owner type, e.g., "system", "user" - # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -58,12 +36,12 @@ class ToolOAuthSystemClient(Base): return cast(dict, json.loads(self.encrypted_oauth_params)) -# user level tool oauth client params (client_id, client_secret, etc.) -class ToolOAuthUserClient(Base): - __tablename__ = "tool_oauth_user_clients" +# tenant level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthTenantClient(Base): + __tablename__ = "tool_oauth_tenant_clients" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_oauth_user_client_pkey"), - db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_user_client"), + db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), + db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) @@ -71,7 +49,6 @@ class ToolOAuthUserClient(Base): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) provider: Mapped[str] = mapped_column(db.String(255), nullable=False) - owner_type: Mapped[str] = mapped_column(db.Text, nullable=False) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -80,19 +57,20 @@ class ToolOAuthUserClient(Base): def oauth_params(self) -> dict: return cast(dict, json.loads(self.encrypted_oauth_params)) + class BuiltinToolProvider(Base): """ This table stores the tool provider information for built-in tools for each tenant. """ __tablename__ = "tool_builtin_providers" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - ) + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(256), nullable=False) + name: Mapped[str] = mapped_column( + db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -107,11 +85,11 @@ class BuiltinToolProvider(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) - default: Mapped[bool] = mapped_column( - db.Boolean, nullable=False, server_default=db.text("false") - ) + is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) # credential type, e.g., "api_key", "oauth2" - credential_type: Mapped[str] = mapped_column(db.String(32), nullable=False, server_default=db.text("'api_key'::character varying")) + credential_type: Mapped[str] = mapped_column( + db.String(32), nullable=False, server_default=db.text("'api_key'::character varying") + ) @property def credentials(self) -> dict: @@ -124,13 +102,11 @@ class ApiToolProvider(Base): """ __tablename__ = "tool_api_providers" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), - ) + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon icon = db.Column(db.String(255), nullable=False) # original schema diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index dcc14a8fad..4d340e2396 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -23,7 +23,7 @@ class OAuthProxyService(BasePluginClient): is used to verify the state, ensuring the request's integrity and authenticity, and mitigating replay attacks. """ - seconds, microseconds = redis_client.time() + seconds, _ = redis_client.time() context_id = str(uuid.uuid4()) data = { "user_id": user_id, @@ -55,7 +55,7 @@ class OAuthProxyService(BasePluginClient): if not data: raise ValueError("context_id is invalid") # check if data is expired - seconds, microseconds = redis_client.time() + seconds, _ = redis_client.time() state = json.loads(data) if state.get("timestamp") < seconds - max_age: raise ValueError("context_id is expired") diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 7dc3e4c0f8..6728a19391 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,20 +1,26 @@ import json import logging +import re from pathlib import Path +from sqlalchemy import ColumnExpressionArgument +from sqlalchemy.orm import Session + from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity +from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db -from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthUserClient, ToolProviderCredentialType +from extensions.ext_redis import redis_client +from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -107,7 +113,7 @@ class BuiltinToolManageService: @staticmethod def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name:str, credentials: dict, credential_id: str, name: str | None = None + user_id: str, tenant_id: str, provider_name: str, credentials: dict, credential_id: str, name: str | None = None ): """ update builtin tool provider @@ -119,7 +125,7 @@ class BuiltinToolManageService: raise ValueError(f"you have not added provider {provider_name}") try: - if ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable(): + if ToolProviderCredentialType.of(provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") @@ -132,18 +138,20 @@ class BuiltinToolManageService: ) # Decrypt and restore original credentials for masked values - credentials = BuiltinToolManageService._decrypt_and_restore_credentials( - provider_controller, tool_configuration, provider, credentials - ) + original_credentials = tool_configuration.decrypt(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: # type: ignore + credentials[name] = original_credentials[name] # type: ignore # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( provider_controller, tool_configuration, provider, credentials, user_id ) else: - raise ValueError( - f"provider {provider_name} is not editable, you can only delete it and add a new one" - ) + raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one") # update name if provided if name is not None and provider.name != name: @@ -151,10 +159,10 @@ class BuiltinToolManageService: db.session.commit() except ( - PluginDaemonClientSideError, - ToolProviderNotFoundError, - ToolNotFoundError, - ToolProviderCredentialValidationError, + PluginDaemonClientSideError, + ToolProviderNotFoundError, + ToolNotFoundError, + ToolProviderCredentialValidationError, ) as e: raise ValueError(str(e)) @@ -162,94 +170,136 @@ class BuiltinToolManageService: @staticmethod def add_builtin_tool_provider( - user_id: str, type: ToolProviderCredentialType, tenant_id: str, provider_name:str, credentials: dict, name: str | None = None + user_id: str, + api_type: ToolProviderCredentialType, + tenant_id: str, + provider_name: str, + credentials: dict, + name: str | None = None, ): """ add builtin tool provider """ - if name is None: - name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, type) - - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - credential_type=type.value, - credentials=json.dumps(credentials), - name=name, - ) - - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") + lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}" + with redis_client.lock(lock_name, timeout=20): + if name is None: + name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - - # Encrypt and save the credentials - BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller, tool_configuration, provider, credentials, user_id - ) - db.session.add(provider) + provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + encrypted_credentials=json.dumps(credentials), + credential_type=api_type.value, + name=name, + ) + + provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider_name} does not need credentials") + + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + # Encrypt and save the credentials + BuiltinToolManageService._encrypt_and_save_credentials( + provider_controller, tool_configuration, provider, credentials, user_id + ) + db.session.add(provider) + db.session.commit() return {"result": "success"} @staticmethod - def get_next_builtin_tool_provider_name(tenant_id: str, type: ToolProviderCredentialType) -> str: - """ - next name = max(provider_names) + 1 - """ - provider_names = db.session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, - credential_type=type.value, - ).all() - if not provider_names: - return f"{type.value} 1" - # OAuth 1 then OAuth 2, if don't have OAuth 1, then return OAuth 1 - # if dont have number, then get name and add 1 - for provider_name in provider_names: - if provider_name.provider.startswith(type.value): - return f"{type.value} {int(provider_name.provider.split(' ')[1]) + 1}" - return f"{type.value} 1" + def get_next_builtin_tool_provider_name( + tenant_id: str, provider_name: str, type: ToolProviderCredentialType + ) -> str: + try: + providers = ( + db.session.query(BuiltinToolProvider) + .filter_by( + tenant_id=tenant_id, + provider=provider_name, + credential_type=type.value, + ) + .order_by(BuiltinToolProvider.created_at.desc()) + .limit(10) + .all() + ) + # Get the default name pattern + default_pattern = type.get_name() + + # Find all names that match the default pattern: "{default_pattern} {number}" + pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" + numbers = [] + + for provider in providers: + if provider.name: + match = re.match(pattern, provider.name.strip()) + if match: + numbers.append(int(match.group(1))) + + # If no default pattern names found, start with 1 + if not numbers: + return f"{default_pattern} 1" + + # Find the next number + max_number = max(numbers) + return f"{default_pattern} {max_number + 1}" + except Exception as e: + logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}") + # fallback + return f"{type.get_name()} 1" @staticmethod - def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): + def get_builtin_tool_provider_credentials( + tenant_id: str, provider_name: str + ) -> list[ToolProviderCredentialApiEntity]: """ get builtin tool provider credentials """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() - if provider_obj is None: - return {} + if len(providers) == 0: + return [] - provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) + provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - credentials = tool_configuration.decrypt(provider_obj.credentials) - credentials = tool_configuration.mask_tool_credentials(credentials) + credentials: list[ToolProviderCredentialApiEntity] = [] + for provider in providers: + decrypt_credential = tool_configuration.mask_tool_credentials( + tool_configuration.decrypt(provider.credentials) + ) + credentials.append( + ToolTransformService.convert_builtin_provider_to_credential_api_entity( + provider=provider, + credentials=decrypt_credential, + ) + ) return credentials @staticmethod - def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): + def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str): """ delete tool provider """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) if provider_obj is None: raise ValueError(f"you have not added provider {provider_name}") db.session.delete(provider_obj) db.session.commit() - + # delete cache provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) tool_configuration = ProviderConfigEncrypter( @@ -267,70 +317,45 @@ class BuiltinToolManageService: """ set default provider """ - # get provider - target_provider = db.session.query(BuiltinToolProvider).filter_by(id=id).first() - if target_provider is None: - raise ValueError("provider not found") + with Session(db.engine) as session: + # get provider + target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first() + if target_provider is None: + raise ValueError("provider not found") - # clear default provider - db.session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, - user_id=user_id, - provider=provider, - default=True - ).update({"default": False}) + # clear default provider + session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, user_id=user_id, provider=provider, default=True + ).update({"default": False}) - # set new default provider - target_provider.default = True - db.session.commit() + # set new default provider + target_provider.is_default = True + session.commit() return {"result": "success"} - @staticmethod - def fetch_default_provider(tenant_id: str, user_id: str, provider_name: str): - """ - fetch default provider - if there is no explicitly set default provider, return the oldest provider as default - """ - # 1. check if default provider exists - default_provider = db.session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - default=True - ).first() - if default_provider: - return default_provider - - # 2. if no default provider, set the oldest provider as default - oldest_provider = (db.session.query(BuiltinToolProvider) - .filter_by(tenant_id=tenant_id, user_id=user_id, provider=provider_name) - .order_by(BuiltinToolProvider.created_at) - .first() - ) - if oldest_provider: - return oldest_provider - - raise ValueError(f"no default provider found for {provider_name}") - @staticmethod def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str): """ get builtin tool provider """ - user_client = db.session.query(ToolOAuthUserClient).filter_by( - tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - enabled=True, - ).first() + with Session(db.engine) as session: + user_client = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + enabled=True, + ) + .first() + ) + if user_client: + plugin_oauth_config = user_client + else: + plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() - if user_client: - plugin_oauth_config = user_client - else: - plugin_oauth_config = db.session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() - - if plugin_oauth_config: - return plugin_oauth_config + if plugin_oauth_config: + return plugin_oauth_config raise ValueError("no oauth available config found for this plugin") @@ -408,73 +433,69 @@ class BuiltinToolManageService: @staticmethod def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None: - provider = (db.session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.id == credential_id, - ) - .first()) + provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) return provider @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: - try: - full_provider_name = provider_name - provider_id_entity = GenericProviderID(provider_name) - provider_name = provider_id_entity.provider_name - if provider_id_entity.organization != "langgenius": - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == full_provider_name, - ) - .first() - ) - else: - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name) - | (BuiltinToolProvider.provider == full_provider_name), - ) - .first() - ) - - if provider_obj is None: - return None - - provider_obj.provider = GenericProviderID(provider_obj.provider).to_string() - return provider_obj - except Exception: - # it's an old provider without organization + """ + This method is used to fetch the builtin provider from the database + 1.if the default provider exists, return the default provider + 2.if the default provider does not exist, return the oldest provider + """ + def _query(provider_filters: list[ColumnExpressionArgument[bool]]): return ( db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name), + .filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first ) .first() ) + try: + full_provider_name = provider_name + provider_id_entity = ToolProviderID(provider_name) + provider_name = provider_id_entity.provider_name + + if provider_id_entity.organization != "langgenius": + provider = _query([BuiltinToolProvider.provider == full_provider_name]) + else: + provider = _query( + [ + (BuiltinToolProvider.provider == provider_name) + | (BuiltinToolProvider.provider == full_provider_name) + ] + ) + + if provider is None: + return None + + provider.provider = ToolProviderID(provider.provider).to_string() + return provider + except Exception: + # it's an old provider without organization + provider_obj = _query([BuiltinToolProvider.provider == provider_name]) + return provider_obj + @staticmethod - def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials): + def _decrypt_and_restore_credentials(tool_configuration, provider, credentials): """ Decrypt original credentials and restore masked values from the input credentials - :param provider_controller: the provider controller :param tool_configuration: the tool configuration encrypter :param provider: the provider object from database :param credentials: the input credentials from user :return: the processed credentials with original values restored """ - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: # type: ignore - credentials[name] = original_credentials[name] # type: ignore return credentials @@ -489,8 +510,9 @@ class BuiltinToolManageService: :param credentials: the credentials to encrypt and save :param user_id: the user id for validation """ - # validate credentials - provider_controller.validate_credentials(user_id, credentials) + if ToolProviderCredentialType.of(provider.credential_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, credentials) + # encrypt credentials encrypted_credentials = tool_configuration.encrypt(credentials) provider.encrypted_credentials = json.dumps(encrypted_credentials) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..b896f6c88f 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -9,12 +9,13 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolParameter, + ToolProviderCredentialType, ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController @@ -304,3 +305,16 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) + + @staticmethod + def convert_builtin_provider_to_credential_api_entity( + provider: BuiltinToolProvider, credentials: dict + ) -> ToolProviderCredentialApiEntity: + return ToolProviderCredentialApiEntity( + id=provider.id, + name=provider.name, + provider=provider.provider, + credential_type=ToolProviderCredentialType.of(provider.credential_type), + is_default=provider.is_default, + credentials=credentials, + )