mirror of https://github.com/langgenius/dify.git
feat: implement OAuth credentials refresh mechanism and update expires_at handling
This commit is contained in:
parent
5d986c2cdf
commit
7fa952b1a2
|
|
@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
|
|||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
credentials = oauth_handler.get_credentials(
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
|
|
@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource):
|
|||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
).credentials
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise Exception("the plugin credentials failed")
|
||||
|
|
@ -758,8 +761,8 @@ class ToolOAuthCallback(Resource):
|
|||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credentials=dict(credentials),
|
||||
expires_at=expires_at,
|
||||
api_type=CredentialType.OAUTH2,
|
||||
expires_at=credentials.get("expires_at"),
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
|
|
|||
|
|
@ -187,6 +187,9 @@ class PluginOAuthCredentialsResponse(BaseModel):
|
|||
)
|
||||
expires_at: int | None = Field(description="The expires at time of the credentials. UTC timestamp.")
|
||||
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
|
||||
expires_at: int = Field(
|
||||
default=-1, description="The timestamp of the credentials expires at, -1 means never expires."
|
||||
)
|
||||
|
||||
|
||||
class PluginListResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient):
|
|||
except Exception as e:
|
||||
raise ValueError(f"Error getting credentials: {e}")
|
||||
|
||||
def refresh_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
redirect_uri: str,
|
||||
system_credentials: Mapping[str, Any],
|
||||
credentials: Mapping[str, Any],
|
||||
) -> PluginOAuthCredentialsResponse:
|
||||
try:
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
|
||||
PluginOAuthCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"redirect_uri": redirect_uri,
|
||||
"system_credentials": system_credentials,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
for resp in response:
|
||||
return resp
|
||||
raise ValueError("No response received from plugin daemon for refresh credentials request.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error refreshing credentials: {e}")
|
||||
|
||||
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||
"""
|
||||
Convert a Request object to raw HTTP data.
|
||||
|
|
|
|||
|
|
@ -1,16 +1,19 @@
|
|||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
|
|
@ -21,6 +24,7 @@ from core.tools.plugin_tool.tool import PluginTool
|
|||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -244,12 +248,44 @@ class ToolManager:
|
|||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
# decrypt the credentials
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||
|
||||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and builtin_provider.expires_at < int(time.time()):
|
||||
# refresh the credentials
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_name)
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = (
|
||||
TypeAdapter(dict[str, Any])
|
||||
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
.decode("utf-8")
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=encrypter.decrypt(builtin_provider.credentials),
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ depends_on = None
|
|||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('expires_at', sa.Integer(), server_default=sa.text('2147483647'), nullable=False))
|
||||
batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('2147483647'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class BuiltinToolProvider(Base):
|
|||
credential_type: Mapped[str] = mapped_column(
|
||||
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
|
||||
)
|
||||
expires_at: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("2147483647"))
|
||||
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
|
|
|
|||
|
|
@ -213,8 +213,8 @@ class BuiltinToolManageService:
|
|||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
expires_at: int = -1,
|
||||
name: str | None = None,
|
||||
expires_at: int | None = None,
|
||||
):
|
||||
"""
|
||||
add builtin tool provider
|
||||
|
|
|
|||
Loading…
Reference in New Issue