mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 06:46:30 +08:00
add more current_user typing (#24612)
This commit is contained in:
parent
22b11e4b43
commit
4cd00efe3b
@ -8,20 +8,21 @@ from uuid import UUID
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytz
|
import pytz
|
||||||
from flask_login import current_user
|
|
||||||
|
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from libs.login import current_user
|
||||||
|
from models.account import Account
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def safe_json_value(v):
|
def safe_json_value(v):
|
||||||
if isinstance(v, datetime):
|
if isinstance(v, datetime):
|
||||||
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
|
tz_name = "UTC"
|
||||||
if not tz_name:
|
if isinstance(current_user, Account) and current_user.timezone is not None:
|
||||||
tz_name = "UTC"
|
tz_name = current_user.timezone
|
||||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||||
elif isinstance(v, date):
|
elif isinstance(v, date):
|
||||||
return v.isoformat()
|
return v.isoformat()
|
||||||
@ -46,7 +47,7 @@ def safe_json_value(v):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def safe_json_dict(d):
|
def safe_json_dict(d: dict):
|
||||||
if not isinstance(d, dict):
|
if not isinstance(d, dict):
|
||||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||||
return {k: safe_json_value(v) for k, v in d.items()}
|
return {k: safe_json_value(v) for k, v in d.items()}
|
||||||
|
|||||||
@ -3,8 +3,6 @@ import logging
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
|
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
@ -17,8 +15,8 @@ from core.tools.entities.tool_entities import (
|
|||||||
from core.tools.errors import ToolInvokeError
|
from core.tools.errors import ToolInvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories.file_factory import build_from_mapping
|
from factories.file_factory import build_from_mapping
|
||||||
from models.account import Account
|
from libs.login import current_user
|
||||||
from models.model import App, EndUser
|
from models.model import App
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -81,11 +79,11 @@ class WorkflowTool(Tool):
|
|||||||
generator = WorkflowAppGenerator()
|
generator = WorkflowAppGenerator()
|
||||||
assert self.runtime is not None
|
assert self.runtime is not None
|
||||||
assert self.runtime.invoke_from is not None
|
assert self.runtime.invoke_from is not None
|
||||||
|
assert current_user is not None
|
||||||
result = generator.generate(
|
result = generator.generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
user=cast("Account | EndUser", current_user),
|
user=current_user,
|
||||||
args={"inputs": tool_parameters, "files": files},
|
args={"inputs": tool_parameters, "files": files},
|
||||||
invoke_from=self.runtime.invoke_from,
|
invoke_from=self.runtime.invoke_from,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
|
|||||||
@ -39,7 +39,7 @@ def test_page_result(text, cursor, maxlen, expected):
|
|||||||
# Tests: get_url
|
# Tests: get_url
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def stub_support_types(monkeypatch):
|
def stub_support_types(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Stub supported content types list."""
|
"""Stub supported content types list."""
|
||||||
import core.tools.utils.web_reader_tool as mod
|
import core.tools.utils.web_reader_tool as mod
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ def stub_support_types(monkeypatch):
|
|||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
|
||||||
def test_get_url_unsupported_content_type(monkeypatch, stub_support_types):
|
def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||||
# HEAD 200 but content-type not supported and not text/html
|
# HEAD 200 but content-type not supported and not text/html
|
||||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||||
return FakeResponse(
|
return FakeResponse(
|
||||||
@ -62,7 +62,7 @@ def test_get_url_unsupported_content_type(monkeypatch, stub_support_types):
|
|||||||
assert result == "Unsupported content-type [image/png] of URL."
|
assert result == "Unsupported content-type [image/png] of URL."
|
||||||
|
|
||||||
|
|
||||||
def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_support_types):
|
def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||||
"""
|
"""
|
||||||
When content-type is in SUPPORT_URL_CONTENT_TYPES,
|
When content-type is in SUPPORT_URL_CONTENT_TYPES,
|
||||||
should call ExtractProcessor.load_from_url and return its text.
|
should call ExtractProcessor.load_from_url and return its text.
|
||||||
@ -88,7 +88,7 @@ def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_
|
|||||||
assert result == "PDF extracted text"
|
assert result == "PDF extracted text"
|
||||||
|
|
||||||
|
|
||||||
def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_support_types):
|
def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||||
"""200 + text/html → GET, chardet detects encoding, readability returns article which is templated."""
|
"""200 + text/html → GET, chardet detects encoding, readability returns article which is templated."""
|
||||||
|
|
||||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||||
@ -121,7 +121,7 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_suppor
|
|||||||
assert "Hello world" in out
|
assert "Hello world" in out
|
||||||
|
|
||||||
|
|
||||||
def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_support_types):
|
def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||||
"""If readability returns no text, should return empty string."""
|
"""If readability returns no text, should return empty string."""
|
||||||
|
|
||||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||||
@ -142,7 +142,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_su
|
|||||||
assert out == ""
|
assert out == ""
|
||||||
|
|
||||||
|
|
||||||
def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types):
|
def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||||
"""HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""
|
"""HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""
|
||||||
|
|
||||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||||
@ -175,7 +175,7 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types):
|
|||||||
assert "X" in out
|
assert "X" in out
|
||||||
|
|
||||||
|
|
||||||
def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types):
|
def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||||
"""HEAD returns non-200 and non-403 → should directly return code message."""
|
"""HEAD returns non-200 and non-403 → should directly return code message."""
|
||||||
|
|
||||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||||
@ -189,7 +189,7 @@ def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types):
|
|||||||
assert out == "URL returned status code 500."
|
assert out == "URL returned status code 500."
|
||||||
|
|
||||||
|
|
||||||
def test_get_url_content_disposition_filename_detection(monkeypatch, stub_support_types):
|
def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||||
"""
|
"""
|
||||||
If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type,
|
If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type,
|
||||||
it should route to ExtractProcessor.load_from_url.
|
it should route to ExtractProcessor.load_from_url.
|
||||||
@ -213,7 +213,7 @@ def test_get_url_content_disposition_filename_detection(monkeypatch, stub_suppor
|
|||||||
assert out == "From ExtractProcessor via filename"
|
assert out == "From ExtractProcessor via filename"
|
||||||
|
|
||||||
|
|
||||||
def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_support_types):
|
def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||||
"""
|
"""
|
||||||
If chardet returns an encoding but content.decode raises, should fallback to response.text.
|
If chardet returns an encoding but content.decode raises, should fallback to response.text.
|
||||||
"""
|
"""
|
||||||
@ -250,7 +250,7 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_supp
|
|||||||
# ---------------------------
|
# ---------------------------
|
||||||
|
|
||||||
|
|
||||||
def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch):
|
def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch: pytest.MonkeyPatch):
|
||||||
# stub readabilipy.simple_json_from_html_string
|
# stub readabilipy.simple_json_from_html_string
|
||||||
def fake_simple_json_from_html_string(html, use_readability=True):
|
def fake_simple_json_from_html_string(html, use_readability=True):
|
||||||
return {
|
return {
|
||||||
@ -271,7 +271,7 @@ def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch):
|
|||||||
assert article.text[0]["text"] == "world"
|
assert article.text[0]["text"] == "world"
|
||||||
|
|
||||||
|
|
||||||
def test_extract_using_readabilipy_defaults_when_missing(monkeypatch):
|
def test_extract_using_readabilipy_defaults_when_missing(monkeypatch: pytest.MonkeyPatch):
|
||||||
def fake_simple_json_from_html_string(html, use_readability=True):
|
def fake_simple_json_from_html_string(html, use_readability=True):
|
||||||
return {} # all missing
|
return {} # all missing
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from core.tools.errors import ToolInvokeError
|
|||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
|
|
||||||
|
|
||||||
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
|
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
||||||
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
||||||
the `data` element.
|
the `data` element.
|
||||||
@ -40,7 +40,7 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
|||||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||||
)
|
)
|
||||||
monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None)
|
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
with pytest.raises(ToolInvokeError) as exc_info:
|
with pytest.raises(ToolInvokeError) as exc_info:
|
||||||
# WorkflowTool always returns a generator, so we need to iterate to
|
# WorkflowTool always returns a generator, so we need to iterate to
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import httpx
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
@ -20,7 +21,7 @@ from models.enums import UserFrom
|
|||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
|
|
||||||
def test_http_request_node_binary_file(monkeypatch):
|
def test_http_request_node_binary_file(monkeypatch: pytest.MonkeyPatch):
|
||||||
data = HttpRequestNodeData(
|
data = HttpRequestNodeData(
|
||||||
title="test",
|
title="test",
|
||||||
method="post",
|
method="post",
|
||||||
@ -110,7 +111,7 @@ def test_http_request_node_binary_file(monkeypatch):
|
|||||||
assert result.outputs["body"] == "test"
|
assert result.outputs["body"] == "test"
|
||||||
|
|
||||||
|
|
||||||
def test_http_request_node_form_with_file(monkeypatch):
|
def test_http_request_node_form_with_file(monkeypatch: pytest.MonkeyPatch):
|
||||||
data = HttpRequestNodeData(
|
data = HttpRequestNodeData(
|
||||||
title="test",
|
title="test",
|
||||||
method="post",
|
method="post",
|
||||||
@ -211,7 +212,7 @@ def test_http_request_node_form_with_file(monkeypatch):
|
|||||||
assert result.outputs["body"] == ""
|
assert result.outputs["body"] == ""
|
||||||
|
|
||||||
|
|
||||||
def test_http_request_node_form_with_multiple_files(monkeypatch):
|
def test_http_request_node_form_with_multiple_files(monkeypatch: pytest.MonkeyPatch):
|
||||||
data = HttpRequestNodeData(
|
data = HttpRequestNodeData(
|
||||||
title="test",
|
title="test",
|
||||||
method="post",
|
method="post",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user