diff --git a/.agents/skills/how-to-write-component/SKILL.md b/.agents/skills/how-to-write-component/SKILL.md index 8a480c8fd09..bf523070740 100644 --- a/.agents/skills/how-to-write-component/SKILL.md +++ b/.agents/skills/how-to-write-component/SKILL.md @@ -36,28 +36,34 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow. - Do not replace prop drilling with one top-level hook that returns a large view model and then thread that object through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it, or use feature-scoped Jotai atoms for simple shared form/UI state when siblings need the same source of truth. - When using feature-scoped Jotai state for a form, drawer, or other secondary surface, scope the store to that surface instance when stale cross-instance state is possible. Initialize stable config at the owning boundary, then let descendants read only the atoms or purpose-named hooks they actually need. -- For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. The lowest-owner rule still applies to independent visual surfaces that do not participate in shared state. +- For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. Do not create a query or mutation atom only because the surrounding feature uses Jotai. If the query or mutation does not read atom state, feed another atom, or participate in shared workflow orchestration, use `useQuery` or `useMutation` directly at the lowest owner. - For repeated row/menu action surfaces that need reset, hydrate the stable identity at the surface entry and scope only the primitives that truly need per-instance reset, such as open flags, drafts, or selected local options. - Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action. -- Prefer uncontrolled DOM state and CSS variables before adding controlled props. +- Default to uncontrolled form and DOM state. Add controlled props or atom-backed drafts only when live cross-component reactions, multi-step persistence, or external synchronization require them. ## Feature-Scoped Jotai State -- A module's feature-local state lives in one state file for Jotai-backed features: primitive atoms, query atoms, derived atoms, write-only action atoms, mutation atoms, submission orchestration, provider exports, and optional scope configuration. -- Keep state local when one component owns it, even inside Jotai-backed features. Dialog open flags, menu/popover visibility, confirmation visibility, form/input drafts, row-local pending flags, and in-flight refs usually belong in component state. +- A module's feature-local state lives in one state file for Jotai-backed features: primitive atoms, shared query atoms, derived atoms, write-only action atoms, shared mutation atoms, submission orchestration, provider exports, and optional scope configuration. +- Keep synchronous UI state local when one component owns it, even inside Jotai-backed features. Dialog open flags, menu/popover visibility, confirmation visibility, form/input drafts, and selected local options usually belong in component state. +- Do not put simple form drafts in Jotai atoms. For edit/create forms whose fields are only read at submit time, use uncontrolled `@langgenius/dify-ui/form` and `@langgenius/dify-ui/field` controls with `defaultValue`, browser/form validation, and keyed remounts for query-backed initial values. +- Promote form state to Jotai only when another component must react to in-progress field changes, the draft must survive unmount/remount within the same scoped workflow, or multiple steps/surfaces share the same editable draft before submit. +- Keep submit-time normalization, dirty checks, and payload shaping beside the form submit handler. Do not create form atoms, field atoms, or derived can-save atoms only to mirror uncontrolled form values or disable a submit button. +- In Jotai-backed feature surfaces, never hand-roll async loading, error, or in-flight guards with `useState` or `useRef`. For async work that depends on atom state, feeds derived atoms, or participates in shared submission orchestration, model the work with `atomWithQuery` or `atomWithMutation`; write atoms should only update the inputs that drive those atoms. For component-owned remote work that does not participate in atom state, use TanStack Query hooks directly. +- Row-local async state should belong to the row owner. Use `useQuery` or `useMutation` directly for row actions that do not depend on atom state and are not consumed by other atoms. Use a per-instance query or mutation atom only when the row action participates in a Jotai-backed shared workflow or needs atom-scoped reset semantics. - Promote UI state to an atom only when siblings need the same source of truth, the value drives a query or mutation atom, a parent workflow coordinates the state, or the state intentionally persists across hidden or unmounted descendants within a scoped surface. -- Reflect atom-backed surface-wide locks or invariants in every affected trigger. If only one row, menu, or dialog should be disabled, keep the pending or lock state local to that row, menu, or dialog. +- Reflect atom-backed surface-wide locks or invariants in every affected trigger. If only one row, menu, or dialog should be disabled, keep the pending or lock scope local to that row, menu, or dialog with the lowest-owner query/mutation hook unless it genuinely participates in shared atom state. - Atom order in the state file follows the dependency graph: types/constants, editable primitives, query atoms, query-data derived atoms, readiness/business derived atoms, write actions, mutation atoms, submission orchestration, provider exports. - Derived atom names read as business facts. Write atom names read as user or workflow commands. - UI components read and write the exact atom they use with `useAtomValue` or `useSetAtom`. Repeated workflow semantics live in named derived atoms or write atoms. - Non-query derived atoms return a narrow value with a clear domain name; avoid pass-through aliases or bundling unrelated UI facts. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract. -- Write-only atoms own synchronous state transitions that update multiple primitives, reset dependent state, or advance the workflow. Async work with loading, error, caching, retry, or stale-result concerns should be modeled as query or mutation atoms, with write atoms only changing the inputs that drive them. +- Write-only atoms own synchronous state transitions that update multiple primitives, reset dependent state, or advance the workflow. Async work with loading, error, caching, retry, stale-result, or in-flight concerns should be modeled as query or mutation atoms, with write atoms only changing the inputs that drive them. - Avoid feature hooks that aggregate form values, query results, derived state, and commands for sibling components. Prefer named derived atoms and write atoms so UI components read the exact shared fact or command they need. - When a form library owns validation, keep submit orchestration in feature state when post-submit result or error state is shared by the surface. Avoid duplicating validation gates or request shaping in UI hooks. -- `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state when atoms are the feature's local state surface. +- `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state only when they need atom inputs, provide data to derived atoms, or coordinate a shared Jotai-backed workflow. - Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query and mutation atoms keep shared cache behavior through the shared QueryClient. - Do not put `atomWithQuery`, `atomWithInfiniteQuery`, `atomWithMutation`, or broad derived orchestration atoms in a `ScopeProvider` just to reset a surface. Scoped derived atoms implicitly scope their dependencies, which can duplicate query client access and break shared invalidation. Leave query/mutation atoms unscoped; let them read scoped primitive inputs. - Scope providers should list resettable primitive atoms and explicit hydration tuples. If a derived atom must be scoped, confirm that every dependency it implicitly scopes is meant to be private to that surface. +- For scoped primitives that are always hydrated by a `ScopeProvider` tuple, prefer `atomWithLazy(() => { throw new Error(...) })` when consumers should see a non-null type. This keeps missing provider hydration as a runtime invariant without leaking `T | undefined` or adding pass-through "required" derived atoms only for narrowing. - Keep independent dialog lifecycles separate. Avoid a single discriminated "current action dialog" atom when edit, delete, and other dialogs have their own open state, loading guard, or reset behavior. - Route-derived stable identities that do not need instance reset or scoped isolation can be hydrated at the route or layout boundary into a feature route atom. Use scoped atoms only when stale cross-instance state or per-surface reset semantics are needed. @@ -102,7 +108,9 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape. - Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`. +- Do not promote a query or mutation to an atom just because the feature already has a state file. Use `atomWithQuery` or `atomWithMutation` only when the query/mutation reads atom state, is consumed by another atom, or is part of shared workflow orchestration. - In `atomWithQuery` and `atomWithInfiniteQuery`, return generated `queryOptions()` or `infiniteOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into that call instead of spreading generated options into a hand-built object. +- When prefetch and render consume the same server request, extract local query options or a query-options atom so `queryClient.prefetchQuery(...)` and `useQuery`/`atomWithQuery` share the exact generated query options. - In `atomWithMutation`, return generated `mutationOptions()` directly when using generated clients. Put request shaping and submit orchestration in write atoms; do not rebuild mutation option objects just to pass through the generated mutation function. - For custom query functions that do not come from generated clients, wrap the options object with TanStack `queryOptions(...)` so query atoms still return a query options contract. - Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it. @@ -110,9 +118,10 @@ Use this as the decision guide for React/TypeScript component structure. Existin - For TanStack cache data, use generated or query-derived types; do not create local wrappers for `getQueryData` or `getQueriesData`. - For generated oRPC `queryOptions()` / `infiniteOptions()`, keep returning the generated options directly. When required input is missing, use a whole-input branch such as `input: condition ? validInput : skipToken` together with `enabled: Boolean(condition)` so no request runs and no fake payload is built. - Do not put `skipToken` inside a nested placeholder payload, such as `{ params: { appInstanceId: skipToken } }`. Do not create hand-written "missing queryOptions" objects or coerce required IDs to `''`. -- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows. +- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))` when the mutation is owned by one component, menu, dialog, or row and its pending/error state is not consumed by feature atoms. In Jotai-backed workflow orchestration, expose mutations from feature state with `atomWithMutation` so pending/error state stays attached to the mutation atom. For component-owned custom mutation functions, use `useMutation(mutationOptions(...))` at the owner. - Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules. - Component or atom mutation callbacks can handle local UI feedback such as toasts, closing dialogs, or navigation. They should not replace shared invalidation or add local cache patches for shared server state. +- For overlays that may open a heavier secondary surface, prefetch server data from the trigger/menu open event with `queryClient.prefetchQuery(queryOptions)` when the primitive exposes `onOpenChange`. Do not mount a hidden component or subscribe to a query only to warm the cache. Do not make an otherwise uncontrolled menu controlled only for prefetching. - Do not use deprecated `useInvalid` or `useReset`. - Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`. @@ -125,6 +134,9 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow. - Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment. - When a dialog, dropdown, or popover component already accepts controlled `open` state, mount the surface unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset for reset behavior instead of `{open && }` wrappers. +- When opening a dialog from a menu item, keep the menu and dialog as sibling surfaces. Let the menu item command open the dialog through local state or scoped atoms, and mount the dialog outside the menu popup content. Avoid wrapping menu items with dialog triggers when the menu primitive already owns item activation and dismissal behavior. +- For dialogs and alert dialogs, keep the root component responsible for `open` wiring and put query/mutation hooks inside the content component when the work should only mount after the overlay opens. Do not put closed-surface remote work in the root just because the root owns the open atom. +- Prefer uncontrolled overlay roots when the library can own their open state. Use `onOpenChange` for side effects such as prefetching, and CSS/data selectors for visual open-state styling instead of adding controlled state only for observation. - Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible. - Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, children-as-pass-through composition, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook, forwards props, or passes trigger/content through to one child, move the logic into that child or make the wrapper own a real surface. diff --git a/api/commands/account.py b/api/commands/account.py index 0d99ce7a0fa..7f4f0a744f3 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -25,7 +25,7 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - account = AccountService.get_account_by_email_with_case_fallback(email.strip()) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip()) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) @@ -67,7 +67,7 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - account = AccountService.get_account_by_email_with_case_fallback(email.strip()) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip()) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) diff --git a/api/controllers/console/agent/composer.py b/api/controllers/console/agent/composer.py index 975586c635c..b54cf4b6daf 100644 --- a/api/controllers/console/agent/composer.py +++ b/api/controllers/console/agent/composer.py @@ -28,9 +28,9 @@ from libs.login import login_required from models.model import App, AppMode from services.agent.composer_service import AgentComposerService from services.agent.composer_validator import ComposerConfigValidator -from services.entities.agent_entities import ComposerSavePayload +from services.entities.agent_entities import ComposerSavePayload, WorkflowComposerCopyFromRosterPayload -register_schema_models(console_ns, ComposerSavePayload) +register_schema_models(console_ns, ComposerSavePayload, WorkflowComposerCopyFromRosterPayload) register_response_schema_models( console_ns, AgentAppComposerResponse, @@ -91,6 +91,38 @@ class WorkflowAgentComposerApi(Resource): ) +@console_ns.route("/apps//workflows/draft/nodes//agent-composer/copy-from-roster") +class WorkflowAgentComposerCopyFromRosterApi(Resource): + @console_ns.expect(console_ns.models[WorkflowComposerCopyFromRosterPayload.__name__]) + @console_ns.response( + 200, + "Workflow roster agent copied to inline agent", + console_ns.models[WorkflowAgentComposerResponse.__name__], + ) + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT) + @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str): + payload = WorkflowComposerCopyFromRosterPayload.model_validate(console_ns.payload or {}) + return dump_response( + WorkflowAgentComposerResponse, + AgentComposerService.copy_workflow_composer_from_roster( + tenant_id=tenant_id, + app_id=app_model.id, + node_id=node_id, + account_id=account_id, + source_agent_id=payload.source_agent_id, + source_snapshot_id=payload.source_snapshot_id, + idempotency_key=payload.idempotency_key, + ), + ) + + @console_ns.route("/apps//workflows/draft/nodes//agent-composer/validate") class WorkflowAgentComposerValidateApi(Resource): @console_ns.expect(console_ns.models[ComposerSavePayload.__name__]) diff --git a/api/controllers/console/app/agent_app_feature.py b/api/controllers/console/app/agent_app_feature.py index d155dae6ac3..358e552beb0 100644 --- a/api/controllers/console/app/agent_app_feature.py +++ b/api/controllers/console/app/agent_app_feature.py @@ -91,7 +91,10 @@ class AgentAppFeatureConfigResource(Resource): args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {}) new_app_model_config = AgentAppFeatureConfigService.update_features( - app_model=app_model, account=current_user, config=args.model_dump(exclude_none=True), session=db.session + app_model=app_model, + account=current_user, + config=args.model_dump(exclude_none=True), + session=db.session, ) app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 43b41903f60..b66c97c274c 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -30,6 +30,7 @@ from controllers.console.wraps import ( setup_required, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode @@ -142,6 +143,7 @@ class ChatMessageTextApi(Resource): response = AudioService.transcript_tts( app_model=app_model, + session=db.session, text=payload.text, voice=payload.voice, message_id=payload.message_id, diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 726bd94cd7e..195a41f2888 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -341,8 +341,8 @@ class MessageFeedbackExportApi(Resource): try: export_data = FeedbackService.export_feedbacks( - db.session(), - app_id=app_model.id, + app_model.id, + session=db.session(), from_source=args.from_source, rating=args.rating, has_comment=args.has_comment, diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 1de206c73db..11fab84a831 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models +from extensions.ext_database import db from fields.base import ResponseModel from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService @@ -58,7 +59,7 @@ class ApiKeyAuthDataSource(Resource): @account_initialization_required @with_current_tenant_id def get(self, current_tenant_id: str): - data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id) + data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(db.session(), current_tenant_id) if data_source_api_key_bindings: return { "sources": [ @@ -92,7 +93,7 @@ class ApiKeyAuthDataSourceBinding(Resource): data = payload.model_dump() ApiKeyAuthService.validate_api_key_auth_args(data) try: - ApiKeyAuthService.create_provider_auth(current_tenant_id, data) + ApiKeyAuthService.create_provider_auth(db.session(), current_tenant_id, data) except Exception as e: raise ApiKeyAuthFailedError(str(e)) return {"result": "success"}, 200 @@ -109,6 +110,6 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @with_current_tenant_id def delete(self, current_tenant_id: str, binding_id: UUID): # The role of the current user in the table must be admin or owner - ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id)) + ApiKeyAuthService.delete_provider_auth(db.session(), current_tenant_id, str(binding_id)) return "", 204 diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 912eb26574c..ccbe9405fe5 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -15,6 +15,7 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) +from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import EmailStr, extract_remote_ip from libs.helper import timezone as validate_timezone_string @@ -100,7 +101,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - account = AccountService.get_account_by_email_with_case_fallback(args.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -175,7 +176,7 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: raise EmailAlreadyInUseError() diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d82f63c11db..061c29a13a2 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -82,7 +82,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - account = AccountService.get_account_by_email_with_case_fallback(args.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email) token = AccountService.send_reset_password_email( account=account, @@ -180,7 +180,7 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: account = db.session.merge(account) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 78d1583fde9..670d1c7818d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -224,7 +224,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, user_info.email) return account diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 7195fe066fd..ebb490cd9e8 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -17,6 +17,7 @@ from controllers.console.wraps import ( with_current_tenant_id, with_current_user, ) +from extensions.ext_database import db from fields.dataset_fields import ( DatasetMetadataBuiltInFieldsResponse, DatasetMetadataListResponse, @@ -65,7 +66,9 @@ class DatasetMetadataCreateApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.create_metadata(dataset_id_str, metadata_args, current_user, current_tenant_id) + metadata = MetadataService.create_metadata( + db.session(), dataset_id_str, metadata_args, current_user, current_tenant_id + ) return dump_response(DatasetMetadataResponse, metadata), 201 @setup_required @@ -81,7 +84,7 @@ class DatasetMetadataCreateApi(Resource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - metadata = MetadataService.get_dataset_metadatas(dataset) + metadata = MetadataService.get_dataset_metadatas(db.session(), dataset) return dump_response(DatasetMetadataListResponse, metadata), 200 @@ -108,7 +111,7 @@ class DatasetMetadataApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) metadata = MetadataService.update_metadata_name( - dataset_id_str, metadata_id_str, name, current_user, current_tenant_id + db.session(), dataset_id_str, metadata_id_str, name, current_user, current_tenant_id ) return dump_response(DatasetMetadataResponse, metadata), 200 @@ -127,7 +130,7 @@ class DatasetMetadataApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - MetadataService.delete_metadata(dataset_id_str, metadata_id_str) + MetadataService.delete_metadata(db.session(), dataset_id_str, metadata_id_str) # Frontend callers only await success and invalidate metadata caches; no response body is consumed. return "", 204 @@ -166,9 +169,9 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): match action: case "enable": - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db.session(), dataset) case "disable": - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db.session(), dataset) # Frontend callers only await success and invalidate metadata caches; no response body is consumed. return "", 204 @@ -195,7 +198,7 @@ class DocumentMetadataEditApi(Resource): metadata_args = MetadataOperationData.model_validate(console_ns.payload or {}) - MetadataService.update_documents_metadata(dataset, metadata_args, current_user) + MetadataService.update_documents_metadata(db.session(), dataset, metadata_args, current_user) # Frontend callers only await success and invalidate caches; no response body is consumed. return "", 204 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 756dfe84f6c..c2104ccfc61 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -20,6 +20,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from models.model import InstalledApp from services.audio_service import AudioService @@ -99,7 +100,13 @@ class ChatTextApi(InstalledAppResource): text = payload.text voice = payload.voice - response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + response = AudioService.transcript_tts( + app_model=app_model, + session=db.session, + text=text, + voice=voice, + message_id=message_id, + ) return response except services.errors.app_model_config.AppModelConfigBrokenError: logger.exception("App model config broken.") diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index ad98dd303fb..6aef9129780 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -419,7 +419,13 @@ class TrialChatTextApi(TrialAppResource): app_id = app_model.id user_id = current_user.id - response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + response = AudioService.transcript_tts( + app_model=app_model, + session=db.session, + text=text, + voice=voice, + message_id=message_id, + ) RecommendedAppService.add_trial_app_record(db.session, app_id, user_id) return response except services.errors.app_model_config.AppModelConfigBrokenError: diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index b3230d77e69..4ea77e04b96 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -131,7 +131,7 @@ def _normalize_invitee_emails(emails: list[str]) -> list[str]: def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int: new_member_count = 0 for email in emails: - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: new_member_count += 1 continue diff --git a/api/controllers/inner_api/knowledge/retrieval.py b/api/controllers/inner_api/knowledge/retrieval.py index 1c1320fde42..e34dedea286 100644 --- a/api/controllers/inner_api/knowledge/retrieval.py +++ b/api/controllers/inner_api/knowledge/retrieval.py @@ -14,6 +14,7 @@ from controllers.common.schema import register_response_schema_models, register_ from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import plugin_inner_api_only from core.workflow.nodes.knowledge_retrieval import exc as retrieval_exc +from extensions.ext_database import db from libs.exception import BaseHTTPException from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest, InnerKnowledgeRetrieveResponse from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError, InnerKnowledgeRetrievalServiceError @@ -81,7 +82,7 @@ class InnerKnowledgeRetrieveApi(Resource): ) from exc try: - response = InnerKnowledgeRetrievalService().retrieve(payload) + response = InnerKnowledgeRetrievalService().retrieve(payload, session=db.session) except InnerKnowledgeRetrievalServiceError as exc: raise InnerKnowledgeRetrievalHttpError( error_code=exc.error_code, diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index 0ff225271df..5653fbae432 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -193,7 +193,7 @@ class WorkspaceMembersApi(Resource): raise BadRequest(str(exc)) normalized_email = body.email.lower() - member = AccountService.get_account_by_email_with_case_fallback(normalized_email) + member = AccountService.get_account_by_email_with_case_fallback(db.session, normalized_email) if member is None: # invite_new_member just created or fetched this account. raise RuntimeError("invited member missing from DB after invite") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 2b5a9ba83a1..59ed4b4a4b1 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -23,6 +23,7 @@ from controllers.service_api.app.error import ( from controllers.service_api.schema import binary_response, expect_with_user, multipart_file_params from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService @@ -177,7 +178,12 @@ class TextApi(Resource): text = payload.text voice = payload.voice response = AudioService.transcript_tts( - app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id + app_model=app_model, + session=db.session, + text=text, + voice=voice, + end_user=end_user.external_user_id, + message_id=message_id, ) return response diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 426d008c412..7363e6bdfd4 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -8,6 +8,7 @@ from controllers.common.controller_schemas import MetadataUpdatePayload from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +from extensions.ext_database import db from fields.dataset_fields import ( DatasetMetadataActionResponse, DatasetMetadataBuiltInFieldsResponse, @@ -85,7 +86,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) + metadata = MetadataService.create_metadata(db.session(), dataset_id_str, metadata_args) return dump_response(DatasetMetadataResponse, metadata), 201 @service_api_ns.doc( @@ -118,7 +119,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - metadata = MetadataService.get_dataset_metadatas(dataset) + metadata = MetadataService.get_dataset_metadatas(db.session(), dataset) return dump_response(DatasetMetadataListResponse, metadata), 200 @@ -158,7 +159,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name) + metadata = MetadataService.update_metadata_name(db.session(), dataset_id_str, metadata_id_str, payload.name) return dump_response(DatasetMetadataResponse, metadata), 200 @service_api_ns.doc( @@ -193,7 +194,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - MetadataService.delete_metadata(dataset_id_str, metadata_id_str) + MetadataService.delete_metadata(db.session(), dataset_id_str, metadata_id_str) return "", 204 @@ -263,9 +264,9 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): match action: case "enable": - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db.session(), dataset) case "disable": - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db.session(), dataset) return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200 @@ -309,6 +310,6 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {}) - MetadataService.update_documents_metadata(dataset, metadata_args) + MetadataService.update_documents_metadata(db.session(), dataset, metadata_args) return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200 diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index c762c914861..801c1f5a629 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -22,6 +22,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App, EndUser @@ -130,7 +131,12 @@ class TextApi(WebApiResource): text = payload.text voice = payload.voice response = AudioService.transcript_tts( - app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id + app_model=app_model, + session=db.session, + text=text, + voice=voice, + end_user=end_user.external_user_id, + message_id=message_id, ) return response diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index d0e023e40ee..ecc91113c32 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -69,7 +69,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - account = AccountService.get_account_by_email_with_case_fallback(request_email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, request_email) if account is None: raise AuthenticationFailedError() else: @@ -168,7 +168,7 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: account = db.session.merge(account) diff --git a/api/core/app/apps/agent_app/runtime_request_builder.py b/api/core/app/apps/agent_app/runtime_request_builder.py index fc1fcb0b168..9790f2fbca0 100644 --- a/api/core/app/apps/agent_app/runtime_request_builder.py +++ b/api/core/app/apps/agent_app/runtime_request_builder.py @@ -197,7 +197,7 @@ class AgentAppRuntimeRequestBuilder: def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str: """Return the transport plugin id expected by plugin-daemon headers.""" if plugin_id.count("/") == 1: - return plugin_id + return plugin_id.split(":", 1)[0].split("@", 1)[0] if plugin_id: return ModelProviderID(plugin_id).plugin_id return ModelProviderID(model_provider).plugin_id diff --git a/api/core/workflow/nodes/agent_v2/runtime_request_builder.py b/api/core/workflow/nodes/agent_v2/runtime_request_builder.py index e3c2dcee839..e5a541ed350 100644 --- a/api/core/workflow/nodes/agent_v2/runtime_request_builder.py +++ b/api/core/workflow/nodes/agent_v2/runtime_request_builder.py @@ -265,7 +265,7 @@ class WorkflowAgentRuntimeRequestBuilder: def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str: """Return the transport plugin id expected by plugin-daemon headers.""" if plugin_id.count("/") == 1: - return plugin_id + return plugin_id.split(":", 1)[0].split("@", 1)[0] if plugin_id: return ModelProviderID(plugin_id).plugin_id return ModelProviderID(model_provider).plugin_id diff --git a/api/openapi/markdown/console-openapi.md b/api/openapi/markdown/console-openapi.md index cff0286ad8f..6041662f303 100644 --- a/api/openapi/markdown/console-openapi.md +++ b/api/openapi/markdown/console-openapi.md @@ -3807,6 +3807,26 @@ Submit human input form preview for workflow | ---- | ----------- | ------ | | 200 | Workflow agent composer candidates | **application/json**: [AgentComposerCandidatesResponse](#agentcomposercandidatesresponse)
| +### [POST] /apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/copy-from-roster +#### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | | Yes | string (uuid) | +| node_id | path | | Yes | string | + +#### Request Body + +| Required | Schema | +| -------- | ------ | +| Yes | **application/json**: [WorkflowComposerCopyFromRosterPayload](#workflowcomposercopyfromrosterpayload)
| + +#### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workflow roster agent copied to inline agent | **application/json**: [WorkflowAgentComposerResponse](#workflowagentcomposerresponse)
| + ### [POST] /apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/impact #### Parameters @@ -14385,9 +14405,14 @@ Button styles for user actions. | agent_soul | [AgentSoulConfig](#agentsoulconfig) | | No | | binding | [ComposerBindingPayload](#composerbindingpayload) | | No | | client_revision_id | string | | No | +| description | string | | No | +| icon | string | | No | +| icon_background | string | | No | +| icon_type | [AgentIconType](#agenticontype) | | No | | idempotency_key | string | | No | | new_agent_name | string | | No | | node_job | [WorkflowNodeJobConfig](#workflownodejobconfig) | | No | +| role | string | | No | | save_strategy | [ComposerSaveStrategy](#composersavestrategy) | | Yes | | soul_lock | [ComposerSoulLockPayload](#composersoullockpayload) | | No | | variant | [ComposerVariant](#composervariant) | | Yes | @@ -20560,6 +20585,14 @@ How a workflow node is bound to an Agent. | position_x | number | Comment X position | No | | position_y | number | Comment Y position | No | +#### WorkflowComposerCopyFromRosterPayload + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| idempotency_key | string | | No | +| source_agent_id | string | | Yes | +| source_snapshot_id | string | | No | + #### WorkflowConversationVariableResponse | Name | Type | Description | Required | diff --git a/api/services/account_service.py b/api/services/account_service.py index 21b5f1eedba..80411dd288e 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -14,7 +14,6 @@ from werkzeug.exceptions import Unauthorized from configs import dify_config from constants.languages import get_valid_language, language_timezone_mapping -from core.db.session_factory import session_factory from events.tenant_event import tenant_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback @@ -981,19 +980,18 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str) -> Account | None: + def get_account_by_email_with_case_fallback(session: Session | scoped_session, email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - with session_factory.create_session() as session: - account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() - if account or email == email.lower(): - return account + account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() + return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -1958,7 +1956,7 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) requires_setup = False if not account: diff --git a/api/services/agent/composer_service.py b/api/services/agent/composer_service.py index 0a17c06300f..8c83ee80031 100644 --- a/api/services/agent/composer_service.py +++ b/api/services/agent/composer_service.py @@ -4,15 +4,18 @@ from typing import Any from sqlalchemy import func, or_, select from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql.elements import ColumnElement from extensions.ext_database import db from libs.helper import to_timestamp +from models import Account from models.agent import ( Agent, AgentConfigRevision, AgentConfigRevisionOperation, AgentConfigSnapshot, AgentDriveFile, + AgentIconType, AgentKind, AgentScope, AgentSource, @@ -20,9 +23,7 @@ from models.agent import ( WorkflowAgentBindingType, WorkflowAgentNodeBinding, ) -from models.agent_config_entities import ( - DeclaredOutputConfig, -) +from models.agent_config_entities import DeclaredOutputConfig from models.agent_config_entities import ( effective_declared_outputs as _effective_declared_outputs, ) @@ -32,8 +33,12 @@ from services.agent.composer_validator import ComposerConfigValidator from services.agent.errors import ( AgentNameConflictError, AgentNotFoundError, + AgentVersionConflictError, AgentVersionNotFoundError, + InvalidComposerConfigError, ) +from services.agent.roster_service import AgentRosterService +from services.app_service import AppService, CreateAppParams from services.entities.agent_entities import ( AgentSoulConfig, ComposerCandidatesResponse, @@ -172,6 +177,86 @@ class AgentComposerService: ) return state + @classmethod + def copy_workflow_composer_from_roster( + cls, + *, + tenant_id: str, + app_id: str, + node_id: str, + account_id: str, + source_agent_id: str, + source_snapshot_id: str | None = None, + idempotency_key: str | None = None, + ) -> dict[str, Any]: + workflow = cls._get_draft_workflow(tenant_id=tenant_id, app_id=app_id) + binding = cls._require_binding( + cls._get_workflow_binding(tenant_id=tenant_id, workflow_id=workflow.id, node_id=node_id) + ) + + if binding.binding_type == WorkflowAgentBindingType.INLINE_AGENT and idempotency_key: + agent = cls._get_agent_if_present(tenant_id=tenant_id, agent_id=binding.agent_id) + version = cls._get_version_if_present( + tenant_id=tenant_id, + agent_id=agent.id if agent else None, + version_id=binding.current_snapshot_id, + ) + return cls._serialize_workflow_state(binding=binding, agent=agent, version=version) + + if binding.binding_type != WorkflowAgentBindingType.ROSTER_AGENT: + raise InvalidComposerConfigError("Workflow agent node must be bound to a roster agent.") + if binding.agent_id != source_agent_id: + raise InvalidComposerConfigError("Source agent does not match the current workflow node binding.") + + source_agent = cls._require_agent(tenant_id=tenant_id, agent_id=source_agent_id) + if source_agent.scope != AgentScope.ROSTER or source_agent.status != AgentStatus.ACTIVE: + raise InvalidComposerConfigError("Source agent must be an active roster agent.") + source_version = cls._require_version( + tenant_id=tenant_id, + agent_id=source_agent.id, + version_id=source_agent.active_config_snapshot_id, + ) + if source_snapshot_id and source_snapshot_id != source_version.id: + raise AgentVersionConflictError() + + agent_soul = AgentSoulConfig.model_validate(source_version.config_snapshot_dict) + inline_agent = cls._create_workflow_only_agent( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow.id, + node_id=node_id, + account_id=account_id, + agent_soul=agent_soul, + name=source_agent.name, + description=source_agent.description, + role=source_agent.role, + icon_type=source_agent.icon_type, + icon=source_agent.icon, + icon_background=source_agent.icon_background, + ) + cls._copy_agent_drive_rows( + tenant_id=tenant_id, + source_agent_id=source_agent.id, + target_agent_id=inline_agent.id, + account_id=account_id, + agent_soul=agent_soul, + node_job=WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict), + ) + + binding.binding_type = WorkflowAgentBindingType.INLINE_AGENT + binding.agent_id = inline_agent.id + binding.current_snapshot_id = inline_agent.active_config_snapshot_id + binding.updated_by = account_id + db.session.flush() + db.session.commit() + + version = cls._require_version( + tenant_id=tenant_id, + agent_id=inline_agent.id, + version_id=inline_agent.active_config_snapshot_id, + ) + return cls._serialize_workflow_state(binding=binding, agent=inline_agent, version=version) + @classmethod def load_agent_app_composer(cls, *, tenant_id: str, app_id: str) -> dict[str, Any]: agent = db.session.scalar( @@ -849,6 +934,11 @@ class AgentComposerService: tenant_id=tenant_id, account_id=account_id, name=agent_name, + description=payload.description or "", + role=payload.role or "", + icon_type=payload.icon_type, + icon=payload.icon, + icon_background=payload.icon_background, agent_soul=payload.agent_soul, operation=AgentConfigRevisionOperation.SAVE_NEW_AGENT, version_note=payload.version_note, @@ -894,10 +984,25 @@ class AgentComposerService: tenant_id=tenant_id, account_id=account_id, name=agent_name, + description=payload.description if payload.description is not None else source_agent.description, + role=payload.role if payload.role is not None else source_agent.role, + icon_type=payload.icon_type if payload.icon_type is not None else source_agent.icon_type, + icon=payload.icon if payload.icon is not None else source_agent.icon, + icon_background=payload.icon_background + if payload.icon_background is not None + else source_agent.icon_background, agent_soul=agent_soul, operation=AgentConfigRevisionOperation.SAVE_TO_ROSTER, version_note=payload.version_note, ) + cls._copy_agent_drive_rows( + tenant_id=tenant_id, + source_agent_id=source_agent.id, + target_agent_id=roster_agent.id, + account_id=account_id, + agent_soul=agent_soul, + node_job=payload.node_job or WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict), + ) binding.binding_type = WorkflowAgentBindingType.ROSTER_AGENT binding.agent_id = roster_agent.id binding.current_snapshot_id = roster_agent.active_config_snapshot_id @@ -916,11 +1021,21 @@ class AgentComposerService: node_id: str, account_id: str, agent_soul: AgentSoulConfig, + name: str | None = None, + description: str = "", + role: str = "", + icon_type: Any | None = None, + icon: str | None = None, + icon_background: str | None = None, ) -> Agent: agent = Agent( tenant_id=tenant_id, - name=f"Workflow Agent {node_id}", - description="", + name=name or f"Workflow Agent {node_id}", + description=description, + role=role, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, agent_kind=AgentKind.DIFY_AGENT, scope=AgentScope.WORKFLOW_ONLY, source=AgentSource.WORKFLOW, @@ -945,6 +1060,98 @@ class AgentComposerService: agent.active_config_has_model = agent_soul_has_model(agent_soul) return agent + @classmethod + def _copy_agent_drive_rows( + cls, + *, + tenant_id: str, + source_agent_id: str, + target_agent_id: str, + account_id: str, + agent_soul: AgentSoulConfig, + node_job: WorkflowNodeJobConfig | None = None, + ) -> None: + exact_keys, prefixes = cls._drive_copy_scopes_from_agent_configs(agent_soul=agent_soul, node_job=node_job) + predicates: list[ColumnElement[bool]] = [] + if exact_keys: + predicates.append(AgentDriveFile.key.in_(sorted(exact_keys))) + predicates.extend(AgentDriveFile.key.startswith(prefix) for prefix in sorted(prefixes)) + if not predicates: + return + + source_rows = list( + db.session.scalars( + select(AgentDriveFile).where( + AgentDriveFile.tenant_id == tenant_id, + AgentDriveFile.agent_id == source_agent_id, + or_(*predicates), + ) + ).all() + ) + if not source_rows: + return + + existing_target_keys = set( + db.session.scalars( + select(AgentDriveFile.key).where( + AgentDriveFile.tenant_id == tenant_id, + AgentDriveFile.agent_id == target_agent_id, + AgentDriveFile.key.in_([row.key for row in source_rows]), + ) + ).all() + ) + for row in source_rows: + if row.key in existing_target_keys: + continue + db.session.add( + AgentDriveFile( + tenant_id=tenant_id, + agent_id=target_agent_id, + key=row.key, + file_kind=row.file_kind, + file_id=row.file_id, + value_owned_by_drive=row.value_owned_by_drive, + is_skill=row.is_skill, + skill_metadata=row.skill_metadata, + size=row.size, + hash=row.hash, + mime_type=row.mime_type, + created_by=account_id, + ) + ) + + @staticmethod + def _drive_copy_scopes_from_agent_configs( + *, agent_soul: AgentSoulConfig, node_job: WorkflowNodeJobConfig | None = None + ) -> tuple[set[str], set[str]]: + from services.agent.prompt_mentions import MentionKind, parse_prompt_mentions + from services.agent_drive_service import decode_drive_mention_ref + + exact_keys: set[str] = set() + prefixes: set[str] = set() + + for mention in parse_prompt_mentions(agent_soul.prompt.system_prompt): + if mention.kind not in {MentionKind.SKILL, MentionKind.FILE}: + continue + drive_key = decode_drive_mention_ref(mention.ref_id) + if not drive_key: + continue + if mention.kind == MentionKind.SKILL and "/" in drive_key: + prefixes.add(f"{drive_key.rsplit('/', 1)[0]}/") + else: + exact_keys.add(drive_key) + + if node_job is not None: + for file_ref in node_job.metadata.file_refs or []: + if file_ref.drive_key: + exact_keys.add(file_ref.drive_key) + for output in node_job.declared_outputs: + benchmark_ref = output.check.benchmark_file_ref if output.check and output.check.enabled else None + if benchmark_ref and benchmark_ref.drive_key: + exact_keys.add(benchmark_ref.drive_key) + + return exact_keys, prefixes + @classmethod def _create_roster_agent_for_composer( cls, @@ -955,27 +1162,42 @@ class AgentComposerService: agent_soul: AgentSoulConfig, operation: AgentConfigRevisionOperation, version_note: str | None, + description: str = "", + role: str = "", + icon_type: AgentIconType | None = None, + icon: str | None = None, + icon_background: str | None = None, ) -> Agent: - agent = Agent( - tenant_id=tenant_id, - name=name, - description="", - agent_kind=AgentKind.DIFY_AGENT, - scope=AgentScope.ROSTER, - source=AgentSource.WORKFLOW, - status=AgentStatus.ACTIVE, - created_by=account_id, - updated_by=account_id, - ) - db.session.add(agent) + account = cls._require_account(account_id=account_id) try: - db.session.flush() + app = AppService().create_app( + tenant_id, + CreateAppParams( + name=name, + description=description, + mode="agent", + agent_role=role, + icon_type=icon_type.value if isinstance(icon_type, AgentIconType) else icon_type, + icon=icon, + icon_background=icon_background, + ), + account, + ) except IntegrityError as exc: db.session.rollback() raise AgentNameConflictError() from exc - version = cls._create_config_version( + + agent = AgentRosterService(db.session).get_app_backing_agent(tenant_id=tenant_id, app_id=app.id) + if agent is None: + raise AgentNotFoundError() + + current_snapshot = cls._require_version( tenant_id=tenant_id, agent_id=agent.id, + version_id=agent.active_config_snapshot_id, + ) + version = cls._update_current_version( + current_snapshot=current_snapshot, account_id=account_id, agent_soul=agent_soul, operation=operation, @@ -983,6 +1205,7 @@ class AgentComposerService: ) agent.active_config_snapshot_id = version.id agent.active_config_has_model = agent_soul_has_model(agent_soul) + agent.updated_by = account_id return agent @classmethod @@ -1111,6 +1334,13 @@ class AgentComposerService: raise AgentNotFoundError() return agent + @classmethod + def _require_account(cls, *, account_id: str) -> Account: + account = db.session.get(Account, account_id) + if not account: + raise ValueError("Account not found") + return account + @classmethod def _get_agent_if_present(cls, *, tenant_id: str, agent_id: str | None) -> Agent | None: if not agent_id: diff --git a/api/services/agent/errors.py b/api/services/agent/errors.py index dcc8f69961f..6a1dc6fb628 100644 --- a/api/services/agent/errors.py +++ b/api/services/agent/errors.py @@ -17,6 +17,10 @@ class AgentArchivedError(Conflict): description = "Archived agent cannot be modified." +class AgentVersionConflictError(Conflict): + description = "Agent config version changed. Please reload and try again." + + class AgentSoulLockedError(BadRequest): description = "Agent Soul is locked for this workflow node." diff --git a/api/services/agent/roster_service.py b/api/services/agent/roster_service.py index 6a9d5818647..97d91b50770 100644 --- a/api/services/agent/roster_service.py +++ b/api/services/agent/roster_service.py @@ -837,6 +837,7 @@ class AgentRosterService: if agent.source == AgentSource.AGENT_APP: return { AgentConfigRevisionOperation.SAVE_NEW_VERSION, + AgentConfigRevisionOperation.SAVE_TO_ROSTER, AgentConfigRevisionOperation.RESTORE_VERSION, } return { diff --git a/api/services/agent_app_feature_service.py b/api/services/agent_app_feature_service.py index b8e98653c8e..5fd794bb10f 100644 --- a/api/services/agent_app_feature_service.py +++ b/api/services/agent_app_feature_service.py @@ -69,7 +69,12 @@ class AgentAppFeatureConfigService: @classmethod def update_features( - cls, *, app_model: App, account: Account, config: dict[str, Any], session: scoped_session + cls, + *, + app_model: App, + account: Account, + config: dict[str, Any], + session: scoped_session, ) -> AppModelConfig: """Persist the presentation features as a new app_model_config version. diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a9024eb3bdd..14c5c0111e5 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,11 +5,11 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from sqlalchemy.orm import Session, scoped_session from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message @@ -77,6 +77,8 @@ class AudioService: def transcript_tts( cls, app_model: App, + *, + session: Session | scoped_session, text: str | None = None, voice: str | None = None, end_user: str | None = None, @@ -87,7 +89,7 @@ class AudioService: if voice is None: if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if is_draft: - workflow = WorkflowService().get_draft_workflow(app_model=app_model) + workflow = WorkflowService().get_draft_workflow(app_model=app_model, session=session) else: workflow = app_model.workflow if ( @@ -132,7 +134,7 @@ class AudioService: uuid.UUID(message_id) except ValueError: return None - message = db.session.get(Message, message_id) + message = session.get(Message, message_id) if message is None: return None if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 36b15170567..42f1d4d8d40 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -2,17 +2,17 @@ import json from typing import Any from sqlalchemy import select +from sqlalchemy.orm import Session from core.helper import encrypter -from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: @staticmethod - def get_provider_auth_list(tenant_id: str): - data_source_api_key_bindings = db.session.scalars( + def get_provider_auth_list(session: Session, tenant_id: str): + data_source_api_key_bindings = session.scalars( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False) ) @@ -20,7 +20,7 @@ class ApiKeyAuthService: return data_source_api_key_bindings @staticmethod - def create_provider_auth(tenant_id: str, args: dict[str, Any]): + def create_provider_auth(session: Session, tenant_id: str, args: dict[str, Any]): auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key @@ -31,12 +31,12 @@ class ApiKeyAuthService: tenant_id=tenant_id, category=args["category"], provider=args["provider"] ) data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) - db.session.add(data_source_api_key_binding) - db.session.commit() + session.add(data_source_api_key_binding) + session.commit() @staticmethod - def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = db.session.scalar( + def get_auth_credentials(session: Session, tenant_id: str, category: str, provider: str): + data_source_api_key_bindings = session.scalar( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, @@ -52,16 +52,16 @@ class ApiKeyAuthService: return credentials @staticmethod - def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = db.session.scalar( + def delete_provider_auth(session: Session, tenant_id: str, binding_id: str): + data_source_api_key_binding = session.scalar( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id, ) ) if data_source_api_key_binding: - db.session.delete(data_source_api_key_binding) - db.session.commit() + session.delete(data_source_api_key_binding) + session.commit() @classmethod def validate_api_key_auth_args(cls, args): diff --git a/api/services/entities/agent_entities.py b/api/services/entities/agent_entities.py index e7b5cbd7c6d..a8634bceb09 100644 --- a/api/services/entities/agent_entities.py +++ b/api/services/entities/agent_entities.py @@ -42,6 +42,11 @@ class ComposerSavePayload(BaseModel): idempotency_key: str | None = None client_revision_id: str | None = None new_agent_name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = None + role: str | None = Field(default=None, max_length=255) + icon_type: AgentIconType | None = None + icon: str | None = Field(default=None, max_length=255) + icon_background: str | None = Field(default=None, max_length=255) @model_validator(mode="after") def validate_variant_sections(self) -> "ComposerSavePayload": @@ -58,6 +63,12 @@ class ComposerSavePayload(BaseModel): return self +class WorkflowComposerCopyFromRosterPayload(BaseModel): + source_agent_id: str = Field(min_length=1, max_length=255) + source_snapshot_id: str | None = Field(default=None, max_length=255) + idempotency_key: str | None = Field(default=None, max_length=255) + + class RosterAgentCreatePayload(BaseModel): name: str = Field(min_length=1, max_length=255) mode: Literal["agent"] = "agent" diff --git a/api/services/feedback_service.py b/api/services/feedback_service.py index 24cfb8aa852..62885c901b7 100644 --- a/api/services/feedback_service.py +++ b/api/services/feedback_service.py @@ -14,8 +14,9 @@ from models.model import Account, App, Conversation, Message, MessageFeedback class FeedbackService: @staticmethod def export_feedbacks( - session: Session, app_id: str, + *, + session: Session, from_source: str | None = None, rating: str | None = None, has_comment: bool | None = None, @@ -28,6 +29,7 @@ class FeedbackService: Args: app_id: Application ID + session: Database session used to run the export query from_source: Filter by feedback source ('user' or 'admin') rating: Filter by rating ('like' or 'dislike') has_comment: Only include feedback with comments diff --git a/api/services/knowledge_retrieval_inner_service.py b/api/services/knowledge_retrieval_inner_service.py index fccc81c4a29..8759413f533 100644 --- a/api/services/knowledge_retrieval_inner_service.py +++ b/api/services/knowledge_retrieval_inner_service.py @@ -13,11 +13,11 @@ of a separate validation error. """ from sqlalchemy import select +from sqlalchemy.orm import scoped_session from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from extensions.ext_database import db from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.nodes.llm.entities import ModelConfig from models.dataset import Dataset @@ -38,7 +38,11 @@ from services.errors.knowledge_retrieval import ( class InnerKnowledgeRetrievalService: """Validate inner caller scope and delegate to workflow dataset retrieval.""" - def retrieve(self, request: InnerKnowledgeRetrieveRequest) -> InnerKnowledgeRetrieveResponse: + def retrieve( + self, + request: InnerKnowledgeRetrieveRequest, + session: scoped_session, + ) -> InnerKnowledgeRetrieveResponse: """Run tenant-scoped retrieval for a trusted internal caller. This method only rejects caller app existence/tenant mismatches and @@ -56,8 +60,8 @@ class InnerKnowledgeRetrievalService: InnerKnowledgeRetrieveDatasetTenantMismatchError: At least one requested dataset is outside the caller tenant. """ - self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id) - self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids) + self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id, session=session) + self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids, session=session) rag = DatasetRetrieval() results = rag.knowledge_retrieval(request=self._to_rag_request(request)) @@ -66,8 +70,8 @@ class InnerKnowledgeRetrievalService: usage=InnerKnowledgeRetrieveUsage.model_validate(jsonable_encoder(rag.llm_usage)), ) - def _validate_caller_app(self, *, tenant_id: str, app_id: str) -> None: - app = db.session.scalar(select(App).where(App.id == app_id).limit(1)) + def _validate_caller_app(self, *, tenant_id: str, app_id: str, session: scoped_session) -> None: + app = session.scalar(select(App).where(App.id == app_id).limit(1)) if app is None: raise InnerKnowledgeRetrieveAppNotFoundError(f"App '{app_id}' not found") if app.tenant_id != tenant_id: @@ -75,8 +79,8 @@ class InnerKnowledgeRetrievalService: f"App '{app_id}' does not belong to tenant '{tenant_id}'" ) - def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str]) -> None: - datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() + def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str], session: scoped_session) -> None: + datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() found_ids = {dataset.id for dataset in datasets} missing_ids = sorted(set(dataset_ids) - found_ids) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index f9dcfd25c7f..d9cd65b2b39 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -2,9 +2,9 @@ import copy import logging from sqlalchemy import delete, func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from libs.login import resolve_account_fallback @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) class MetadataService: @staticmethod def create_metadata( + session: Session, dataset_id: str, metadata_args: MetadataArgs, current_user: Account | None = None, # TODO: the service_api is not migrated yet @@ -33,7 +34,7 @@ class MetadataService: raise ValueError("Metadata name cannot exceed 255 characters.") current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id) # check if metadata name already exists - if db.session.scalar( + if session.scalar( select(DatasetMetadata) .where( DatasetMetadata.tenant_id == current_tenant_id, @@ -53,12 +54,13 @@ class MetadataService: name=metadata_args.name, created_by=current_user.id, ) - db.session.add(metadata) - db.session.commit() + session.add(metadata) + session.commit() return metadata @staticmethod def update_metadata_name( + session: Session, dataset_id: str, metadata_id: str, name: str, @@ -72,7 +74,7 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id) - if db.session.scalar( + if session.scalar( select(DatasetMetadata) .where( DatasetMetadata.tenant_id == current_tenant_id, @@ -87,7 +89,7 @@ class MetadataService: raise ValueError("Metadata name already exists in Built-in fields.") try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = db.session.scalar( + metadata = session.scalar( select(DatasetMetadata) .where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id) .limit(1) @@ -100,7 +102,7 @@ class MetadataService: metadata.updated_at = naive_utc_now() # update related documents - dataset_metadata_bindings = db.session.scalars( + dataset_metadata_bindings = session.scalars( select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id) ).all() if dataset_metadata_bindings: @@ -114,8 +116,8 @@ class MetadataService: value = doc_metadata.pop(old_name, None) doc_metadata[name] = value document.doc_metadata = doc_metadata - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() return metadata except Exception: logger.exception("Update metadata name failed") @@ -124,21 +126,21 @@ class MetadataService: redis_client.delete(lock_key) @staticmethod - def delete_metadata(dataset_id: str, metadata_id: str): + def delete_metadata(session: Session, dataset_id: str, metadata_id: str): lock_key = f"dataset_metadata_lock_{dataset_id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = db.session.scalar( + metadata = session.scalar( select(DatasetMetadata) .where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id) .limit(1) ) if metadata is None: raise ValueError("Metadata not found.") - db.session.delete(metadata) + session.delete(metadata) # deal related documents - dataset_metadata_bindings = db.session.scalars( + dataset_metadata_bindings = session.scalars( select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id) ).all() if dataset_metadata_bindings: @@ -151,8 +153,8 @@ class MetadataService: doc_metadata = copy.deepcopy(document.doc_metadata) doc_metadata.pop(metadata.name, None) document.doc_metadata = doc_metadata - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() return metadata except Exception: logger.exception("Delete metadata failed") @@ -170,13 +172,13 @@ class MetadataService: ] @staticmethod - def enable_built_in_field(dataset: Dataset): + def enable_built_in_field(session: Session, dataset: Dataset): if dataset.built_in_field_enabled: return lock_key = f"dataset_metadata_lock_{dataset.id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) - db.session.add(dataset) + session.add(dataset) documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) if documents: for document in documents: @@ -190,22 +192,22 @@ class MetadataService: doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata - db.session.add(document) + session.add(document) dataset.built_in_field_enabled = True - db.session.commit() + session.commit() except Exception: logger.exception("Enable built-in field failed") finally: redis_client.delete(lock_key) @staticmethod - def disable_built_in_field(dataset: Dataset): + def disable_built_in_field(session: Session, dataset: Dataset): if not dataset.built_in_field_enabled: return lock_key = f"dataset_metadata_lock_{dataset.id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) - db.session.add(dataset) + session.add(dataset) documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) document_ids = [] if documents: @@ -220,10 +222,10 @@ class MetadataService: doc_metadata.pop(BuiltInField.last_update_date, None) doc_metadata.pop(BuiltInField.source, None) document.doc_metadata = doc_metadata - db.session.add(document) + session.add(document) document_ids.append(document.id) dataset.built_in_field_enabled = False - db.session.commit() + session.commit() except Exception: logger.exception("Disable built-in field failed") finally: @@ -231,6 +233,7 @@ class MetadataService: @staticmethod def update_documents_metadata( + session: Session, dataset: Dataset, metadata_args: MetadataOperationData, current_user: Account | None = None, # TODO: the service_api is not migrated yet @@ -259,11 +262,11 @@ class MetadataService: doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata - db.session.add(document) + session.add(document) # deal metadata binding (in the same transaction as the doc_metadata update) if not operation.partial_update: - db.session.execute( + session.execute( delete(DatasetMetadataBinding).where( DatasetMetadataBinding.document_id == operation.document_id ) @@ -272,7 +275,7 @@ class MetadataService: for metadata_value in operation.metadata_list: # check if binding already exists if operation.partial_update: - existing_binding = db.session.scalar( + existing_binding = session.scalar( select(DatasetMetadataBinding) .where( DatasetMetadataBinding.document_id == operation.document_id, @@ -290,10 +293,10 @@ class MetadataService: metadata_id=metadata_value.id, created_by=current_user.id, ) - db.session.add(dataset_metadata_binding) - db.session.commit() + session.add(dataset_metadata_binding) + session.commit() except Exception: - db.session.rollback() + session.rollback() logger.exception("Update documents metadata failed") raise finally: @@ -313,14 +316,14 @@ class MetadataService: redis_client.set(lock_key, 1, ex=3600) @staticmethod - def get_dataset_metadatas(dataset: Dataset): + def get_dataset_metadatas(session: Session, dataset: Dataset): return { "doc_metadata": [ { "id": item.get("id"), "name": item.get("name"), "type": item.get("type"), - "count": db.session.scalar( + "count": session.scalar( select(func.count(DatasetMetadataBinding.id)).where( DatasetMetadataBinding.metadata_id == item.get("id"), DatasetMetadataBinding.dataset_id == dataset.id, diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 2b63d9171e9..6ecc8eb8bc9 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -35,7 +35,7 @@ class WebAppAuthService: @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: raise AccountNotFoundError() @@ -55,7 +55,7 @@ class WebAppAuthService: @classmethod def get_user_through_email(cls, email: str): - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: return None diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 9f8e4b83093..262ccc18f83 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast from sqlalchemy import exists, select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session, scoped_session, sessionmaker from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -142,7 +142,7 @@ class WorkflowService: return db.session.execute(stmt).scalar_one() def get_draft_workflow( - self, app_model: App, workflow_id: str | None = None, session: Session | None = None + self, app_model: App, workflow_id: str | None = None, session: Session | scoped_session | None = None ) -> Workflow | None: """ Get draft workflow @@ -169,7 +169,7 @@ class WorkflowService: return workflow def get_published_workflow_by_id( - self, app_model: App, workflow_id: str, session: Session | None = None + self, app_model: App, workflow_id: str, session: Session | scoped_session | None = None ) -> Workflow | None: """ fetch published workflow by workflow_id diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py index 5eb9f71e695..e55b46d38bf 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -1,7 +1,7 @@ """Controller integration tests for API key data source auth routes.""" import json -from unittest.mock import patch +from unittest.mock import ANY, patch from flask.testing import FlaskClient from sqlalchemy import select @@ -85,7 +85,7 @@ def test_create_binding_successful( assert response.status_code == 200 assert response.get_json() == {"result": "success"} - create_auth.assert_called_once_with(tenant_id, payload) + create_auth.assert_called_once_with(ANY, tenant_id, payload) def test_create_binding_failure( diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index bb7921a5f45..109332e16c9 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -270,10 +270,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 014c1588fee..812aa299c1b 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -165,10 +165,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index d043c0d413a..d87afb87669 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -494,10 +494,7 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 2c6a9902401..d568a1c0b04 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -4,7 +4,7 @@ from __future__ import annotations import base64 from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -57,7 +57,7 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com") + mock_get_account.assert_called_once_with(ANY, "User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -177,7 +177,7 @@ class TestForgotPasswordResetApi: response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com") + mock_get_account.assert_called_once_with(ANY, "User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py index c93e61b2bfb..e2f8c8fc703 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -51,7 +51,7 @@ class TestApiKeyAuthService: self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) assert len(result) >= 1 tenant_results = [r for r in result if r.tenant_id == tenant_id] @@ -61,7 +61,7 @@ class TestApiKeyAuthService: def test_get_provider_auth_list_empty( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id ): - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] @@ -74,7 +74,7 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] @@ -95,7 +95,7 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) mock_factory.assert_called_once() mock_auth_instance.validate_credentials.assert_called_once() @@ -118,7 +118,7 @@ class TestApiKeyAuthService: mock_auth_instance.validate_credentials.return_value = False mock_factory.return_value = mock_auth_instance - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) db_session_with_containers.expire_all() bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() @@ -142,7 +142,7 @@ class TestApiKeyAuthService: original_key = mock_args["credentials"]["config"]["api_key"] - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123" assert mock_args["credentials"]["config"]["api_key"] != original_key @@ -166,14 +166,14 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result == mock_credentials def test_get_auth_credentials_not_found( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, category, provider ): - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result is None @@ -190,7 +190,7 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result == special_credentials assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" @@ -204,7 +204,7 @@ class TestApiKeyAuthService: binding_id = binding.id db_session_with_containers.expire_all() - ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id) + ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, binding_id) db_session_with_containers.expire_all() remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() @@ -214,7 +214,7 @@ class TestApiKeyAuthService: self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id ): # Should not raise when binding not found - ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) + ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, str(uuid4())) def test_validate_api_key_auth_args_success(self, mock_args): ApiKeyAuthService.validate_api_key_auth_args(mock_args) @@ -288,16 +288,16 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.return_value = "encrypted_key" - with patch("services.auth.api_key_auth_service.db.session") as mock_session: - mock_session.commit.side_effect = Exception("Database error") - with pytest.raises(Exception, match="Database error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + mock_session = MagicMock() + mock_session.commit.side_effect = Exception("Database error") + with pytest.raises(Exception, match="Database error"): + ApiKeyAuthService.create_provider_auth(mock_session, tenant_id, mock_args) @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args): mock_factory.side_effect = Exception("Factory error") with pytest.raises(Exception, match="Factory error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args) @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") @@ -307,7 +307,7 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") with pytest.raises(Exception, match="Encryption error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args) def test_validate_api_key_auth_args_none_input(self): with pytest.raises(TypeError): diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index 1de9ce38a0b..9b86ab41f2b 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -13,6 +13,7 @@ import pytest from flask import Flask from sqlalchemy.orm import Session +from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory from services.auth.api_key_auth_service import ApiKeyAuthService @@ -56,7 +57,7 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_fc_test_key_123" args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) mock_http.assert_called_once() call_args = mock_http.call_args @@ -100,15 +101,15 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_key" args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args1) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args1) args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_2, args2) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_2, args2) db_session_with_containers.expire_all() - result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1) - result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2) + result1 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_1) + result2 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_2) assert len(result1) == 1 assert result1[0].tenant_id == tenant_id_1 @@ -118,7 +119,9 @@ class TestAuthIntegration: def test_cross_tenant_access_prevention( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id_2, category ): - result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) + result = ApiKeyAuthService.get_auth_credentials( + db_session_with_containers, tenant_id_2, category, AuthType.FIRECRAWL + ) assert result is None @@ -160,7 +163,7 @@ class TestAuthIntegration: "provider": AuthType.FIRECRAWL, "credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}, } - ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args) + ApiKeyAuthService.create_provider_auth(db.session(), tenant_id_1, thread_args) results.append("success") except Exception as e: exceptions.append(e) @@ -213,7 +216,7 @@ class TestAuthIntegration: args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} with pytest.raises(httpx.RequestError): - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) db_session_with_containers.expire_all() bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() @@ -250,11 +253,13 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_key" args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL) + result = ApiKeyAuthService.get_auth_credentials( + db_session_with_containers, tenant_id_1, category, AuthType.FIRECRAWL + ) assert result is not None assert result["config"]["api_key"] == "encrypted_key" diff --git a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py index 2593b53fe84..c9cf60bcfb1 100644 --- a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py +++ b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py @@ -158,6 +158,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: with patch("services.audio_service.ModelManager.for_tenant", return_value=mock_model_manager): result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=message.id, voice="en-US-Neural", ) @@ -174,6 +175,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id="invalid-uuid", ) @@ -185,6 +187,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=str(uuid4()), ) @@ -205,6 +208,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=message.id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index e4fd81b53e7..c2b0385fc74 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -97,8 +97,9 @@ class TestFeedbackService: ) # Test CSV export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") - + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="csv" + ) # Verify response structure assert hasattr(result, "headers") assert "text/csv" in result.headers["Content-Type"] @@ -128,7 +129,9 @@ class TestFeedbackService: ) # Test JSON export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="json" + ) # Verify response structure assert hasattr(result, "headers") @@ -158,8 +161,8 @@ class TestFeedbackService: # Test with filters result = FeedbackService.export_feedbacks( - mock_db_session, app_id=sample_data["app"].id, + session=mock_db_session, from_source=FeedbackFromSource.ADMIN, rating=FeedbackRating.DISLIKE, has_comment=True, @@ -175,7 +178,9 @@ class TestFeedbackService: """Test exporting feedback when no data exists.""" mock_db_session.execute.return_value = _execute_result([]) - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="csv" + ) # Should return an empty CSV with headers only assert hasattr(result, "headers") @@ -194,13 +199,13 @@ class TestFeedbackService: # Test with invalid start_date with pytest.raises(ValueError, match="Invalid start_date format"): FeedbackService.export_feedbacks( - mock_db_session, app_id=sample_data["app"].id, start_date="invalid-date-format" + app_id=sample_data["app"].id, session=mock_db_session, start_date="invalid-date-format" ) # Test with invalid end_date with pytest.raises(ValueError, match="Invalid end_date format"): FeedbackService.export_feedbacks( - mock_db_session, app_id=sample_data["app"].id, end_date="invalid-date-format" + app_id=sample_data["app"].id, session=mock_db_session, end_date="invalid-date-format" ) def test_export_feedbacks_invalid_format(self, mock_db_session, sample_data): @@ -208,8 +213,8 @@ class TestFeedbackService: with pytest.raises(ValueError, match="Unsupported format"): FeedbackService.export_feedbacks( - mock_db_session, app_id=sample_data["app"].id, + session=mock_db_session, format_type="xml", # Unsupported format ) @@ -239,7 +244,9 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="json" + ) # Check JSON content json_content = json.loads(result.get_data(as_text=True)) @@ -290,7 +297,9 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="csv" + ) # Check that unicode content is preserved csv_content = result.get_data(as_text=True) @@ -320,7 +329,9 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="json" + ) # Check JSON content for emoji ratings json_content = json.loads(result.get_data(as_text=True)) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py index 5844441e6a5..fbdc265265d 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -95,7 +95,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args, current_account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, metadata_args, current_account) db_session_with_containers.expire_all() updated_doc = db_session_with_containers.get(Document, document.id) @@ -126,7 +126,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args, current_account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, metadata_args, current_account) db_session_with_containers.expire_all() updated_doc = db_session_with_containers.get(Document, document.id) @@ -168,7 +168,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args, current_account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, metadata_args, current_account) db_session_with_containers.expire_all() bindings = db_session_with_containers.scalars( @@ -202,6 +202,8 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")): + with patch.object(db_session_with_containers, "commit", side_effect=RuntimeError("database connection lost")): with pytest.raises(RuntimeError, match="database connection lost"): - MetadataService.update_documents_metadata(dataset, metadata_args, current_account) + MetadataService.update_documents_metadata( + db_session_with_containers, dataset, metadata_args, current_account + ) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 0c9e3830430..7cc9fc7e696 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -183,7 +183,9 @@ class TestMetadataService: metadata_args = MetadataArgs(type="string", name="test_metadata") # Act: Execute the method under test - result = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + result = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Assert: Verify the expected outcomes assert result is not None @@ -218,7 +220,7 @@ class TestMetadataService: # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): - MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + MetadataService.create_metadata(db_session_with_containers, dataset.id, metadata_args, account, tenant.id) def test_create_metadata_name_already_exists( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -236,14 +238,16 @@ class TestMetadataService: # Create first metadata first_metadata_args = MetadataArgs(type="string", name="duplicate_name") - MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id) + MetadataService.create_metadata(db_session_with_containers, dataset.id, first_metadata_args, account, tenant.id) # Try to create second metadata with same name second_metadata_args = MetadataArgs(type="number", name="duplicate_name") # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists."): - MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id) + MetadataService.create_metadata( + db_session_with_containers, dataset.id, second_metadata_args, account, tenant.id + ) def test_create_metadata_name_conflicts_with_built_in_field( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -265,7 +269,7 @@ class TestMetadataService: # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): - MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + MetadataService.create_metadata(db_session_with_containers, dataset.id, metadata_args, account, tenant.id) def test_update_metadata_name_success( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -283,11 +287,15 @@ class TestMetadataService: # Create metadata first metadata_args = MetadataArgs(type="string", name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Act: Execute the method under test new_name = "new_name" - result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name, account, tenant.id) + result = MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, metadata.id, new_name, account, tenant.id + ) # Assert: Verify the expected outcomes assert result is not None @@ -316,14 +324,18 @@ class TestMetadataService: # Create metadata first metadata_args = MetadataArgs(type="string", name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Try to update with too long name long_name = "a" * 256 # 256 characters, exceeding 255 limit # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): - MetadataService.update_metadata_name(dataset.id, metadata.id, long_name, account, tenant.id) + MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, metadata.id, long_name, account, tenant.id + ) def test_update_metadata_name_already_exists( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -341,14 +353,20 @@ class TestMetadataService: # Create two metadata entries first_metadata_args = MetadataArgs(type="string", name="first_metadata") - first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id) + first_metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, first_metadata_args, account, tenant.id + ) second_metadata_args = MetadataArgs(type="number", name="second_metadata") - second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id) + second_metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, second_metadata_args, account, tenant.id + ) # Try to update first metadata with second metadata's name with pytest.raises(ValueError, match="Metadata name already exists."): - MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata", account, tenant.id) + MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, first_metadata.id, "second_metadata", account, tenant.id + ) def test_update_metadata_name_conflicts_with_built_in_field( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -366,13 +384,17 @@ class TestMetadataService: # Create metadata first metadata_args = MetadataArgs(type="string", name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Try to update with built-in field name built_in_field_name = BuiltInField.document_name with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): - MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name, account, tenant.id) + MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, metadata.id, built_in_field_name, account, tenant.id + ) def test_update_metadata_name_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -395,7 +417,9 @@ class TestMetadataService: new_name = "new_name" # Act: Execute the method under test - result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name, account, tenant.id) + result = MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, fake_metadata_id, new_name, account, tenant.id + ) # Assert: Verify the method returns None when metadata is not found assert result is None @@ -416,10 +440,12 @@ class TestMetadataService: # Create metadata first metadata_args = MetadataArgs(type="string", name="to_be_deleted") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Act: Execute the method under test - result = MetadataService.delete_metadata(dataset.id, metadata.id) + result = MetadataService.delete_metadata(db_session_with_containers, dataset.id, metadata.id) # Assert: Verify the expected outcomes assert result is not None @@ -450,7 +476,7 @@ class TestMetadataService: fake_metadata_id = str(uuid.uuid4()) # Use valid UUID format # Act: Execute the method under test - result = MetadataService.delete_metadata(dataset.id, fake_metadata_id) + result = MetadataService.delete_metadata(db_session_with_containers, dataset.id, fake_metadata_id) # Assert: Verify the method returns None when metadata is not found assert result is None @@ -474,7 +500,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Create metadata binding binding = DatasetMetadataBinding( @@ -494,7 +522,7 @@ class TestMetadataService: db_session_with_containers.commit() # Act: Execute the method under test - result = MetadataService.delete_metadata(dataset.id, metadata.id) + result = MetadataService.delete_metadata(db_session_with_containers, dataset.id, metadata.id) # Assert: Verify the expected outcomes assert result is not None @@ -559,7 +587,7 @@ class TestMetadataService: assert dataset.built_in_field_enabled is False # Act: Execute the method under test - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the expected outcomes @@ -595,7 +623,7 @@ class TestMetadataService: ]() # Act: Execute the method under test - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the method returns early without changes db_session_with_containers.refresh(dataset) @@ -621,7 +649,7 @@ class TestMetadataService: ]() # Act: Execute the method under test - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the expected outcomes @@ -668,7 +696,7 @@ class TestMetadataService: ] # Act: Execute the method under test - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the expected outcomes db_session_with_containers.refresh(dataset) @@ -700,7 +728,7 @@ class TestMetadataService: ]() # Act: Execute the method under test - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the method returns early without changes @@ -733,7 +761,7 @@ class TestMetadataService: ]() # Act: Execute the method under test - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the expected outcomes db_session_with_containers.refresh(dataset) @@ -758,7 +786,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Mock DocumentService.get_document mock_external_service_dependencies["document_service"].get_document.return_value = document @@ -777,7 +807,7 @@ class TestMetadataService: operation_data = MetadataOperationData(operation_data=[operation]) # Act: Execute the method under test - MetadataService.update_documents_metadata(dataset, operation_data, account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, operation_data, account) # Assert: Verify the expected outcomes @@ -822,7 +852,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Mock DocumentService.get_document mock_external_service_dependencies["document_service"].get_document.return_value = document @@ -841,7 +873,7 @@ class TestMetadataService: operation_data = MetadataOperationData(operation_data=[operation]) # Act: Execute the method under test - MetadataService.update_documents_metadata(dataset, operation_data, account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, operation_data, account) # Assert: Verify the expected outcomes # Verify document metadata was updated with both custom and built-in fields @@ -869,7 +901,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Create metadata operation data from services.entities.knowledge_entities.knowledge_entities import ( @@ -890,7 +924,7 @@ class TestMetadataService: # Act & Assert: The method should raise ValueError("Document not found.") # because the exception is now re-raised after rollback with pytest.raises(ValueError, match="Document not found"): - MetadataService.update_documents_metadata(dataset, operation_data, account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, operation_data, account) def test_knowledge_base_metadata_lock_check_dataset_id( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -986,7 +1020,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Create document and metadata binding document = self._create_test_document( @@ -1005,7 +1041,7 @@ class TestMetadataService: db_session_with_containers.commit() # Act: Execute the method under test - result = MetadataService.get_dataset_metadatas(dataset) + result = MetadataService.get_dataset_metadatas(db_session_with_containers, dataset) # Assert: Verify the expected outcomes assert result is not None @@ -1045,10 +1081,12 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Act: Execute the method under test - result = MetadataService.get_dataset_metadatas(dataset) + result = MetadataService.get_dataset_metadatas(db_session_with_containers, dataset) # Assert: Verify the expected outcomes assert result is not None @@ -1077,7 +1115,7 @@ class TestMetadataService: ) # Act: Execute the method under test - result = MetadataService.get_dataset_metadatas(dataset) + result = MetadataService.get_dataset_metadatas(db_session_with_containers, dataset) # Assert: Verify the expected outcomes assert result is not None diff --git a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py index 27bb75e21f8..3d84f899379 100644 --- a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py +++ b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py @@ -15,6 +15,7 @@ from controllers.console.agent.composer import ( AgentComposerValidateApi, WorkflowAgentComposerApi, WorkflowAgentComposerCandidatesApi, + WorkflowAgentComposerCopyFromRosterApi, WorkflowAgentComposerImpactApi, WorkflowAgentComposerSaveToRosterApi, WorkflowAgentComposerValidateApi, @@ -1017,6 +1018,58 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save( )["save_options"] == ["node_job_only"] +def test_workflow_composer_copy_from_roster(app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str) -> None: + app_model = SimpleNamespace(id="app-1") + captured: dict[str, object] = {} + + def fake_copy_from_roster(**kwargs): + captured.update(kwargs) + return _workflow_composer_response( + binding={ + "id": "binding-1", + "binding_type": "inline_agent", + "agent_id": "inline-agent-1", + "current_snapshot_id": "inline-version-1", + "workflow_id": "workflow-1", + "node_id": kwargs["node_id"], + }, + agent={ + "id": "inline-agent-1", + "name": "Nadia", + "description": "", + "scope": "workflow_only", + "status": "active", + }, + active_config_snapshot={"id": "inline-version-1", "version": 1}, + ) + + monkeypatch.setattr( + composer_controller.AgentComposerService, "copy_workflow_composer_from_roster", fake_copy_from_roster + ) + + with app.test_request_context( + json={ + "source_agent_id": "roster-agent-1", + "source_snapshot_id": "roster-version-1", + "idempotency_key": "copy-1", + } + ): + result = unwrap(WorkflowAgentComposerCopyFromRosterApi.post)( + WorkflowAgentComposerCopyFromRosterApi(), "tenant-1", account_id, app_model, "node-1" + ) + + assert result["binding"]["binding_type"] == "inline_agent" + assert captured == { + "tenant_id": "tenant-1", + "app_id": "app-1", + "node_id": "node-1", + "account_id": account_id, + "source_agent_id": "roster-agent-1", + "source_snapshot_id": "roster-version-1", + "idempotency_key": "copy-1", + } + + def test_workflow_impact_returns_empty_without_version(app: Flask) -> None: payload = {"variant": ComposerVariant.WORKFLOW.value, "save_strategy": ComposerSaveStrategy.NODE_JOB_ONLY.value} diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py index 7f449bb376e..21d1932f820 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime from inspect import unwrap from types import SimpleNamespace -from unittest.mock import PropertyMock, patch +from unittest.mock import ANY, PropertyMock, patch from controllers.console import console_ns from controllers.console.auth.data_source_bearer_auth import ( @@ -34,13 +34,16 @@ def test_list_data_source_auth_uses_injected_tenant_id() -> None: updated_at=datetime(2026, 1, 2, tzinfo=UTC), ) - with patch( - "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list", - return_value=[binding], - ) as get_provider_auth_list: + with ( + patch("controllers.console.auth.data_source_bearer_auth.db"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list", + return_value=[binding], + ) as get_provider_auth_list, + ): result = method(api, "tenant-1") - get_provider_auth_list.assert_called_once_with("tenant-1") + get_provider_auth_list.assert_called_once_with(ANY, "tenant-1") assert result["sources"][0]["id"] == "binding-1" assert result["sources"][0]["provider"] == "custom" @@ -56,12 +59,13 @@ def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None: with ( _payload_patch(payload), + patch("controllers.console.auth.data_source_bearer_auth.db"), patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth, ): result, status = method(api, "tenant-1") - create_auth.assert_called_once_with("tenant-1", payload) + create_auth.assert_called_once_with(ANY, "tenant-1", payload) assert result == {"result": "success"} assert status == 200 @@ -70,11 +74,14 @@ def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None: api = ApiKeyAuthDataSourceBindingDelete() method = unwrap(api.delete) - with patch( - "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" - ) as delete_provider_auth: + with ( + patch("controllers.console.auth.data_source_bearer_auth.db"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" + ) as delete_provider_auth, + ): result, status = method(api, "tenant-1", "binding-1") - delete_provider_auth.assert_called_once_with("tenant-1", "binding-1") + delete_provider_auth.assert_called_once_with(ANY, "tenant-1", "binding-1") assert result == "" assert status == 204 diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index e419428ca66..1600fcda50d 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -692,12 +692,7 @@ def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): second.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first, second] - mock_factory = MagicMock() - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - - with patch("services.account_service.session_factory", mock_factory): - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 1cfe152c864..52d050ff55a 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -176,6 +176,7 @@ class TestAudioServiceMockedBehavior: result = AudioService.transcript_tts( app_model=mock_app, + session=Mock(), text="Hello world", voice="nova", end_user="user_123", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py index 7a9978e742a..b77c783ae16 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py @@ -17,7 +17,7 @@ Decorator strategy: import uuid from inspect import unwrap -from unittest.mock import Mock, patch +from unittest.mock import ANY, Mock, patch import pytest from flask import Flask @@ -408,7 +408,7 @@ class TestDatasetMetadataBuiltInFieldAction: assert status == 200 assert response["result"] == "success" - mock_meta_svc.enable_built_in_field.assert_called_once_with(mock_dataset) + mock_meta_svc.enable_built_in_field.assert_called_once_with(ANY, mock_dataset) @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") @@ -439,7 +439,7 @@ class TestDatasetMetadataBuiltInFieldAction: ) assert status == 200 - mock_meta_svc.disable_built_in_field.assert_called_once_with(mock_dataset) + mock_meta_svc.disable_built_in_field.assert_called_once_with(ANY, mock_dataset) @patch("controllers.service_api.dataset.metadata.DatasetService") def test_action_dataset_not_found( diff --git a/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py b/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py index 0d1483e1b79..4f292d90bb4 100644 --- a/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py +++ b/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py @@ -144,6 +144,22 @@ class TestAgentAppRuntimeRequestBuilder: assert result.redacted_request["composition"]["layers"][-1]["config"]["credentials"] == "[REDACTED]" assert result.metadata["conversation_id"] == "conv-1" + def test_build_normalizes_marketplace_model_plugin_id(self): + soul = _soul_with_model() + soul.model.plugin_id = ( + "langgenius/openai:0.4.2@21195ee1321849e0a7d4b3f6b2fd8c2be23ea6c7182e1b444ecc4c1711b52468" + ) + builder = AgentAppRuntimeRequestBuilder( + credentials_provider=_FakeCredentialsProvider(), + plugin_tools_builder=_NoToolsBuilder(), # type: ignore[arg-type] + ) + + result = builder.build(_ctx(soul)) + + llm = next(layer for layer in result.request.composition.layers if layer.name == "llm") + assert llm.config.plugin_id == "langgenius/openai" + assert llm.config.model_provider == "openai" + def test_build_maps_agent_soul_knowledge_to_knowledge_layer(self): soul = AgentSoulConfig.model_validate( { diff --git a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py index 78e49769159..ffa7ccdbca7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py @@ -189,7 +189,7 @@ def test_normalizes_langgenius_model_provider_for_agent_backend_transport(): context.snapshot.config_snapshot = AgentSoulConfig( prompt={"system_prompt": "You are careful."}, model=AgentSoulModelConfig( - plugin_id="langgenius/openai/openai", + plugin_id="langgenius/openai:0.4.2@21195ee1321849e0a7d4b3f6b2fd8c2be23ea6c7182e1b444ecc4c1711b52468", model_provider="langgenius/openai/openai", model="gpt-test", ), diff --git a/api/tests/unit_tests/services/agent/test_agent_composer_entities.py b/api/tests/unit_tests/services/agent/test_agent_composer_entities.py index e82ba92029b..23988c2ec20 100644 --- a/api/tests/unit_tests/services/agent/test_agent_composer_entities.py +++ b/api/tests/unit_tests/services/agent/test_agent_composer_entities.py @@ -105,6 +105,28 @@ def test_agent_app_soul_allows_app_features_and_variables(): assert payload.agent_soul.app_variables[0].name == "company_name" +def test_composer_save_payload_accepts_new_roster_metadata(): + payload = ComposerSavePayload.model_validate( + { + "variant": ComposerVariant.WORKFLOW, + "save_strategy": ComposerSaveStrategy.SAVE_TO_ROSTER, + "new_agent_name": "Research Agent", + "description": "Finds relevant sources.", + "role": "Research Assistant", + "icon_type": "emoji", + "icon": "search", + "icon_background": "#E0F2FE", + } + ) + + assert payload.new_agent_name == "Research Agent" + assert payload.description == "Finds relevant sources." + assert payload.role == "Research Assistant" + assert payload.icon_type == "emoji" + assert payload.icon == "search" + assert payload.icon_background == "#E0F2FE" + + def test_knowledge_query_mode_uses_stable_backend_enums(): config = AgentSoulConfig.model_validate( { diff --git a/api/tests/unit_tests/services/agent/test_agent_services.py b/api/tests/unit_tests/services/agent/test_agent_services.py index 9ba62d60375..1bac183c39f 100644 --- a/api/tests/unit_tests/services/agent/test_agent_services.py +++ b/api/tests/unit_tests/services/agent/test_agent_services.py @@ -3,6 +3,7 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest +from sqlalchemy.exc import IntegrityError from core.workflow.nodes.agent_v2.validators import WorkflowAgentNodeValidationError from models.agent import ( @@ -10,6 +11,7 @@ from models.agent import ( AgentConfigRevisionOperation, AgentConfigSnapshot, AgentDebugConversation, + AgentDriveFile, AgentKind, AgentScope, AgentSource, @@ -31,7 +33,12 @@ from services.agent import composer_service, roster_service from services.agent.agent_soul_state import agent_soul_has_model from services.agent.composer_service import AgentComposerService from services.agent.composer_validator import ComposerConfigValidator -from services.agent.errors import InvalidComposerConfigError +from services.agent.errors import ( + AgentNameConflictError, + AgentNotFoundError, + AgentVersionConflictError, + InvalidComposerConfigError, +) from services.agent.roster_service import AgentRosterService from services.agent.workflow_publish_service import WorkflowAgentPublishService from services.app_service import AppListParams, AppService @@ -415,9 +422,34 @@ def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.Monk fake_session = FakeSession() monkeypatch.setattr(composer_service.db, "session", fake_session) workflow_agent = SimpleNamespace(id="inline-agent-1", active_config_snapshot_id="inline-version-1") - roster_agent = SimpleNamespace(id="roster-agent-1", active_config_snapshot_id="roster-version-1", name="Roster") + roster_agent = SimpleNamespace( + id="roster-agent-1", + active_config_snapshot_id="roster-version-1", + name="Roster", + description="Source description", + role="Source role", + icon_type="emoji", + icon="source", + icon_background="#FFFFFF", + ) + create_roster_calls = [] + copy_drive_calls = [] monkeypatch.setattr(AgentComposerService, "_create_workflow_only_agent", lambda **kwargs: workflow_agent) - monkeypatch.setattr(AgentComposerService, "_create_roster_agent_for_composer", lambda **kwargs: roster_agent) + + def fake_create_roster_agent_for_composer(**kwargs): + create_roster_calls.append(kwargs) + return roster_agent + + monkeypatch.setattr( + AgentComposerService, + "_create_roster_agent_for_composer", + fake_create_roster_agent_for_composer, + ) + monkeypatch.setattr( + AgentComposerService, + "_copy_agent_drive_rows", + lambda **kwargs: copy_drive_calls.append(kwargs), + ) monkeypatch.setattr(AgentComposerService, "_require_agent", lambda **kwargs: roster_agent) monkeypatch.setattr( AgentComposerService, @@ -443,6 +475,11 @@ def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.Monk "agent_soul": {"prompt": {"system_prompt": "new"}}, "node_job": {"workflow_prompt": "use prior output"}, "new_agent_name": "Copied Agent", + "description": "Copied description", + "role": "Copied role", + "icon_type": "emoji", + "icon": "copied", + "icon_background": "#E0F2FE", } ) existing_binding = WorkflowAgentNodeBinding(agent_id="inline-agent-1", current_snapshot_id="inline-version-1") @@ -500,6 +537,24 @@ def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.Monk assert new_agent_binding.binding_type == WorkflowAgentBindingType.ROSTER_AGENT assert save_to_roster_binding.agent_id == "roster-agent-1" assert new_version_binding.current_snapshot_id == "new-version-1" + assert create_roster_calls[0]["description"] == "Copied description" + assert create_roster_calls[0]["role"] == "Copied role" + assert create_roster_calls[0]["icon"] == "copied" + assert create_roster_calls[0]["icon_background"] == "#E0F2FE" + assert create_roster_calls[1]["description"] == "Copied description" + assert create_roster_calls[1]["role"] == "Copied role" + assert create_roster_calls[1]["icon"] == "copied" + assert create_roster_calls[1]["icon_background"] == "#E0F2FE" + assert copy_drive_calls == [ + { + "tenant_id": "tenant-1", + "source_agent_id": "roster-agent-1", + "target_agent_id": "roster-agent-1", + "account_id": "account-1", + "agent_soul": payload.agent_soul, + "node_job": payload.node_job, + } + ] def test_node_job_only_updates_inline_agent_soul(monkeypatch: pytest.MonkeyPatch): @@ -715,9 +770,464 @@ def test_node_job_only_rejects_inline_binding_pointing_to_roster_agent(monkeypat ) +def test_copy_workflow_composer_from_roster_creates_inline_agent_and_preserves_node_job( + monkeypatch: pytest.MonkeyPatch, +): + fake_session = FakeSession() + monkeypatch.setattr(composer_service.db, "session", fake_session) + workflow = SimpleNamespace(id="workflow-1") + node_job = WorkflowNodeJobConfig(workflow_prompt="keep this node task") + binding = WorkflowAgentNodeBinding( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version="draft", + node_id="node-1", + binding_type=WorkflowAgentBindingType.ROSTER_AGENT, + agent_id="roster-agent-1", + current_snapshot_id="old-roster-version", + node_job_config=node_job, + ) + roster_agent = Agent( + id="roster-agent-1", + tenant_id="tenant-1", + name="Nadia", + description="Clarification Drafter", + role="Clarifies tenders", + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="roster-version-2", + ) + source_version = AgentConfigSnapshot( + id="roster-version-2", + tenant_id="tenant-1", + agent_id="roster-agent-1", + version=2, + config_snapshot='{"prompt":{"system_prompt":"copy me"}}', + ) + inline_agent = Agent( + id="inline-agent-1", + tenant_id="tenant-1", + name="Nadia", + description="Clarification Drafter", + role="Clarifies tenders", + scope=AgentScope.WORKFLOW_ONLY, + source=AgentSource.WORKFLOW, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="inline-version-1", + ) + captured: dict[str, object] = {} + + monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: workflow) + monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: binding) + monkeypatch.setattr(AgentComposerService, "_require_agent", lambda **kwargs: roster_agent) + monkeypatch.setattr(AgentComposerService, "_require_version", lambda **kwargs: source_version) + + def fake_create_workflow_only_agent(**kwargs): + captured["create"] = kwargs + return inline_agent + + def fake_copy_drive_rows(**kwargs): + captured["drive"] = kwargs + + monkeypatch.setattr(AgentComposerService, "_create_workflow_only_agent", fake_create_workflow_only_agent) + monkeypatch.setattr(AgentComposerService, "_copy_agent_drive_rows", fake_copy_drive_rows) + monkeypatch.setattr( + AgentComposerService, + "_serialize_workflow_state", + lambda **kwargs: { + "binding": { + "binding_type": kwargs["binding"].binding_type.value, + "agent_id": kwargs["binding"].agent_id, + "current_snapshot_id": kwargs["binding"].current_snapshot_id, + }, + "node_job": kwargs["binding"].node_job_config_dict, + }, + ) + + state = AgentComposerService.copy_workflow_composer_from_roster( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + account_id="account-1", + source_agent_id="roster-agent-1", + source_snapshot_id="roster-version-2", + ) + + assert state["binding"]["binding_type"] == WorkflowAgentBindingType.INLINE_AGENT.value + assert state["binding"]["agent_id"] == "inline-agent-1" + assert state["node_job"]["workflow_prompt"] == "keep this node task" + assert binding.node_job_config is node_job + create_kwargs = captured["create"] + assert create_kwargs["agent_soul"].prompt.system_prompt == "copy me" + assert create_kwargs["name"] == "Nadia" + assert create_kwargs["role"] == "Clarifies tenders" + drive_kwargs = captured["drive"] + assert drive_kwargs["source_agent_id"] == "roster-agent-1" + assert drive_kwargs["target_agent_id"] == "inline-agent-1" + assert fake_session.commits == 1 + + +def test_copy_workflow_composer_from_roster_rejects_stale_source_snapshot(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1")) + monkeypatch.setattr( + AgentComposerService, + "_get_workflow_binding", + lambda **kwargs: WorkflowAgentNodeBinding( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version="draft", + node_id="node-1", + binding_type=WorkflowAgentBindingType.ROSTER_AGENT, + agent_id="roster-agent-1", + current_snapshot_id="roster-version-1", + node_job_config=WorkflowNodeJobConfig(), + ), + ) + roster_agent = Agent( + id="roster-agent-1", + tenant_id="tenant-1", + name="Nadia", + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="roster-version-2", + ) + source_version = AgentConfigSnapshot( + id="roster-version-2", + tenant_id="tenant-1", + agent_id="roster-agent-1", + version=2, + config_snapshot='{"prompt":{"system_prompt":"copy me"}}', + ) + monkeypatch.setattr(AgentComposerService, "_require_agent", lambda **kwargs: roster_agent) + monkeypatch.setattr(AgentComposerService, "_require_version", lambda **kwargs: source_version) + + with pytest.raises(AgentVersionConflictError): + AgentComposerService.copy_workflow_composer_from_roster( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + account_id="account-1", + source_agent_id="roster-agent-1", + source_snapshot_id="roster-version-1", + ) + + +def test_copy_workflow_composer_from_roster_is_idempotent_when_already_inline(monkeypatch: pytest.MonkeyPatch): + inline_binding = WorkflowAgentNodeBinding( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version="draft", + node_id="node-1", + binding_type=WorkflowAgentBindingType.INLINE_AGENT, + agent_id="inline-agent-1", + current_snapshot_id="inline-version-1", + ) + inline_agent = Agent( + id="inline-agent-1", + tenant_id="tenant-1", + name="Inline", + scope=AgentScope.WORKFLOW_ONLY, + source=AgentSource.WORKFLOW, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="inline-version-1", + ) + inline_version = AgentConfigSnapshot( + id="inline-version-1", + tenant_id="tenant-1", + agent_id="inline-agent-1", + version=1, + config_snapshot='{"prompt":{"system_prompt":"inline"}}', + ) + monkeypatch.setattr(composer_service.db, "session", FakeSession()) + monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1")) + monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: inline_binding) + monkeypatch.setattr(AgentComposerService, "_get_agent_if_present", lambda **kwargs: inline_agent) + monkeypatch.setattr(AgentComposerService, "_get_version_if_present", lambda **kwargs: inline_version) + monkeypatch.setattr( + AgentComposerService, + "_serialize_workflow_state", + lambda **kwargs: {"binding_type": kwargs["binding"].binding_type.value}, + ) + + state = AgentComposerService.copy_workflow_composer_from_roster( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + account_id="account-1", + source_agent_id="roster-agent-1", + idempotency_key="same-click", + ) + + assert state == {"binding_type": WorkflowAgentBindingType.INLINE_AGENT.value} + + +@pytest.mark.parametrize( + ("binding_agent_id", "binding_type", "source_scope", "source_status", "expected_message"), + [ + ( + "roster-agent-1", + WorkflowAgentBindingType.INLINE_AGENT, + AgentScope.ROSTER, + AgentStatus.ACTIVE, + "must be bound to a roster agent", + ), + ( + "other-agent", + WorkflowAgentBindingType.ROSTER_AGENT, + AgentScope.ROSTER, + AgentStatus.ACTIVE, + "does not match", + ), + ( + "roster-agent-1", + WorkflowAgentBindingType.ROSTER_AGENT, + AgentScope.WORKFLOW_ONLY, + AgentStatus.ACTIVE, + "must be an active roster agent", + ), + ( + "roster-agent-1", + WorkflowAgentBindingType.ROSTER_AGENT, + AgentScope.ROSTER, + AgentStatus.ARCHIVED, + "must be an active roster agent", + ), + ], +) +def test_copy_workflow_composer_from_roster_rejects_invalid_source_binding( + monkeypatch: pytest.MonkeyPatch, + binding_agent_id: str, + binding_type: WorkflowAgentBindingType, + source_scope: AgentScope, + source_status: AgentStatus, + expected_message: str, +): + binding = WorkflowAgentNodeBinding( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version="draft", + node_id="node-1", + binding_type=binding_type, + agent_id=binding_agent_id, + current_snapshot_id="version-1", + node_job_config=WorkflowNodeJobConfig(), + ) + source_agent = Agent( + id="roster-agent-1", + tenant_id="tenant-1", + name="Source", + scope=source_scope, + source=AgentSource.AGENT_APP, + status=source_status, + active_config_snapshot_id="version-1", + ) + monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1")) + monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: binding) + monkeypatch.setattr(AgentComposerService, "_require_agent", lambda **kwargs: source_agent) + + with pytest.raises(InvalidComposerConfigError, match=expected_message): + AgentComposerService.copy_workflow_composer_from_roster( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + account_id="account-1", + source_agent_id="roster-agent-1", + ) + + +def test_copy_agent_drive_rows_copies_skill_prefix_and_files(monkeypatch: pytest.MonkeyPatch): + skill_row = AgentDriveFile( + tenant_id="tenant-1", + agent_id="roster-agent-1", + key="tender-analyzer/SKILL.md", + file_kind="tool_file", + file_id="tool-file-1", + value_owned_by_drive=True, + is_skill=True, + skill_metadata='{"name":"Tender Analyzer"}', + size=10, + mime_type="text/markdown", + ) + script_row = AgentDriveFile( + tenant_id="tenant-1", + agent_id="roster-agent-1", + key="tender-analyzer/scripts/run.sh", + file_kind="tool_file", + file_id="tool-file-2", + value_owned_by_drive=True, + size=20, + mime_type="text/x-shellscript", + ) + file_row = AgentDriveFile( + tenant_id="tenant-1", + agent_id="roster-agent-1", + key="files/qna.pdf", + file_kind="upload_file", + file_id="upload-file-1", + value_owned_by_drive=False, + size=30, + mime_type="application/pdf", + ) + fake_session = FakeSession(scalars=[[skill_row, script_row, file_row], []]) + monkeypatch.setattr(composer_service.db, "session", fake_session) + agent_soul = AgentSoulConfig.model_validate( + { + "prompt": { + "system_prompt": "[§skill:tender-analyzer/SKILL.md:Tender Analyzer§]", + }, + } + ) + node_job = WorkflowNodeJobConfig.model_validate( + {"metadata": {"file_refs": [{"name": "qna.pdf", "drive_key": "files/qna.pdf"}]}} + ) + + AgentComposerService._copy_agent_drive_rows( + tenant_id="tenant-1", + source_agent_id="roster-agent-1", + target_agent_id="inline-agent-1", + account_id="account-1", + agent_soul=agent_soul, + node_job=node_job, + ) + + copied = [row for row in fake_session.added if isinstance(row, AgentDriveFile)] + assert [row.key for row in copied] == [ + "tender-analyzer/SKILL.md", + "tender-analyzer/scripts/run.sh", + "files/qna.pdf", + ] + assert {row.agent_id for row in copied} == {"inline-agent-1"} + assert copied[0].file_id == "tool-file-1" + assert copied[0].is_skill is True + assert copied[2].value_owned_by_drive is False + + +def test_copy_agent_drive_rows_skips_when_no_referenced_drive_keys(monkeypatch: pytest.MonkeyPatch): + fake_session = FakeSession() + monkeypatch.setattr(composer_service.db, "session", fake_session) + agent_soul = AgentSoulConfig.model_validate({"prompt": {"system_prompt": "No drive mentions."}}) + + AgentComposerService._copy_agent_drive_rows( + tenant_id="tenant-1", + source_agent_id="roster-agent-1", + target_agent_id="inline-agent-1", + account_id="account-1", + agent_soul=agent_soul, + ) + + assert fake_session.added == [] + + +def test_copy_agent_drive_rows_skips_existing_target_keys(monkeypatch: pytest.MonkeyPatch): + source_row = AgentDriveFile( + tenant_id="tenant-1", + agent_id="roster-agent-1", + key="files/qna.pdf", + file_kind="upload_file", + file_id="upload-file-1", + value_owned_by_drive=False, + size=30, + mime_type="application/pdf", + ) + fake_session = FakeSession(scalars=[[source_row], ["files/qna.pdf"]]) + monkeypatch.setattr(composer_service.db, "session", fake_session) + agent_soul = AgentSoulConfig.model_validate({"prompt": {"system_prompt": "[§file:files/qna.pdf:qna.pdf§]"}}) + + AgentComposerService._copy_agent_drive_rows( + tenant_id="tenant-1", + source_agent_id="roster-agent-1", + target_agent_id="inline-agent-1", + account_id="account-1", + agent_soul=agent_soul, + ) + + assert [row for row in fake_session.added if isinstance(row, AgentDriveFile)] == [] + + +def test_drive_copy_scopes_include_declared_output_benchmark_files(): + agent_soul = AgentSoulConfig.model_validate( + { + "prompt": { + "system_prompt": ( + "[§file:files/source.pdf:source.pdf§] " + "[§knowledge:dataset-1:Docs§] " + "[§skill:tender-analyzer/SKILL.md:Tender Analyzer§]" + ) + }, + } + ) + node_job = WorkflowNodeJobConfig.model_validate( + { + "declared_outputs": [ + { + "name": "qna_report", + "type": "file", + "check": { + "enabled": True, + "prompt": "Compare the generated file with the benchmark.", + "benchmark_file_ref": {"name": "expected.pdf", "drive_key": "files/expected.pdf"}, + }, + }, + { + "name": "summary", + "type": "string", + "check": {"enabled": False, "benchmark_file_ref": {"drive_key": "files/ignored.pdf"}}, + }, + ], + } + ) + + exact_keys, prefixes = AgentComposerService._drive_copy_scopes_from_agent_configs( + agent_soul=agent_soul, + node_job=node_job, + ) + + assert exact_keys == {"files/source.pdf", "files/expected.pdf"} + assert prefixes == {"tender-analyzer/"} + + def test_composer_create_agents_syncs_active_config_has_model(monkeypatch: pytest.MonkeyPatch): fake_session = FakeSession() monkeypatch.setattr(composer_service.db, "session", fake_session) + created_apps = [] + backing_agent = Agent( + id="roster-agent-1", + tenant_id="tenant-1", + name="Ready Agent", + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + app_id="app-agent-1", + active_config_snapshot_id="empty-version-1", + ) + + class FakeAppService: + def create_app(self, tenant_id, params, account): + created_apps.append((tenant_id, params, account)) + return SimpleNamespace(id="app-agent-1") + + class FakeAgentRosterService: + def __init__(self, session): + self.session = session + + def get_app_backing_agent(self, *, tenant_id, app_id): + assert tenant_id == "tenant-1" + assert app_id == "app-agent-1" + return backing_agent + + monkeypatch.setattr(composer_service, "AppService", FakeAppService) + monkeypatch.setattr(composer_service, "AgentRosterService", FakeAgentRosterService) + monkeypatch.setattr(AgentComposerService, "_require_account", lambda **kwargs: SimpleNamespace(id="account-1")) + monkeypatch.setattr( + AgentComposerService, + "_require_version", + lambda **kwargs: SimpleNamespace(id="empty-version-1", tenant_id="tenant-1", agent_id="roster-agent-1"), + ) monkeypatch.setattr( AgentComposerService, "_create_config_version", @@ -745,6 +1255,81 @@ def test_composer_create_agents_syncs_active_config_has_model(monkeypatch: pytes assert workflow_agent.active_config_has_model is True assert roster_agent.active_config_snapshot_id == "version-with-model" assert roster_agent.active_config_has_model is True + assert roster_agent.source == AgentSource.AGENT_APP + assert roster_agent.app_id == "app-agent-1" + created_tenant_id, created_params, created_account = created_apps[0] + assert created_tenant_id == "tenant-1" + assert created_params.mode == "agent" + assert created_params.name == "Ready Agent" + assert created_account.id == "account-1" + + +def test_composer_require_account(monkeypatch: pytest.MonkeyPatch): + account = SimpleNamespace(id="account-1") + monkeypatch.setattr(composer_service.db, "session", SimpleNamespace(get=lambda model, account_id: account)) + + assert AgentComposerService._require_account(account_id="account-1") is account + + +def test_composer_require_account_raises_when_missing(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(composer_service.db, "session", SimpleNamespace(get=lambda model, account_id: None)) + + with pytest.raises(ValueError, match="Account not found"): + AgentComposerService._require_account(account_id="missing-account") + + +def test_composer_create_roster_agent_rolls_back_name_conflict(monkeypatch: pytest.MonkeyPatch): + fake_session = FakeSession() + monkeypatch.setattr(composer_service.db, "session", fake_session) + + class FakeAppService: + def create_app(self, tenant_id, params, account): + raise IntegrityError("insert apps", params, Exception("duplicate")) + + monkeypatch.setattr(composer_service, "AppService", FakeAppService) + monkeypatch.setattr(AgentComposerService, "_require_account", lambda **kwargs: SimpleNamespace(id="account-1")) + + with pytest.raises(AgentNameConflictError): + AgentComposerService._create_roster_agent_for_composer( + tenant_id="tenant-1", + account_id="account-1", + name="Duplicate Agent", + agent_soul=_agent_soul_with_model(), + operation=AgentConfigRevisionOperation.CREATE_VERSION, + version_note=None, + ) + + assert fake_session.rollbacks == 1 + + +def test_composer_create_roster_agent_raises_when_backing_agent_missing(monkeypatch: pytest.MonkeyPatch): + fake_session = FakeSession() + monkeypatch.setattr(composer_service.db, "session", fake_session) + + class FakeAppService: + def create_app(self, tenant_id, params, account): + return SimpleNamespace(id="app-agent-1") + + class FakeAgentRosterService: + def __init__(self, session): + self.session = session + + def get_app_backing_agent(self, *, tenant_id, app_id): + return None + + monkeypatch.setattr(composer_service, "AppService", FakeAppService) + monkeypatch.setattr(composer_service, "AgentRosterService", FakeAgentRosterService) + monkeypatch.setattr(AgentComposerService, "_require_account", lambda **kwargs: SimpleNamespace(id="account-1")) + + with pytest.raises(AgentNotFoundError): + AgentComposerService._create_roster_agent_for_composer( + tenant_id="tenant-1", + account_id="account-1", + name="Missing Backing Agent", + agent_soul=_agent_soul_with_model(), + operation=AgentConfigRevisionOperation.CREATE_VERSION, + version_note=None, + ) def test_composer_version_helpers_and_lookup_errors(monkeypatch: pytest.MonkeyPatch): @@ -1318,6 +1903,7 @@ def test_agent_app_visible_versions_exclude_draft_saves(): assert agent_app_operations == { AgentConfigRevisionOperation.SAVE_NEW_VERSION, + AgentConfigRevisionOperation.SAVE_TO_ROSTER, AgentConfigRevisionOperation.RESTORE_VERSION, } assert AgentConfigRevisionOperation.SAVE_CURRENT_VERSION not in agent_app_operations diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 3b5c6cc9bd6..c748fc0962e 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1821,7 +1821,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with("newuser@example.com") + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "newuser@example.com") def test_invite_new_member_normalizes_new_account_email( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies @@ -1865,7 +1865,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with(mixed_email) + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, mixed_email) mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with( mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "normal" @@ -1923,7 +1923,7 @@ class TestRegisterService: mock_tenant, mock_existing_account, "normal", requires_setup=True ) mock_task_dependencies.delay.assert_called_once() - mock_lookup.assert_called_once_with("existing@example.com") + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "existing@example.com") def test_invite_existing_active_account_requires_acceptance_before_joining( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 5d148974f87..788a47c5c31 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -398,6 +398,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Hello world", voice="en-US-Neural", end_user="user-123", @@ -432,6 +433,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Test", ) @@ -465,6 +467,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Test", ) @@ -496,17 +499,52 @@ class TestAudioServiceTTS: mock_model_instance = MagicMock() mock_model_instance.invoke_tts.return_value = b"draft audio" mock_model_manager.get_default_model_instance.return_value = mock_model_instance + session = MagicMock() # Act result = AudioService.transcript_tts( app_model=app, + session=session, text="Draft test", is_draft=True, ) # Assert assert result == b"draft audio" - mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app) + mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app, session=session) + + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) + def test_transcript_tts_message_id_uses_provided_session( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): + """Test TTS message lookup uses the injected session.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.CHAT) + message_id = "00000000-0000-0000-0000-000000000001" + message = factory.create_message_mock(message_id=message_id, answer="Message answer") + session = MagicMock() + session.get.return_value = message + + mock_model_manager = mock_model_manager_class.return_value + mock_model_instance = MagicMock() + mock_model_instance.invoke_tts.return_value = b"message audio" + mock_model_manager.get_default_model_instance.return_value = mock_model_instance + + # Act + result = AudioService.transcript_tts( + app_model=app, + session=session, + message_id=message_id, + voice="message-voice", + ) + + # Assert + assert result == b"message audio" + session.get.assert_called_once_with(Message, message_id) + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="Message answer", + voice="message-voice", + ) def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory): """Test that TTS raises error when text is missing.""" @@ -515,7 +553,7 @@ class TestAudioServiceTTS: # Act & Assert with pytest.raises(ValueError, match="Text is required"): - AudioService.transcript_tts(app_model=app, text=None) + AudioService.transcript_tts(app_model=app, session=MagicMock(), text=None) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_raises_error_when_no_voices_available( @@ -539,7 +577,7 @@ class TestAudioServiceTTS: # Act & Assert with pytest.raises(ValueError, match="Sorry, no voice available"): - AudioService.transcript_tts(app_model=app, text="Test") + AudioService.transcript_tts(app_model=app, session=MagicMock(), text="Test") class TestAudioServiceTTSVoices: diff --git a/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py b/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py index 287d787ad70..7a8efe85f13 100644 --- a/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py +++ b/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py @@ -74,14 +74,14 @@ def _build_source() -> Source: class TestInnerKnowledgeRetrievalService: @patch("services.knowledge_retrieval_inner_service.DatasetRetrieval") - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_db, mock_rag_cls): + def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_rag_cls): request = _build_request() + mock_session = MagicMock() mock_app = MagicMock(id="app-1", tenant_id="tenant-1") dataset_1 = MagicMock(id="dataset-1", tenant_id="tenant-1", enable_api=False) dataset_2 = MagicMock(id="dataset-2", tenant_id="tenant-1", enable_api=True) - mock_db.session.scalar.return_value = mock_app - mock_db.session.scalars.return_value.all.return_value = [dataset_1, dataset_2] + mock_session.scalar.return_value = mock_app + mock_session.scalars.return_value.all.return_value = [dataset_1, dataset_2] rag = MagicMock() rag.knowledge_retrieval.return_value = [_build_source()] @@ -101,7 +101,7 @@ class TestInnerKnowledgeRetrievalService: } mock_rag_cls.return_value = rag - response = InnerKnowledgeRetrievalService().retrieve(request) + response = InnerKnowledgeRetrievalService().retrieve(request, mock_session) rag_request = rag.knowledge_retrieval.call_args.kwargs["request"] assert rag_request.tenant_id == "tenant-1" @@ -127,8 +127,7 @@ class TestInnerKnowledgeRetrievalService: assert response.usage.currency == "USD" @patch("services.knowledge_retrieval_inner_service.DatasetRetrieval") - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_maps_single_request(self, mock_db, mock_rag_cls): + def test_retrieve_maps_single_request(self, mock_rag_cls): request = _build_request( dataset_ids=["dataset-1"], retrieval={ @@ -151,8 +150,9 @@ class TestInnerKnowledgeRetrievalService: }, attachment_ids=[], ) - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] rag = MagicMock() rag.knowledge_retrieval.return_value = [] @@ -172,7 +172,7 @@ class TestInnerKnowledgeRetrievalService: } mock_rag_cls.return_value = rag - InnerKnowledgeRetrievalService().retrieve(request) + InnerKnowledgeRetrievalService().retrieve(request, mock_session) rag_request = rag.knowledge_retrieval.call_args.kwargs["request"] assert rag_request.retrieval_mode == "single" @@ -184,35 +184,35 @@ class TestInnerKnowledgeRetrievalService: assert rag_request.metadata_model_config is not None assert rag_request.metadata_model_config.provider == "openai" - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_app_missing(self, mock_db): - mock_db.session.scalar.return_value = None + def test_retrieve_raises_when_app_missing(self): + mock_session = MagicMock() + mock_session.scalar.return_value = None with pytest.raises(InnerKnowledgeRetrieveAppNotFoundError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_app_belongs_to_other_tenant(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2") + def test_retrieve_raises_when_app_belongs_to_other_tenant(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2") with pytest.raises(InnerKnowledgeRetrieveAppTenantMismatchError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_dataset_missing(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] + def test_retrieve_raises_when_dataset_missing(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] with pytest.raises(InnerKnowledgeRetrieveDatasetNotFoundError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [ + def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [ MagicMock(id="dataset-1", tenant_id="tenant-1"), MagicMock(id="dataset-2", tenant_id="tenant-2"), ] with pytest.raises(InnerKnowledgeRetrieveDatasetTenantMismatchError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index 36ea1fac1a4..6792243e9d0 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -48,13 +48,15 @@ class TestMetadataBugCompleteValidation: account = _make_account() # Should crash with TypeError with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") + MetadataService.create_metadata(Mock(), "dataset-123", mock_metadata_args, account, "tenant-123") # Test update method as well account = _make_account() none_name = cast(str, None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123") + MetadataService.update_metadata_name( + Mock(), "dataset-123", "metadata-456", none_name, account, "tenant-123" + ) def test_3_database_constraints_verification(self) -> None: """Test Layer 3: Verify database model has nullable=False constraints.""" @@ -97,7 +99,7 @@ class TestMetadataBugCompleteValidation: account = _make_account() with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") + MetadataService.create_metadata(Mock(), "dataset-123", mock_metadata_args, account, "tenant-123") def test_7_end_to_end_validation_layers(self) -> None: """Test all validation layers work together correctly.""" diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index 27570a86f1a..ae93fe5ef51 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -37,7 +37,7 @@ class TestMetadataNullableBug: account = _make_account() # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") + MetadataService.create_metadata(Mock(), "dataset-123", mock_metadata_args, account, "tenant-123") def test_metadata_service_update_with_none_name_crashes(self) -> None: """Test that MetadataService.update_metadata_name crashes when name is None.""" @@ -45,7 +45,9 @@ class TestMetadataNullableBug: none_name = cast(str, None) # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123") + MetadataService.update_metadata_name( + Mock(), "dataset-123", "metadata-456", none_name, account, "tenant-123" + ) def test_api_layer_now_uses_pydantic_validation(self) -> None: """Verify that API layer relies on Pydantic validation instead of reqparse.""" diff --git a/dify-agent/src/dify_agent/layers/shell/layer.py b/dify-agent/src/dify_agent/layers/shell/layer.py index 5db17d68499..a8f46d628a6 100644 --- a/dify-agent/src/dify_agent/layers/shell/layer.py +++ b/dify-agent/src/dify_agent/layers/shell/layer.py @@ -18,9 +18,11 @@ side-effecting ``on_context_resume`` attempt fails after issuing shellctl jobs, Agenton still exits ``resource_context()`` but never transitions the layer to ``ACTIVE``. In that failed-enter path, normal suspend/delete hooks do not run, so the enter hook itself must perform best-effort business compensation before -re-raising the failure. Agent Stub env injection uses shellctl's native per-run -``env`` argument for user-visible ``shell.run`` and for trusted server-owned -fixed scripts executed through ``run_remote_script()``. +re-raising the failure. Agent Soul shell env is injected into user-visible +commands and CLI bootstrap commands without persisting a workspace env file. +Agent Stub env injection uses shellctl's native per-run ``env`` argument for +user-visible ``shell.run`` and for trusted server-owned fixed scripts executed +through ``run_remote_script()``. """ from __future__ import annotations @@ -475,7 +477,7 @@ class DifyShellLayer(PydanticAILayer[DifyShellLayerDeps, object, DifyShellLayerC try: client = self._require_client() result = await client.run( - _wrap_user_script(script), + _wrap_user_script(script, self.config), cwd=self._require_workspace_cwd(), env=self._build_user_shell_run_env(), timeout=timeout, @@ -536,9 +538,9 @@ class DifyShellLayer(PydanticAILayer[DifyShellLayerDeps, object, DifyShellLayerC and optional Agent Stub env injection. Unlike model-visible ``shell.run``, this server-owned boundary does not - source ``.dify/env.sh``. That file is user-controlled shell config, so - sourcing it here would let sandbox code clobber trusted Agent Stub env - values before ``dify-agent file upload`` executes. + inject Agent Soul shell env. Keeping the user-controlled shell env out + of this path prevents sandbox code from clobbering trusted Agent Stub + env values before ``dify-agent file upload`` executes. """ env = None if inject_agent_stub_env: @@ -833,16 +835,18 @@ def _workspace_cwd(session_id: str) -> str: def _workspace_bootstrap_script(config: DifyShellLayerConfig) -> str: - """Return the workspace bootstrap script for env + CLI tool declarations.""" - has_bootstrap = bool(config.env or config.secret_refs or config.cli_tools or config.sandbox is not None) - if not has_bootstrap: + """Return the workspace bootstrap script for CLI tool declarations.""" + install_commands = [command for tool in config.cli_tools for command in tool.install_commands] + if not install_commands: return "" - lines: list[str] = [ - "set -eu", - 'mkdir -p ".dify"', - "cat > \".dify/env.sh\" <<'DIFY_ENV_EOF'", - ] + lines: list[str] = ["set -eu", *_shell_config_export_lines(config), *install_commands] + return "\n".join(lines) + + +def _shell_config_export_lines(config: DifyShellLayerConfig) -> list[str]: + """Return ephemeral Agent Soul shell exports for one shellctl command.""" + lines: list[str] = [] for env_var in config.env: lines.append(f"export {env_var.name}={_shquote(env_var.value)}") for secret_ref in config.secret_refs: @@ -860,32 +864,15 @@ def _workspace_bootstrap_script(config: DifyShellLayerConfig) -> str: if config.sandbox.config: sandbox_config = json.dumps(config.sandbox.config, ensure_ascii=True, sort_keys=True) lines.append(f"export DIFY_SANDBOX_CONFIG_JSON={_shquote(sandbox_config)}") - lines.extend( - [ - "DIFY_ENV_EOF", - 'chmod 600 ".dify/env.sh"', - '. ".dify/env.sh"', - ] - ) - for tool in config.cli_tools: - for command in tool.install_commands: - lines.append(command) - return "\n".join(lines) + return lines -def _wrap_user_script(script: str) -> str: - """Source Agent Soul env before executing a model-requested shell command.""" - # TODO: refactor - return "\n".join( - [ - 'if [ -f ".dify/env.sh" ]; then', - " set -a", - ' . ".dify/env.sh"', - " set +a", - "fi", - script, - ] - ) +def _wrap_user_script(script: str, config: DifyShellLayerConfig) -> str: + """Inject Agent Soul env before executing a model-requested shell command.""" + lines = _shell_config_export_lines(config) + if not lines: + return script + return "\n".join([*lines, script]) def _workspace_mkdir_script(*, session_id: str) -> str: diff --git a/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py b/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py index 30352d87c5e..c7d2599b63c 100644 --- a/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py +++ b/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py @@ -3,6 +3,7 @@ from collections.abc import Callable, Mapping import secrets import time from dataclasses import dataclass +from typing import cast import pytest @@ -454,7 +455,6 @@ def test_shell_layer_create_bootstraps_agent_soul_shell_config(monkeypatch: pyte assert 'export GITHUB_TOKEN="${GITHUB_TOKEN:-}"' in script assert "export DIFY_SANDBOX_PROVIDER='independent'" in script assert "export DIFY_SANDBOX_CONFIG_JSON='{\"cpu\": 2}'" in script - assert '. ".dify/env.sh"' in script assert "apt-get install -y ripgrep" in script return _job_result("bootstrap-job", status=JobStatusName.EXITED, done=True, exit_code=0) @@ -489,10 +489,60 @@ def test_shell_layer_create_bootstraps_agent_soul_shell_config(monkeypatch: pyte assert layer.runtime_state.job_ids == ["mkdir-job", "bootstrap-job"] +def test_shell_layer_injects_agent_soul_env_without_workspace_env_file(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(time, "time", lambda: 0xABC12) + + def token_hex(_nbytes: int) -> str: + return "ff" + + monkeypatch.setattr(secrets, "token_hex", token_hex) + + def run_handler(script: str, cwd: str | None, env: Mapping[str, str] | None, timeout: float) -> JobResult: + del timeout + assert env is None + if cwd is None: + return _job_result("mkdir-job", status=JobStatusName.EXITED, done=True, exit_code=0) + + assert cwd == "~/workspace/abc12ff" + assert "export PROJECT_NAME='demo project'" in script + assert 'export OPENAI_API_KEY="${OPENAI_API_KEY:-}"' in script + assert "export DIFY_SANDBOX_PROVIDER='independent'" in script + assert "export DIFY_SANDBOX_CONFIG_JSON='{\"cpu\": 2}'" in script + assert script.endswith("\npwd") + return _job_result("user-job", status=JobStatusName.EXITED, done=True, exit_code=0) + + client = FakeShellctlClient(run_handler=run_handler) + layer = _shell_layer( + client_factory=lambda _entrypoint: client, + config=DifyShellLayerConfig( + env=[DifyShellEnvVarConfig(name="PROJECT_NAME", value="demo project")], + secret_refs=[DifyShellSecretRefConfig(name="OPENAI_API_KEY", ref="secret-1")], + sandbox=DifyShellSandboxConfig(provider="independent", config={"cpu": 2}), + ), + ) + tools = {tool.name: tool for tool in layer.tools} + + async def scenario() -> None: + async with layer.resource_context(): + await layer.on_context_create() + run_result = cast( + Mapping[str, object], + await tools["shell_run"].function_schema.call( + {"script": "pwd"}, + None, # pyright: ignore[reportArgumentType] + ), + ) + assert run_result["job_id"] == "user-job" + + asyncio.run(scenario()) + + assert [call.cwd for call in client.run_calls] == [None, "~/workspace/abc12ff"] + assert layer.runtime_state.job_ids == ["mkdir-job", "user-job"] + + def test_shell_layer_tools_map_inputs_to_shellctl_calls_and_maintain_offsets() -> None: def run_handler(script: str, cwd: str | None, env: Mapping[str, str] | None, timeout: float) -> JobResult: - assert script.endswith("\npwd") - assert '. ".dify/env.sh"' in script + assert script == "pwd" assert cwd == "~/workspace/abc12ff" assert env is None assert timeout == 2.5 @@ -608,8 +658,7 @@ def test_shell_layer_tools_map_inputs_to_shellctl_calls_and_maintain_offsets() - def test_shell_layer_injects_agent_stub_env_only_for_user_visible_shell_run() -> None: def run_handler(script: str, cwd: str | None, env: Mapping[str, str] | None, timeout: float) -> JobResult: del cwd, timeout - if script.endswith("\npwd"): - assert '. ".dify/env.sh"' in script + if script == "pwd": assert env is not None return _job_result("user-job", status=JobStatusName.EXITED, done=True, exit_code=0) assert env is None @@ -639,8 +688,8 @@ def test_shell_layer_injects_agent_stub_env_only_for_user_visible_shell_run() -> asyncio.run(scenario()) - user_run_call = next(call for call in client.run_calls if call.script.endswith("\npwd")) - internal_run_calls = [call for call in client.run_calls if not call.script.endswith("\npwd")] + user_run_call = next(call for call in client.run_calls if call.script == "pwd") + internal_run_calls = [call for call in client.run_calls if call.script != "pwd"] assert user_run_call.env == { AGENT_STUB_API_BASE_URL_ENV_VAR: "https://agent.example.com/agent-stub", diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 5c1d04ee120..48f94044d8d 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -3668,11 +3668,6 @@ "count": 1 } }, - "web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx": { - "jsx-a11y/role-has-required-aria-props": { - "count": 1 - } - }, "web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx": { "jsx-a11y/no-autofocus": { "count": 1 diff --git a/packages/contracts/generated/api/console/agent/types.gen.ts b/packages/contracts/generated/api/console/agent/types.gen.ts index aa21f2ce651..43119c4f1f4 100644 --- a/packages/contracts/generated/api/console/agent/types.gen.ts +++ b/packages/contracts/generated/api/console/agent/types.gen.ts @@ -134,9 +134,14 @@ export type ComposerSavePayload = { agent_soul?: AgentSoulConfig | null binding?: ComposerBindingPayload | null client_revision_id?: string | null + description?: string | null + icon?: string | null + icon_background?: string | null + icon_type?: AgentIconType | null idempotency_key?: string | null new_agent_name?: string | null node_job?: WorkflowNodeJobConfig | null + role?: string | null save_strategy: ComposerSaveStrategy soul_lock?: ComposerSoulLockPayload variant: ComposerVariant @@ -536,6 +541,8 @@ export type ComposerBindingPayload = { current_snapshot_id?: string | null } +export type AgentIconType = 'emoji' | 'image' | 'link' + export type WorkflowNodeJobConfig = { declared_outputs?: Array human_contacts?: Array @@ -876,8 +883,6 @@ export type LlmMode = 'chat' | 'completion' export type AgentKind = 'dify_agent' -export type AgentIconType = 'emoji' | 'image' | 'link' - export type AgentPublishedReferenceResponse = { app_icon?: string | null app_icon_background?: string | null diff --git a/packages/contracts/generated/api/console/agent/zod.gen.ts b/packages/contracts/generated/api/console/agent/zod.gen.ts index cb4107f2d53..d7f5681ffc4 100644 --- a/packages/contracts/generated/api/console/agent/zod.gen.ts +++ b/packages/contracts/generated/api/console/agent/zod.gen.ts @@ -282,6 +282,13 @@ export const zComposerBindingPayload = z.object({ current_snapshot_id: z.string().nullish(), }) +/** + * AgentIconType + * + * Supported icon storage formats for Agent roster entries. + */ +export const zAgentIconType = z.enum(['emoji', 'image', 'link']) + /** * ComposerSoulLockPayload */ @@ -830,13 +837,6 @@ export const zAgentAppDetailWithSite = z.object({ */ export const zAgentKind = z.enum(['dify_agent']) -/** - * AgentIconType - * - * Supported icon storage formats for Agent roster entries. - */ -export const zAgentIconType = z.enum(['emoji', 'image', 'link']) - /** * AgentPublishedReferenceResponse */ @@ -1876,9 +1876,14 @@ export const zComposerSavePayload = z.object({ agent_soul: zAgentSoulConfig.nullish(), binding: zComposerBindingPayload.nullish(), client_revision_id: z.string().nullish(), + description: z.string().nullish(), + icon: z.string().max(255).nullish(), + icon_background: z.string().max(255).nullish(), + icon_type: zAgentIconType.nullish(), idempotency_key: z.string().nullish(), new_agent_name: z.string().min(1).max(255).nullish(), node_job: zWorkflowNodeJobConfig.nullish(), + role: z.string().max(255).nullish(), save_strategy: zComposerSaveStrategy, soul_lock: zComposerSoulLockPayload.optional(), variant: zComposerVariant, diff --git a/packages/contracts/generated/api/console/apps/orpc.gen.ts b/packages/contracts/generated/api/console/apps/orpc.gen.ts index 7a93572885e..ea72df28458 100644 --- a/packages/contracts/generated/api/console/apps/orpc.gen.ts +++ b/packages/contracts/generated/api/console/apps/orpc.gen.ts @@ -392,6 +392,9 @@ import { zPostAppsByAppIdWorkflowsDraftLoopNodesByNodeIdRunBody, zPostAppsByAppIdWorkflowsDraftLoopNodesByNodeIdRunPath, zPostAppsByAppIdWorkflowsDraftLoopNodesByNodeIdRunResponse, + zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterBody, + zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterPath, + zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponse, zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactBody, zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactPath, zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactResponse, @@ -3479,6 +3482,26 @@ export const candidates = { } export const post51 = oc + .route({ + inputStructure: 'detailed', + method: 'POST', + operationId: 'postAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRoster', + path: '/apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/copy-from-roster', + tags: ['console'], + }) + .input( + z.object({ + body: zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterBody, + params: zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterPath, + }), + ) + .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponse) + +export const copyFromRoster = { + post: post51, +} + +export const post52 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -3495,10 +3518,10 @@ export const post51 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactResponse) export const impact = { - post: post51, + post: post52, } -export const post52 = oc +export const post53 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -3515,10 +3538,10 @@ export const post52 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerSaveToRosterResponse) export const saveToRoster = { - post: post52, + post: post53, } -export const post53 = oc +export const post54 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -3535,7 +3558,7 @@ export const post53 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerValidateResponse) export const validate = { - post: post53, + post: post54, } export const get62 = oc @@ -3569,6 +3592,7 @@ export const agentComposer = { get: get62, put: put4, candidates, + copyFromRoster, impact, saveToRoster, validate, @@ -3598,7 +3622,7 @@ export const lastRun = { * * Run draft workflow node */ -export const post54 = oc +export const post55 = oc .route({ description: 'Run draft workflow node', inputStructure: 'detailed', @@ -3617,7 +3641,7 @@ export const post54 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdRunResponse) export const run8 = { - post: post54, + post: post55, } /** @@ -3625,7 +3649,7 @@ export const run8 = { * * Poll for trigger events and execute single node when event arrives */ -export const post55 = oc +export const post56 = oc .route({ description: 'Poll for trigger events and execute single node when event arrives', inputStructure: 'detailed', @@ -3639,7 +3663,7 @@ export const post55 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdTriggerRunResponse) export const run9 = { - post: post55, + post: post56, } export const trigger = { @@ -3699,7 +3723,7 @@ export const nodes7 = { * * Run draft workflow */ -export const post56 = oc +export const post57 = oc .route({ description: 'Run draft workflow', inputStructure: 'detailed', @@ -3718,7 +3742,7 @@ export const post56 = oc .output(zPostAppsByAppIdWorkflowsDraftRunResponse) export const run10 = { - post: post56, + post: post57, } /** @@ -3840,7 +3864,7 @@ export const systemVariables = { * * Poll for trigger events and execute full workflow when event arrives */ -export const post57 = oc +export const post58 = oc .route({ description: 'Poll for trigger events and execute full workflow when event arrives', inputStructure: 'detailed', @@ -3859,7 +3883,7 @@ export const post57 = oc .output(zPostAppsByAppIdWorkflowsDraftTriggerRunResponse) export const run11 = { - post: post57, + post: post58, } /** @@ -3867,7 +3891,7 @@ export const run11 = { * * Full workflow debug when the start node is a trigger */ -export const post58 = oc +export const post59 = oc .route({ description: 'Full workflow debug when the start node is a trigger', inputStructure: 'detailed', @@ -3886,7 +3910,7 @@ export const post58 = oc .output(zPostAppsByAppIdWorkflowsDraftTriggerRunAllResponse) export const runAll = { - post: post58, + post: post59, } export const trigger2 = { @@ -4039,7 +4063,7 @@ export const get72 = oc * * Sync draft workflow configuration */ -export const post59 = oc +export const post60 = oc .route({ description: 'Sync draft workflow configuration', inputStructure: 'detailed', @@ -4059,7 +4083,7 @@ export const post59 = oc export const draft2 = { get: get72, - post: post59, + post: post60, conversationVariables: conversationVariables2, environmentVariables, features, @@ -4095,7 +4119,7 @@ export const get73 = oc /** * Publish workflow */ -export const post60 = oc +export const post61 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -4114,7 +4138,7 @@ export const post60 = oc export const publish = { get: get73, - post: post60, + post: post61, } /** @@ -4251,7 +4275,7 @@ export const triggers2 = { /** * Restore a published workflow version into the draft workflow */ -export const post61 = oc +export const post62 = oc .route({ description: 'Restore a published workflow version into the draft workflow', inputStructure: 'detailed', @@ -4264,7 +4288,7 @@ export const post61 = oc .output(zPostAppsByAppIdWorkflowsByWorkflowIdRestoreResponse) export const restore = { - post: post61, + post: post62, } /** @@ -4489,7 +4513,7 @@ export const get81 = oc * * Create a new API key for an app */ -export const post62 = oc +export const post63 = oc .route({ description: 'Create a new API key for an app', inputStructure: 'detailed', @@ -4505,7 +4529,7 @@ export const post62 = oc export const apiKeys = { get: get81, - post: post62, + post: post63, byApiKeyId, } @@ -4563,7 +4587,7 @@ export const get83 = oc * * Create a new application */ -export const post63 = oc +export const post64 = oc .route({ description: 'Create a new application', inputStructure: 'detailed', @@ -4579,7 +4603,7 @@ export const post63 = oc export const apps = { get: get83, - post: post63, + post: post64, imports, starred, workflows, diff --git a/packages/contracts/generated/api/console/apps/types.gen.ts b/packages/contracts/generated/api/console/apps/types.gen.ts index fa56590f0a4..9e79518f3cd 100644 --- a/packages/contracts/generated/api/console/apps/types.gen.ts +++ b/packages/contracts/generated/api/console/apps/types.gen.ts @@ -986,9 +986,14 @@ export type ComposerSavePayload = { agent_soul?: AgentSoulConfig | null binding?: ComposerBindingPayload | null client_revision_id?: string | null + description?: string | null + icon?: string | null + icon_background?: string | null + icon_type?: AgentIconType | null idempotency_key?: string | null new_agent_name?: string | null node_job?: WorkflowNodeJobConfig | null + role?: string | null save_strategy: ComposerSaveStrategy soul_lock?: ComposerSoulLockPayload variant: ComposerVariant @@ -1003,6 +1008,12 @@ export type AgentComposerCandidatesResponse = { variant: ComposerVariant } +export type WorkflowComposerCopyFromRosterPayload = { + idempotency_key?: string | null + source_agent_id: string + source_snapshot_id?: string | null +} + export type AgentComposerImpactResponse = { bindings?: Array current_snapshot_id?: string | null @@ -1873,6 +1884,8 @@ export type ComposerBindingPayload = { current_snapshot_id?: string | null } +export type AgentIconType = 'emoji' | 'image' | 'link' + export type ComposerSoulLockPayload = { locked?: boolean unlocked_from_version_id?: string | null @@ -5415,6 +5428,23 @@ export type GetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResp export type GetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResponse = GetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResponses[keyof GetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResponses] +export type PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterData = { + body: WorkflowComposerCopyFromRosterPayload + path: { + app_id: string + node_id: string + } + query?: never + url: '/apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/copy-from-roster' +} + +export type PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponses = { + 200: WorkflowAgentComposerResponse +} + +export type PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponse + = PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponses[keyof PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponses] + export type PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactData = { body: ComposerSavePayload path: { diff --git a/packages/contracts/generated/api/console/apps/zod.gen.ts b/packages/contracts/generated/api/console/apps/zod.gen.ts index 043fc11261f..9b86fda0a62 100644 --- a/packages/contracts/generated/api/console/apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/apps/zod.gen.ts @@ -642,6 +642,15 @@ export const zHumanInputDeliveryTestPayload = z.object({ */ export const zEmptyObjectResponse = z.record(z.string(), z.unknown()) +/** + * WorkflowComposerCopyFromRosterPayload + */ +export const zWorkflowComposerCopyFromRosterPayload = z.object({ + idempotency_key: z.string().max(255).nullish(), + source_agent_id: z.string().min(1).max(255), + source_snapshot_id: z.string().max(255).nullish(), +}) + /** * DraftWorkflowNodeRunPayload */ @@ -1835,6 +1844,13 @@ export const zComposerBindingPayload = z.object({ current_snapshot_id: z.string().nullish(), }) +/** + * AgentIconType + * + * Supported icon storage formats for Agent roster entries. + */ +export const zAgentIconType = z.enum(['emoji', 'image', 'link']) + /** * ComposerSoulLockPayload */ @@ -3336,9 +3352,14 @@ export const zComposerSavePayload = z.object({ agent_soul: zAgentSoulConfig.nullish(), binding: zComposerBindingPayload.nullish(), client_revision_id: z.string().nullish(), + description: z.string().nullish(), + icon: z.string().max(255).nullish(), + icon_background: z.string().max(255).nullish(), + icon_type: zAgentIconType.nullish(), idempotency_key: z.string().nullish(), new_agent_name: z.string().min(1).max(255).nullish(), node_job: zWorkflowNodeJobConfig.nullish(), + role: z.string().max(255).nullish(), save_strategy: zComposerSaveStrategy, soul_lock: zComposerSoulLockPayload.optional(), variant: zComposerVariant, @@ -5342,6 +5363,20 @@ export const zGetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesPa export const zGetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResponse = zAgentComposerCandidatesResponse +export const zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterBody + = zWorkflowComposerCopyFromRosterPayload + +export const zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterPath = z.object({ + app_id: z.uuid(), + node_id: z.string(), +}) + +/** + * Workflow roster agent copied to inline agent + */ +export const zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponse + = zWorkflowAgentComposerResponse + export const zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactBody = zComposerSavePayload diff --git a/packages/contracts/generated/enterprise/orpc.gen.ts b/packages/contracts/generated/enterprise/orpc.gen.ts index 61503a7f742..764cea4fdc1 100644 --- a/packages/contracts/generated/enterprise/orpc.gen.ts +++ b/packages/contracts/generated/enterprise/orpc.gen.ts @@ -380,12 +380,8 @@ export const listRollbackTargets = oc ) .output(zDeploymentServiceListRollbackTargetsResponse) -/** - * CancelDeployment cancels the in-flight deployment on the environment. - */ export const cancelDeployment = oc .route({ - description: 'CancelDeployment cancels the in-flight deployment on the environment.', inputStructure: 'detailed', method: 'POST', operationId: 'DeploymentService_CancelDeployment', @@ -607,13 +603,8 @@ export const releaseService = { precheckRelease, } -/** - * ListEnvironments returns only the environments the current user can - * deploy to. - */ export const listEnvironments = oc .route({ - description: 'ListEnvironments returns only the environments the current user can\n deploy to.', inputStructure: 'detailed', method: 'GET', operationId: 'EnvironmentService_ListEnvironments', diff --git a/packages/contracts/generated/enterprise/types.gen.ts b/packages/contracts/generated/enterprise/types.gen.ts index 600f8975678..882a85465b9 100644 --- a/packages/contracts/generated/enterprise/types.gen.ts +++ b/packages/contracts/generated/enterprise/types.gen.ts @@ -13,13 +13,13 @@ export const AccessMode = { export type AccessMode = (typeof AccessMode)[keyof typeof AccessMode] -export const SubjectType = { - SUBJECT_TYPE_UNSPECIFIED: 'SUBJECT_TYPE_UNSPECIFIED', - SUBJECT_TYPE_ACCOUNT: 'SUBJECT_TYPE_ACCOUNT', - SUBJECT_TYPE_GROUP: 'SUBJECT_TYPE_GROUP', +export const AccessSubjectType = { + ACCESS_SUBJECT_TYPE_UNSPECIFIED: 'ACCESS_SUBJECT_TYPE_UNSPECIFIED', + ACCESS_SUBJECT_TYPE_ACCOUNT: 'ACCESS_SUBJECT_TYPE_ACCOUNT', + ACCESS_SUBJECT_TYPE_GROUP: 'ACCESS_SUBJECT_TYPE_GROUP', } as const -export type SubjectType = (typeof SubjectType)[keyof typeof SubjectType] +export type AccessSubjectType = (typeof AccessSubjectType)[keyof typeof AccessSubjectType] export const AppRunnerLogStatus = { APP_RUNNER_LOG_STATUS_UNSPECIFIED: 'APP_RUNNER_LOG_STATUS_UNSPECIFIED', @@ -295,7 +295,7 @@ export type AccessPolicy = { } export type AccessSubject = { - subjectType: SubjectType + subjectType: AccessSubjectType subjectId: string } @@ -598,7 +598,6 @@ export type Environment = { status: EnvironmentStatus statusMessage: string lastError?: Error - apiServer?: string namespace?: string managedBy?: string runtimeEndpoint?: string @@ -741,9 +740,6 @@ export type GetReleaseResponse = { export type K8sEnvironmentConfig = { namespace?: string - apiServer?: string - caBundle?: string - bearerToken?: string } export type ListApiKeysResponse = { @@ -832,6 +828,7 @@ export type PrecheckReleaseResponse = { canCreate: boolean matchedRelease?: ReleaseContentMatch unsupportedNodes: Array + unsupportedToolProviders: Array } export type PromoteRequest = { @@ -998,6 +995,14 @@ export type UnsupportedDslNode = { type: string } +export type UnsupportedToolProvider = { + nodeId: string + providerType: string + providerId?: string + providerName?: string + toolName?: string +} + export type UpdateAccessChannelsRequest = { appInstanceId?: string webAppEnabled?: boolean @@ -1362,7 +1367,6 @@ export type InfoConfigReply = { Branding?: BrandingInfo WebAppAuth?: WebAppAuthInfo PluginInstallationPermission?: PluginInstallationPermissionInfo - EnableAppDeploy?: boolean } export type InnerAdmission = { @@ -1458,6 +1462,19 @@ export type IsUserAllowedToAccessWebAppRes = { result?: boolean } +export type IssueMcpTokenReply = { + token?: string + expiresAt?: string + tokenType?: string +} + +export type IssueMcpTokenReq = { + userId?: string + tenantId?: string + appId?: string + audience?: string +} + export type JoinWorkspaceReply = { message?: string } @@ -1466,6 +1483,7 @@ export type JoinWorkspaceReq = { id?: string email?: string role?: string + rbacRole?: string } export type LicenseInfo = { @@ -1667,12 +1685,9 @@ export type PluginInstallationSettingsReply = { export type RbacRole = { id?: string - type?: string name?: string description?: string - isBuiltin?: boolean - category?: string - permissionKeys?: Array + permissions?: Array } export type ResetMemberPasswordReply = { @@ -1813,7 +1828,7 @@ export type SetDefaultWorkspaceReq = { export type Subject = { subjectId?: string - subjectType?: SubjectType + subjectType?: string accountData?: SubjectAccountData groupData?: SubjectGroupData } diff --git a/packages/contracts/generated/enterprise/zod.gen.ts b/packages/contracts/generated/enterprise/zod.gen.ts index d7a42b35d4c..85f74b22121 100644 --- a/packages/contracts/generated/enterprise/zod.gen.ts +++ b/packages/contracts/generated/enterprise/zod.gen.ts @@ -9,10 +9,10 @@ export const zAccessMode = z.enum([ 'ACCESS_MODE_PRIVATE_ALL', ]) -export const zSubjectType = z.enum([ - 'SUBJECT_TYPE_UNSPECIFIED', - 'SUBJECT_TYPE_ACCOUNT', - 'SUBJECT_TYPE_GROUP', +export const zAccessSubjectType = z.enum([ + 'ACCESS_SUBJECT_TYPE_UNSPECIFIED', + 'ACCESS_SUBJECT_TYPE_ACCOUNT', + 'ACCESS_SUBJECT_TYPE_GROUP', ]) export const zAppRunnerLogStatus = z.enum([ @@ -203,7 +203,7 @@ export const zLimitStatus = z.enum([ ]) export const zAccessSubject = z.object({ - subjectType: zSubjectType, + subjectType: zAccessSubjectType, subjectId: z.string(), }) @@ -254,10 +254,6 @@ export const zAppInstance = z.object({ updatedAt: z.iso.datetime(), }) -/** - * BootstrapAssignment is one runtime_instance assignment in a runner's startup - * baseline. - */ export const zBootstrapAssignment = z.object({ appId: z.string().optional(), environmentId: z.string().optional(), @@ -322,10 +318,6 @@ export const zCreateReleaseRequest = z.object({ sourceAppId: z.string().optional(), }) -/** - * CredentialCandidate is one tenant-visible credential a frontend may - * pick for a credential slot. It carries no secret. - */ export const zCredentialCandidate = z.object({ credentialId: z.string(), providerId: z.string(), @@ -334,20 +326,12 @@ export const zCredentialCandidate = z.object({ fromEnterprise: z.boolean(), }) -/** - * CredentialSelectionInput is one deploy-time plugin-credential - * selection: a shared credential id chosen for a required DSL slot. - */ export const zCredentialSelectionInput = z.object({ providerId: z.string(), category: zPluginCategory.optional(), credentialId: z.string(), }) -/** - * CredentialSlot is one model/tool plugin-credential requirement a - * Release's DSL declares, paired with the candidates selectable for it. - */ export const zCredentialSlot = z.object({ providerId: z.string(), category: zPluginCategory, @@ -406,10 +390,6 @@ export const zEnvironmentDeploymentRecord = z.object({ finalizedAt: z.iso.datetime().optional(), }) -/** - * Error is the package-wide failure shape, carried wherever an operation or - * resource reports an error. - */ export const zError = z.object({ code: z.string().optional(), message: z.string().optional(), @@ -445,7 +425,6 @@ export const zEnvironment = z.object({ status: zEnvironmentStatus, statusMessage: z.string(), lastError: zError.optional(), - apiServer: z.string().optional(), namespace: z.string().optional(), managedBy: z.string().optional(), runtimeEndpoint: z.string().optional(), @@ -523,9 +502,6 @@ export const zGetEnvironmentResponse = z.object({ export const zK8sEnvironmentConfig = z.object({ namespace: z.string().optional(), - apiServer: z.string().optional(), - caBundle: z.string().optional(), - bearerToken: z.string().optional(), }) export const zCreateEnvironmentRequest = z.object({ @@ -571,9 +547,6 @@ export const zDeployRequest = z.object({ expectedDslDigest: z.string().optional(), }) -/** - * Operator is who triggered the run (the "END USER OR ACCOUNT" column). - */ export const zOperator = z.object({ type: zOperatorType, id: z.string(), @@ -620,10 +593,6 @@ export const zPromoteRequest = z.object({ idempotencyKey: z.string(), }) -/** - * ReleaseContentMatch identifies an existing release whose DSL content is - * identical to the checked content. - */ export const zReleaseContentMatch = z.object({ releaseId: z.string(), displayName: z.string(), @@ -638,11 +607,6 @@ export const zReleaseEnvironmentAction = z.object({ currentReleaseId: z.string(), }) -/** - * ReleaseEnvironmentDeployment is an environment where the release is the - * active deployment, paired with that environment's runtime status so the - * version history can show running vs failed vs deploying. - */ export const zReleaseEnvironmentDeployment = z.object({ environment: zEnvironment, status: zRuntimeInstanceStatus, @@ -663,10 +627,6 @@ export const zReportRuntimeAssignmentStatusResponse = z.object({ stale: z.boolean().optional(), }) -/** - * RequiredSlot is an input requirement extracted from a Release's - * DSL. - */ export const zRequiredSlot = z.object({ type: zSlotType, providerId: z.string(), @@ -715,10 +675,6 @@ export const zDeployResponse = z.object({ deployment: zDeployment, }) -/** - * EnvironmentAppInstance is one app instance as seen from a single environment: - * its current release, runtime status, and derived last error in THIS env. - */ export const zEnvironmentAppInstance = z.object({ appInstance: zAppInstance.optional(), currentRelease: zRelease.optional(), @@ -759,10 +715,6 @@ export const zComputeReleaseDeploymentViewResponse = z.object({ options: zDeploymentOptions.optional(), }) -/** - * EnvironmentDeploymentHistoryItem is one deployment row in an environment's - * history, with a thin reference to the owning app instance. - */ export const zEnvironmentDeploymentHistoryItem = z.object({ deployment: zDeployment.optional(), appInstanceId: z.string().optional(), @@ -904,20 +856,25 @@ export const zUndeployResponse = z.object({ deployment: zDeployment, }) -/** - * UnsupportedDslNode identifies a workflow node whose type the app runner - * cannot execute. - */ export const zUnsupportedDslNode = z.object({ id: z.string(), type: z.string(), }) +export const zUnsupportedToolProvider = z.object({ + nodeId: z.string(), + providerType: z.string(), + providerId: z.string().optional(), + providerName: z.string().optional(), + toolName: z.string().optional(), +}) + export const zPrecheckReleaseResponse = z.object({ gateCommitId: z.string(), canCreate: z.boolean(), matchedRelease: zReleaseContentMatch.optional(), unsupportedNodes: z.array(zUnsupportedDslNode), + unsupportedToolProviders: z.array(zUnsupportedToolProvider), }) export const zUpdateAccessChannelsRequest = z.object({ @@ -1302,6 +1259,19 @@ export const zIsUserAllowedToAccessWebAppRes = z.object({ result: z.boolean().optional(), }) +export const zIssueMcpTokenReply = z.object({ + token: z.string().optional(), + expiresAt: z.string().optional(), + tokenType: z.string().optional(), +}) + +export const zIssueMcpTokenReq = z.object({ + userId: z.string().optional(), + tenantId: z.string().optional(), + appId: z.string().optional(), + audience: z.string().optional(), +}) + export const zJoinWorkspaceReply = z.object({ message: z.string().optional(), }) @@ -1313,6 +1283,7 @@ export const zJoinWorkspaceReq = z.object({ id: z.string().optional(), email: z.string().optional(), role: z.string().optional(), + rbacRole: z.string().optional(), }) export const zLimitConfig = z.object({ @@ -1494,12 +1465,9 @@ export const zPluginInstallationSettingsReply = z.object({ export const zRbacRole = z.object({ id: z.string().optional(), - type: z.string().optional(), name: z.string().optional(), description: z.string().optional(), - isBuiltin: z.boolean().optional(), - category: z.string().optional(), - permissionKeys: z.array(z.string()).optional(), + permissions: z.array(z.string()).optional(), }) export const zGetMemberRbacRolesReply = z.object({ @@ -1778,7 +1746,7 @@ export const zGetWebAppWhitelistSubjectsRes = z.object({ */ export const zSubject = z.object({ subjectId: z.string().optional(), - subjectType: zSubjectType.optional(), + subjectType: z.string().optional(), accountData: zSubjectAccountData.optional(), groupData: zSubjectGroupData.optional(), }) @@ -2104,7 +2072,6 @@ export const zInfoConfigReply = z.object({ Branding: zBrandingInfo.optional(), WebAppAuth: zWebAppAuthInfo.optional(), PluginInstallationPermission: zPluginInstallationPermissionInfo.optional(), - EnableAppDeploy: z.boolean().optional(), }) export const zWebOAuth2LoginReply = z.object({ diff --git a/packages/dify-ui/src/themes/dark.css b/packages/dify-ui/src/themes/dark.css index 1b24e8fb489..3f4a163a725 100644 --- a/packages/dify-ui/src/themes/dark.css +++ b/packages/dify-ui/src/themes/dark.css @@ -162,6 +162,7 @@ html[data-theme="dark"] { --color-components-main-nav-glass-surface-middle-2: #0033ff1a; --color-components-main-nav-glass-surface-end: #0033ff14; --color-components-main-nav-glass-edge-highlight-first: #fffffffa; + --color-components-main-nav-glass-edge-highlight-middle: #ffffff00; --color-components-main-nav-glass-edge-highlight-end: #ffffff6b; --color-components-main-nav-glass-edge-reflection-first: #0033ff00; --color-components-main-nav-glass-edge-reflection-middle: #0033ff99; diff --git a/packages/dify-ui/src/themes/light.css b/packages/dify-ui/src/themes/light.css index 3feb4afb47f..dd3252f3614 100644 --- a/packages/dify-ui/src/themes/light.css +++ b/packages/dify-ui/src/themes/light.css @@ -162,6 +162,7 @@ html[data-theme="light"] { --color-components-main-nav-glass-surface-middle-2: #0033ff1a; --color-components-main-nav-glass-surface-end: #0033ff14; --color-components-main-nav-glass-edge-highlight-first: #fffffffa; + --color-components-main-nav-glass-edge-highlight-middle: #ffffff00; --color-components-main-nav-glass-edge-highlight-end: #ffffff6b; --color-components-main-nav-glass-edge-reflection-first: #0033ff00; --color-components-main-nav-glass-edge-reflection-middle: #0033ff99; diff --git a/packages/dify-ui/src/themes/theme.css b/packages/dify-ui/src/themes/theme.css index 3e35feb8eb8..c14e54ea549 100644 --- a/packages/dify-ui/src/themes/theme.css +++ b/packages/dify-ui/src/themes/theme.css @@ -169,6 +169,7 @@ --color-components-main-nav-glass-surface-middle-2: var(--color-components-main-nav-glass-surface-middle-2); --color-components-main-nav-glass-surface-end: var(--color-components-main-nav-glass-surface-end); --color-components-main-nav-glass-edge-highlight-first: var(--color-components-main-nav-glass-edge-highlight-first); + --color-components-main-nav-glass-edge-highlight-middle: var(--color-components-main-nav-glass-edge-highlight-middle); --color-components-main-nav-glass-edge-highlight-end: var(--color-components-main-nav-glass-edge-highlight-end); --color-components-main-nav-glass-edge-reflection-first: var(--color-components-main-nav-glass-edge-reflection-first); --color-components-main-nav-glass-edge-reflection-middle: var(--color-components-main-nav-glass-edge-reflection-middle); diff --git a/web/.env.example b/web/.env.example index 112232e529c..7363ce628f3 100644 --- a/web/.env.example +++ b/web/.env.example @@ -116,3 +116,4 @@ NEXT_PUBLIC_ENABLE_CHANGE_EMAIL=true NEXT_PUBLIC_CREATORS_PLATFORM_FEATURES_ENABLED=true NEXT_PUBLIC_ENABLE_TRIAL_APP=true NEXT_PUBLIC_ENABLE_EXPLORE_BANNER=true +NEXT_PUBLIC_RBAC_ENABLED=false diff --git a/web/__tests__/env.spec.ts b/web/__tests__/env.spec.ts index 89781d32685..419dcadf030 100644 --- a/web/__tests__/env.spec.ts +++ b/web/__tests__/env.spec.ts @@ -1,5 +1,6 @@ describe('env runtime transport', () => { const originalAgentV2Env = process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 + const originalRbacEnv = process.env.NEXT_PUBLIC_RBAC_ENABLED beforeEach(() => { vi.clearAllMocks() @@ -7,7 +8,9 @@ describe('env runtime transport', () => { vi.doUnmock('../utils/client') document.body.removeAttribute('data-enable-agent-v2') document.body.removeAttribute('data-enable-agent-v-2') + document.body.removeAttribute('data-rbac-enabled') delete process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 + delete process.env.NEXT_PUBLIC_RBAC_ENABLED }) afterAll(() => { @@ -15,6 +18,11 @@ describe('env runtime transport', () => { delete process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 else process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 = originalAgentV2Env + + if (originalRbacEnv === undefined) + delete process.env.NEXT_PUBLIC_RBAC_ENABLED + else + process.env.NEXT_PUBLIC_RBAC_ENABLED = originalRbacEnv }) it('should read NEXT_PUBLIC_ENABLE_AGENT_V2 from the browser runtime dataset key', async () => { @@ -25,6 +33,14 @@ describe('env runtime transport', () => { expect(env.NEXT_PUBLIC_ENABLE_AGENT_V2).toBe(true) }) + it('should read NEXT_PUBLIC_RBAC_ENABLED from the browser runtime dataset key', async () => { + document.body.setAttribute('data-rbac-enabled', 'true') + + const { env } = await import('../env') + + expect(env.NEXT_PUBLIC_RBAC_ENABLED).toBe(true) + }) + it('should emit the Agent v2 runtime dataset attribute from getDatasetMap on the server', async () => { process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 = 'true' @@ -39,4 +55,18 @@ describe('env runtime transport', () => { expect(datasetMap['data-enable-agent-v2']).toBe(true) expect(datasetMap['data-enable-agent-v-2']).toBeUndefined() }) + + it('should emit the RBAC runtime dataset attribute from getDatasetMap on the server', async () => { + process.env.NEXT_PUBLIC_RBAC_ENABLED = 'true' + + vi.doMock('../utils/client', () => ({ + isClient: false, + isServer: true, + })) + + const { getDatasetMap } = await import('../env') + const datasetMap = getDatasetMap() + + expect(datasetMap['data-rbac-enabled']).toBe(true) + }) }) diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 6092618bfb3..71baf83c795 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -27,25 +27,26 @@ export default async function Layout({ children }: { children: ReactNode }) { - - - - - - - - {children} - - - - - - - - - - - + + + + + + + + {children} + + + + + + + + + + + + diff --git a/web/app/components/app/app-access-control/__tests__/access-control.spec.tsx b/web/app/components/app/app-access-control/__tests__/access-control.spec.tsx index c8bce5401e6..7cfa29fe97a 100644 --- a/web/app/components/app/app-access-control/__tests__/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/__tests__/access-control.spec.tsx @@ -1,6 +1,6 @@ import type { AccessControlAccount, AccessControlGroup, Subject } from '@/models/access-control' import type { App } from '@/types/app' -import { SubjectType as EnterpriseSubjectType } from '@dify/contracts/enterprise/types.gen' +import { AccessSubjectType as EnterpriseSubjectType } from '@dify/contracts/enterprise/types.gen' import { toast } from '@langgenius/dify-ui/toast' import { fireEvent, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' @@ -375,8 +375,8 @@ describe('AccessControl', () => { appId: app.id, accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, subjects: [ - { subjectId: baseGroup.id, subjectType: EnterpriseSubjectType.SUBJECT_TYPE_GROUP }, - { subjectId: baseMember.id, subjectType: EnterpriseSubjectType.SUBJECT_TYPE_ACCOUNT }, + { subjectId: baseGroup.id, subjectType: EnterpriseSubjectType.ACCESS_SUBJECT_TYPE_GROUP }, + { subjectId: baseMember.id, subjectType: EnterpriseSubjectType.ACCESS_SUBJECT_TYPE_ACCOUNT }, ], }, }, diff --git a/web/app/components/app/app-access-control/index.tsx b/web/app/components/app/app-access-control/index.tsx index 18c28d12757..13aa079092a 100644 --- a/web/app/components/app/app-access-control/index.tsx +++ b/web/app/components/app/app-access-control/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { Subject as EnterpriseSubject } from '@dify/contracts/enterprise/types.gen' import type { App } from '@/types/app' -import { SubjectType as EnterpriseSubjectType } from '@dify/contracts/enterprise/types.gen' +import { AccessSubjectType as EnterpriseSubjectType } from '@dify/contracts/enterprise/types.gen' import { toast } from '@langgenius/dify-ui/toast' import { useMutation, useSuspenseQuery } from '@tanstack/react-query' import { useTranslation } from 'react-i18next' @@ -94,12 +94,12 @@ function AccessControlForm({ if (currentMenu === AccessMode.SPECIFIC_GROUPS_MEMBERS) { const subjects: Pick[] = [] specificGroups.forEach((group) => { - subjects.push({ subjectId: group.id, subjectType: EnterpriseSubjectType.SUBJECT_TYPE_GROUP }) + subjects.push({ subjectId: group.id, subjectType: EnterpriseSubjectType.ACCESS_SUBJECT_TYPE_GROUP }) }) specificMembers.forEach((member) => { subjects.push({ subjectId: member.id, - subjectType: EnterpriseSubjectType.SUBJECT_TYPE_ACCOUNT, + subjectType: EnterpriseSubjectType.ACCESS_SUBJECT_TYPE_ACCOUNT, }) }) submitData.subjects = subjects diff --git a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx index 81b08046a94..741580f6a5c 100644 --- a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx +++ b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx @@ -248,7 +248,7 @@ describe('AccountDropdown', () => { fireEvent.click(screen.getByText('common.settings.preferences')) // Assert - expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ payload: ACCOUNT_SETTING_TAB.LANGUAGE }) + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ payload: ACCOUNT_SETTING_TAB.PREFERENCES }) }) it('should show Appearance after Preferences in the main nav account dropdown', () => { diff --git a/web/app/components/header/account-dropdown/main-nav-menu-content.tsx b/web/app/components/header/account-dropdown/main-nav-menu-content.tsx index ee2ebcc981e..1b84397a1ed 100644 --- a/web/app/components/header/account-dropdown/main-nav-menu-content.tsx +++ b/web/app/components/header/account-dropdown/main-nav-menu-content.tsx @@ -127,7 +127,7 @@ export function MainNavMenuContent({ setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.LANGUAGE })} + onClick={() => setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PREFERENCES })} > { expect(ACCOUNT_SETTING_TAB.DATA_SOURCE).toBe('data-source') expect(ACCOUNT_SETTING_TAB.API_BASED_EXTENSION).toBe('custom-endpoint') expect(ACCOUNT_SETTING_TAB.CUSTOM).toBe('custom') + expect(ACCOUNT_SETTING_TAB.PREFERENCES).toBe('preferences') expect(ACCOUNT_SETTING_TAB.LANGUAGE).toBe('language') }) @@ -42,6 +43,7 @@ describe('AccountSetting Constants', () => { expect(isValidAccountSettingTab('data-source')).toBe(true) expect(isValidAccountSettingTab('custom-endpoint')).toBe(true) expect(isValidAccountSettingTab('custom')).toBe(true) + expect(isValidAccountSettingTab('preferences')).toBe(true) expect(isValidAccountSettingTab('language')).toBe(true) }) @@ -55,6 +57,7 @@ describe('AccountSetting Constants', () => { expect(isValidSettingsTab('permissions')).toBe(true) expect(isValidSettingsTab('access-rules')).toBe(true) expect(isValidSettingsTab('billing')).toBe(true) + expect(isValidSettingsTab('preferences')).toBe(true) expect(isValidSettingsTab('language')).toBe(true) expect(isValidSettingsTab('provider')).toBe(true) expect(isValidSettingsTab('mcp')).toBe(true) diff --git a/web/app/components/header/account-setting/__tests__/index.spec.tsx b/web/app/components/header/account-setting/__tests__/index.spec.tsx index ec1ac9887d2..7113714ac4f 100644 --- a/web/app/components/header/account-setting/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/__tests__/index.spec.tsx @@ -257,6 +257,17 @@ describe('AccountSetting', () => { expect(screen.getByText('common.settings.dataSource'))!.toBeInTheDocument() }) + it('should normalize legacy language tab entries to preferences', () => { + // Act + renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.LANGUAGE }) + + // Assert + const preferencesButton = screen.getByRole('button', { name: 'common.settings.preferences' }) + expect(preferencesButton.querySelector('.i-ri-equalizer-2-fill')).toBeInTheDocument() + expect(screen.getByText('common.account.general')).toBeInTheDocument() + expect(screen.getByText('common.account.appearanceLabel')).toBeInTheDocument() + }) + it('should hide sidebar labels on mobile', () => { // Arrange vi.mocked(useBreakpoints).mockReturnValue(MediaType.mobile) diff --git a/web/app/components/header/account-setting/__tests__/update-setting-dialog-form.spec.tsx b/web/app/components/header/account-setting/__tests__/update-setting-dialog-form.spec.tsx new file mode 100644 index 00000000000..7ca34890a35 --- /dev/null +++ b/web/app/components/header/account-setting/__tests__/update-setting-dialog-form.spec.tsx @@ -0,0 +1,88 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import * as React from 'react' +import { AUTO_UPDATE_MODE, AUTO_UPDATE_STRATEGY } from '@/app/components/plugins/reference-setting-modal/auto-update-setting/types' +import { PluginCategoryEnum } from '@/app/components/plugins/types' +import { ACCOUNT_SETTING_TAB } from '../constants' +import UpdateSettingDialogForm from '../update-setting-dialog-form' + +const mockSetShowAccountSettingModal = vi.fn() + +vi.mock('@/context/modal-context', () => ({ + useModalContextSelector: (selector: (s: { setShowAccountSettingModal: typeof mockSetShowAccountSettingModal }) => typeof mockSetShowAccountSettingModal) => { + return selector({ setShowAccountSettingModal: mockSetShowAccountSettingModal }) + }, +})) + +vi.mock('react-i18next', () => ({ + useTranslation: (defaultNs?: string) => ({ + t: (key: string, options?: Record) => { + const ns = (options?.ns as string | undefined) ?? defaultNs + return `${ns ? `${ns}.` : ''}${key}` + }, + i18n: { + language: 'en', + changeLanguage: vi.fn(), + }, + }), + Trans: ({ i18nKey, components }: { + i18nKey: string + components?: Record + }) => { + const setTimezone = components?.setTimezone + if (setTimezone) + return React.cloneElement(setTimezone, undefined, i18nKey) + + return {i18nKey} + }, +})) + +vi.mock('@/app/components/base/date-and-time-picker/time-picker', () => ({ + default: () =>
, +})) + +vi.mock('@/app/components/plugins/reference-setting-modal/auto-update-setting/plugins-picker', () => ({ + default: () =>
, +})) + +describe('UpdateSettingDialogForm', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should open preferences after closing the update setting dialog when timezone link is clicked', () => { + const onRequestClose = vi.fn() + + render( + minutes} + onAutoUpgradeChange={vi.fn()} + onPluginsChange={vi.fn()} + onRequestClose={onRequestClose} + onUpdateTimeChange={vi.fn()} + renderTimePickerTrigger={() => } + />, + ) + + fireEvent.click(screen.getByText('autoUpdate.changeTimezone')) + + expect(onRequestClose).toHaveBeenCalledTimes(1) + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ payload: ACCOUNT_SETTING_TAB.PREFERENCES }) + }) +}) diff --git a/web/app/components/header/account-setting/constants.ts b/web/app/components/header/account-setting/constants.ts index d8128631306..67a89afc0c2 100644 --- a/web/app/components/header/account-setting/constants.ts +++ b/web/app/components/header/account-setting/constants.ts @@ -12,6 +12,7 @@ export const ACCOUNT_SETTING_TAB = { DATA_SOURCE: 'data-source', API_BASED_EXTENSION: 'custom-endpoint', CUSTOM: 'custom', + PREFERENCES: 'preferences', LANGUAGE: 'language', } as const @@ -30,6 +31,7 @@ const WORKSPACE_SETTING_TAB_VALUES = [ export type WorkspaceSettingTab = typeof WORKSPACE_SETTING_TAB_VALUES[number] const USER_SETTING_TAB_VALUES = [ + ACCOUNT_SETTING_TAB.PREFERENCES, ACCOUNT_SETTING_TAB.LANGUAGE, ] as const diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index b04404bde17..a03fb30cea2 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -20,11 +20,11 @@ import { BillingPermission, hasPermission } from '@/utils/permission' import AccessRulesPage from './access-rules-page' import { ApiBasedExtensionPage } from './api-based-extension-page' import DataSourcePage from './data-source-page-new' -import LanguagePage from './language-page' import MembersPage from './members-page' import ModelProviderPage from './model-provider-page' import { useResetModelProviderListExpanded } from './model-provider-page/atoms' import PermissionsPage from './permissions-page' +import PreferencePage from './preference-page' const iconClassName = ` w-4 h-4 mr-2 @@ -58,12 +58,14 @@ export default function AccountSetting({ const isRbacEnabled = systemFeatures.rbac_enabled const canManageWorkspaceRoles = isRbacEnabled && hasPermission(workspacePermissionKeys, 'workspace.role.manage') const canViewBilling = enableBilling && hasPermission(workspacePermissionKeys, BillingPermission.View) + // Keep legacy `language` deep links opening Preferences during the tab rename migration. + const normalizedActiveTab = activeTab === ACCOUNT_SETTING_TAB.LANGUAGE ? ACCOUNT_SETTING_TAB.PREFERENCES : activeTab const activeMenu = (() => { - if (activeTab === ACCOUNT_SETTING_TAB.BILLING && !canViewBilling) - return ACCOUNT_SETTING_TAB.LANGUAGE - if ((activeTab === ACCOUNT_SETTING_TAB.PERMISSIONS || activeTab === ACCOUNT_SETTING_TAB.ACCESS_RULES) && !canManageWorkspaceRoles) + if (normalizedActiveTab === ACCOUNT_SETTING_TAB.BILLING && !canViewBilling) + return ACCOUNT_SETTING_TAB.PREFERENCES + if ((normalizedActiveTab === ACCOUNT_SETTING_TAB.PERMISSIONS || normalizedActiveTab === ACCOUNT_SETTING_TAB.ACCESS_RULES) && !canManageWorkspaceRoles) return ACCOUNT_SETTING_TAB.MEMBERS - return activeTab + return normalizedActiveTab })() const scrollContainerRef = useRef(null) @@ -119,7 +121,7 @@ export default function AccountSetting({ activeIcon: , }, { - key: ACCOUNT_SETTING_TAB.LANGUAGE, + key: ACCOUNT_SETTING_TAB.PREFERENCES, name: t('settings.preferences', { ns: 'common' }), title: t('account.general', { ns: 'common' }), icon: , @@ -151,7 +153,7 @@ export default function AccountSetting({ const media = useBreakpoints() const isMobile = media === MediaType.mobile - const languageItem = settingItems.find(item => item.key === ACCOUNT_SETTING_TAB.LANGUAGE) + const preferenceItem = settingItems.find(item => item.key === ACCOUNT_SETTING_TAB.PREFERENCES) const menuItems = [ { @@ -161,7 +163,7 @@ export default function AccountSetting({ }, { key: 'user-group', - items: languageItem ? [languageItem] : [], + items: preferenceItem ? [preferenceItem] : [], }, ] @@ -266,7 +268,7 @@ export default function AccountSetting({ {activeMenu === ACCOUNT_SETTING_TAB.DATA_SOURCE && } {activeMenu === ACCOUNT_SETTING_TAB.API_BASED_EXTENSION && } {activeMenu === ACCOUNT_SETTING_TAB.CUSTOM && } - {activeMenu === ACCOUNT_SETTING_TAB.LANGUAGE && } + {activeMenu === ACCOUNT_SETTING_TAB.PREFERENCES && }
diff --git a/web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx b/web/app/components/header/account-setting/preference-page/__tests__/index.spec.tsx similarity index 96% rename from web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx rename to web/app/components/header/account-setting/preference-page/__tests__/index.spec.tsx index edeb14cb1c3..55ee81481ad 100644 --- a/web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/preference-page/__tests__/index.spec.tsx @@ -4,7 +4,7 @@ import { act, fireEvent, render, screen, waitFor, within } from '@testing-librar import { languages } from '@/i18n-config/language' import { updateUserProfile } from '@/service/common' import { timezones } from '@/utils/timezone' -import LanguagePage from '../index' +import PreferencePage from '../index' const mockRefresh = vi.fn() const mockMutateUserProfile = vi.fn() @@ -54,7 +54,7 @@ vi.mock('@langgenius/dify-ui/select', async () => { SelectItem: ({ children, value }: { children: React.ReactNode, value: string }) => { const context = React.useContext(SelectContext) return ( - ) @@ -104,7 +104,7 @@ const createUserProfile = (overrides: Partial = {}): const renderPage = () => { render( <> - + , ) @@ -150,7 +150,7 @@ beforeEach(() => { }) // Rendering -describe('LanguagePage - Rendering', () => { +describe('PreferencePage - Rendering', () => { it('should render default language and timezone labels', () => { const english = getLanguageOption('en-US') const niueTimezone = getTimezoneOption('Pacific/Niue') @@ -182,7 +182,7 @@ describe('LanguagePage - Rendering', () => { }) // Interactions -describe('LanguagePage - Interactions', () => { +describe('PreferencePage - Interactions', () => { it('should show success toast when language updates', async () => { const chinese = getLanguageOption('zh-Hans') mockUserProfile = createUserProfile({ interface_language: 'en-US' }) diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/preference-page/index.tsx similarity index 99% rename from web/app/components/header/account-setting/language-page/index.tsx rename to web/app/components/header/account-setting/preference-page/index.tsx index fabd84b4c71..ce59fe59740 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/preference-page/index.tsx @@ -33,7 +33,7 @@ const isThemeOption = (value: string): value is ThemeOption => { return (themes as readonly string[]).includes(value) } -export default function LanguagePage() { +export default function PreferencePage() { const locale = useLocale() const { userProfile, mutateUserProfile } = useAppContext() const [editing, setEditing] = useState(false) diff --git a/web/app/components/header/account-setting/update-setting-dialog-form.tsx b/web/app/components/header/account-setting/update-setting-dialog-form.tsx index 65efcf9c9d1..bdfbd5720ab 100644 --- a/web/app/components/header/account-setting/update-setting-dialog-form.tsx +++ b/web/app/components/header/account-setting/update-setting-dialog-form.tsx @@ -53,7 +53,7 @@ function SettingTimeZone({ className="cursor-pointer border-none bg-transparent p-0 text-left body-xs-regular text-text-accent focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:outline-hidden" onClick={() => { onRequestClose() - setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.LANGUAGE }) + setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PREFERENCES }) }} > {children} diff --git a/web/app/components/main-nav/__tests__/index.spec.tsx b/web/app/components/main-nav/__tests__/index.spec.tsx index 0dad69b1d16..497b31ea314 100644 --- a/web/app/components/main-nav/__tests__/index.spec.tsx +++ b/web/app/components/main-nav/__tests__/index.spec.tsx @@ -25,7 +25,8 @@ import { AppModeEnum } from '@/types/app' import MainNav from '../index' import { DETAIL_SIDEBAR_STORAGE_KEY } from '../storage' -const activeEdgeClassName = 'before:pointer-events-none' +const activeGradientMaskClassName = 'aria-[current=page]:main-nav-active-glass' +const activeStackingClassName = 'aria-[current=page]:z-1' type SnippetNavigationTestState = { onFieldsChange?: (fields: SnippetInputField[]) => void @@ -503,7 +504,7 @@ describe('MainNav', () => { expect(logoLink.parentElement).toHaveClass('pt-3', 'pr-2', 'pb-2', 'pl-4') const homeLink = screen.getByRole('link', { name: /common.mainNav.home/ }) - expect(homeLink.closest('nav')).toHaveClass('flex', 'flex-col', 'gap-px', 'p-2') + expect(homeLink.closest('nav')).toHaveClass('isolate', 'flex', 'flex-col', 'gap-px', 'p-2') expect(homeLink).toHaveClass('h-8', 'w-full', 'rounded-[10px]', 'px-2', 'py-1.5') const webAppsButton = screen.getByRole('button', { name: 'explore.sidebar.webApps' }) @@ -641,8 +642,7 @@ describe('MainNav', () => { renderMainNav() const datasetsLink = screen.getByRole('link', { name: /common.menus.datasets/ }) - expect(datasetsLink.className).toContain('bg-[linear-gradient(98.077deg') - expect(datasetsLink).toHaveClass(activeEdgeClassName) + expect(datasetsLink).toHaveClass(activeGradientMaskClassName) expect(datasetsLink).toHaveAttribute('aria-current', 'page') expect(screen.getByRole('link', { name: /common.mainNav.home/ })).not.toHaveAttribute('aria-current') }) @@ -653,7 +653,7 @@ describe('MainNav', () => { renderMainNav() const studioLink = screen.getByRole('link', { name: /common.menus.apps/ }) - expect(studioLink).toHaveClass(activeEdgeClassName) + expect(studioLink).toHaveClass(activeGradientMaskClassName) expect(studioLink).toHaveAttribute('aria-current', 'page') expect(screen.getByRole('link', { name: /common.mainNav.home/ })).not.toHaveAttribute('aria-current') }) @@ -959,7 +959,7 @@ describe('MainNav', () => { renderMainNav() const marketplaceLink = screen.getByRole('link', { name: /common.mainNav.marketplace/ }) - expect(marketplaceLink).toHaveClass(activeEdgeClassName) + expect(marketplaceLink).toHaveClass(activeGradientMaskClassName) }) it('marks roster active on roster routes', () => { @@ -968,7 +968,7 @@ describe('MainNav', () => { renderMainNav() const rosterLink = screen.getByRole('link', { name: /common.menus.roster/ }) - expect(rosterLink).toHaveClass(activeEdgeClassName) + expect(rosterLink).toHaveClass(activeGradientMaskClassName) expect(rosterLink).toHaveAttribute('aria-current', 'page') }) @@ -979,13 +979,8 @@ describe('MainNav', () => { const homeLink = screen.getByRole('link', { name: /common.mainNav.home/ }) - expect(homeLink).toHaveClass( - 'backdrop-blur-[5px]', - 'text-saas-dify-blue-inverted', - activeEdgeClassName, - 'after:border-components-main-nav-glass-edge-highlight-first', - ) - expect(homeLink.className).toContain('var(--color-components-main-nav-glass-surface-first)') + expect(homeLink).toHaveClass(activeGradientMaskClassName) + expect(homeLink).toHaveClass(activeStackingClassName) }) it('keeps Home active on the legacy explore apps route only', () => { diff --git a/web/app/components/main-nav/components/nav-link.css b/web/app/components/main-nav/components/nav-link.css new file mode 100644 index 00000000000..20839c746ea --- /dev/null +++ b/web/app/components/main-nav/components/nav-link.css @@ -0,0 +1,48 @@ +@utility main-nav-active-glass { + @apply overflow-hidden system-md-semibold text-saas-dify-blue-inverted backdrop-blur-[5px]; + + background-image: linear-gradient( + 91.46deg, + var(--color-components-main-nav-glass-surface-first, rgb(0 51 255 / 0.08)) 0%, + var(--color-components-main-nav-glass-surface-middle-1, rgb(0 51 255 / 0.12)) 17.98%, + var(--color-components-main-nav-glass-surface-middle-2, rgb(0 51 255 / 0.1)) 58.75%, + var(--color-components-main-nav-glass-surface-end, rgb(0 51 255 / 0.08)) 101.09% + ); + box-shadow: + 0 4px 8px 0 var(--color-components-main-nav-glass-shadow-reflection-glow), + 0 10px 12px -4px var(--color-shadow-shadow-4), + 0 3px 5px -2px var(--color-shadow-shadow-1), + 0 8px 16px -4px var(--color-components-main-nav-glass-shadow-reflection); + + &::before { + content: ""; + pointer-events: none; + position: absolute; + inset: 0; + border-radius: inherit; + border: 1px solid transparent; + background: linear-gradient( + 0deg, + var(--color-components-main-nav-glass-edge-reflection-first, rgb(0 51 255 / 0)) 0%, + var(--color-components-main-nav-glass-edge-reflection-middle, rgb(0 51 255 / 0.6)) 50%, + var(--color-components-main-nav-glass-edge-reflection-end, rgb(0 51 255 / 0)) 100% + ) border-box; + -webkit-mask: linear-gradient(#fff 0 0) padding-box, linear-gradient(#fff 0 0); + -webkit-mask-composite: destination-out; + mask-composite: exclude; + } + + &::after { + content: ""; + pointer-events: none; + position: absolute; + inset: 0; + border-radius: inherit; + border: 1px solid transparent; + background: linear-gradient(180deg, var(--color-components-main-nav-glass-edge-highlight-first, rgb(255 255 255 / 0.98)) 0%, var(--color-components-main-nav-glass-edge-highlight-middle, rgb(255 255 255 / 0)) 18%, var(--color-components-main-nav-glass-edge-highlight-end, rgb(255 255 255 / 0.42)) 100%) border-box; + -webkit-mask: linear-gradient(#fff 0 0) padding-box, linear-gradient(#fff 0 0); + -webkit-mask-composite: destination-out; + mask-composite: exclude; + box-shadow: inset 0 0 8px 0 var(--color-components-main-nav-glass-inner-glow); + } +} diff --git a/web/app/components/main-nav/components/nav-link.tsx b/web/app/components/main-nav/components/nav-link.tsx index 10865c94d18..628be088263 100644 --- a/web/app/components/main-nav/components/nav-link.tsx +++ b/web/app/components/main-nav/components/nav-link.tsx @@ -4,21 +4,6 @@ import type { MainNavItem } from '../types' import { cn } from '@langgenius/dify-ui/cn' import Link from '@/next/link' -const navItemClassName = 'group relative flex h-8 w-full items-center gap-2 rounded-[10px] px-2 py-1.5 outline-hidden transition-colors focus-visible:ring-2 focus-visible:ring-inset focus-visible:ring-state-accent-solid' - -const activeNavItemClassName = cn( - 'overflow-hidden', - 'bg-[linear-gradient(98.077deg,var(--color-components-main-nav-glass-surface-first)_0%,var(--color-components-main-nav-glass-surface-middle-1)_17.98%,var(--color-components-main-nav-glass-surface-middle-2)_58.75%,var(--color-components-main-nav-glass-surface-end)_101.09%)]', - 'system-md-semibold text-saas-dify-blue-inverted backdrop-blur-[5px]', - 'shadow-[0px_4px_8px_0px_var(--color-components-main-nav-glass-shadow-reflection-glow),0px_12px_16px_-4px_var(--color-shadow-shadow-4),0px_4px_6px_-2px_var(--color-shadow-shadow-1),0px_10px_16px_-4px_var(--color-components-main-nav-glass-shadow-reflection)]', - 'before:pointer-events-none before:absolute before:inset-0 before:rounded-[inherit] before:p-px before:content-[\'\']', - 'before:bg-[linear-gradient(var(--color-components-main-nav-glass-edge-highlight-first),var(--color-components-main-nav-glass-edge-highlight-first))_top/100%_1px_no-repeat,linear-gradient(var(--color-components-main-nav-glass-edge-highlight-end),var(--color-components-main-nav-glass-edge-highlight-end))_bottom/100%_1px_no-repeat,linear-gradient(180deg,var(--color-components-main-nav-glass-edge-reflection-first)_0%,var(--color-components-main-nav-glass-edge-reflection-middle)_50%,var(--color-components-main-nav-glass-edge-reflection-end)_100%)_left/1px_100%_no-repeat,linear-gradient(180deg,var(--color-components-main-nav-glass-edge-reflection-first)_0%,var(--color-components-main-nav-glass-edge-reflection-middle)_50%,var(--color-components-main-nav-glass-edge-reflection-end)_100%)_right/1px_100%_no-repeat]', - 'before:[mask-composite:exclude] before:[-webkit-mask-composite:xor] before:[-webkit-mask:linear-gradient(#000_0_0)_content-box,linear-gradient(#000_0_0)] before:[mask:linear-gradient(#000_0_0)_content-box,linear-gradient(#000_0_0)]', - 'after:pointer-events-none after:absolute after:inset-[-1px] after:rounded-[inherit] after:border after:border-components-main-nav-glass-edge-highlight-first after:shadow-[inset_0_0_8px_0_var(--color-components-main-nav-glass-inner-glow)] after:content-[\'\']', -) - -const inactiveNavItemClassName = 'system-md-medium bg-components-main-nav-nav-button-bg text-components-main-nav-nav-button-text hover:bg-components-main-nav-nav-button-bg-hover hover:text-components-main-nav-nav-button-text' - const NavIcon = ({ icon, className, @@ -46,12 +31,14 @@ const MainNavLink = ({ aria-current={activated ? 'page' : undefined} title={item.label} className={cn( - navItemClassName, - activated ? activeNavItemClassName : inactiveNavItemClassName, + 'group relative flex h-8 w-full items-center gap-2 rounded-[10px] px-2 py-1.5 outline-hidden transition-colors focus-visible:ring-2 focus-visible:ring-state-accent-solid focus-visible:ring-inset', + 'not-aria-[current=page]:bg-components-main-nav-nav-button-bg not-aria-[current=page]:system-md-medium not-aria-[current=page]:text-components-main-nav-nav-button-text not-aria-[current=page]:hover:bg-components-main-nav-nav-button-bg-hover not-aria-[current=page]:hover:text-components-main-nav-nav-button-text', + 'aria-[current=page]:main-nav-active-glass aria-[current=page]:z-1', )} > - - {item.label} + + + {item.label} ) } diff --git a/web/app/components/main-nav/index.tsx b/web/app/components/main-nav/index.tsx index fb2db5f3ff3..207e40fdc30 100644 --- a/web/app/components/main-nav/index.tsx +++ b/web/app/components/main-nav/index.tsx @@ -320,7 +320,7 @@ const MainNav = ({ : : ( <> -