mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 17:18:40 +08:00
merge main
This commit is contained in:
commit
a95cf6f8b0
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -49,8 +49,8 @@ jobs:
|
|||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: |
|
run: |
|
||||||
uv run --directory api ruff --version
|
uv run --directory api ruff --version
|
||||||
uv run --directory api ruff check --diff ./
|
uv run --directory api ruff check ./
|
||||||
uv run --directory api ruff format --check --diff ./
|
uv run --directory api ruff format --check ./
|
||||||
|
|
||||||
- name: Dotenv check
|
- name: Dotenv check
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
|||||||
@ -4,6 +4,11 @@
|
|||||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||||
SECRET_KEY=
|
SECRET_KEY=
|
||||||
|
|
||||||
|
# Ensure UTF-8 encoding
|
||||||
|
LANG=en_US.UTF-8
|
||||||
|
LC_ALL=en_US.UTF-8
|
||||||
|
PYTHONIOENCODING=utf-8
|
||||||
|
|
||||||
# Console API base URL
|
# Console API base URL
|
||||||
CONSOLE_API_URL=http://localhost:5001
|
CONSOLE_API_URL=http://localhost:5001
|
||||||
CONSOLE_WEB_URL=http://localhost:3000
|
CONSOLE_WEB_URL=http://localhost:3000
|
||||||
|
|||||||
@ -42,6 +42,8 @@ select = [
|
|||||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||||
"S311", # suspicious-non-cryptographic-random-usage
|
"S311", # suspicious-non-cryptographic-random-usage
|
||||||
|
"G001", # don't use str format to logging messages
|
||||||
|
"G004", # don't use f-strings to format logging messages
|
||||||
]
|
]
|
||||||
|
|
||||||
ignore = [
|
ignore = [
|
||||||
|
|||||||
@ -37,6 +37,11 @@ EXPOSE 5001
|
|||||||
# set timezone
|
# set timezone
|
||||||
ENV TZ=UTC
|
ENV TZ=UTC
|
||||||
|
|
||||||
|
# Set UTF-8 locale
|
||||||
|
ENV LANG=en_US.UTF-8
|
||||||
|
ENV LC_ALL=en_US.UTF-8
|
||||||
|
ENV PYTHONIOENCODING=utf-8
|
||||||
|
|
||||||
WORKDIR /app/api
|
WORKDIR /app/api
|
||||||
|
|
||||||
RUN \
|
RUN \
|
||||||
|
|||||||
@ -32,7 +32,7 @@ def create_app() -> DifyApp:
|
|||||||
initialize_extensions(app)
|
initialize_extensions(app)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)")
|
logging.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
@ -91,14 +91,14 @@ def initialize_extensions(app: DifyApp):
|
|||||||
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
|
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
|
||||||
if not is_enabled:
|
if not is_enabled:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logging.info(f"Skipped {short_name}")
|
logging.info("Skipped %s", short_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
ext.init_app(app)
|
ext.init_app(app)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
|
logging.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
|
||||||
|
|
||||||
|
|
||||||
def create_migrations_app():
|
def create_migrations_app():
|
||||||
|
|||||||
@ -54,13 +54,13 @@ def reset_password(email, new_password, password_confirm):
|
|||||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
|
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
valid_password(new_password)
|
valid_password(new_password)
|
||||||
except:
|
except:
|
||||||
click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red"))
|
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
# generate password salt
|
# generate password salt
|
||||||
@ -93,13 +93,13 @@ def reset_email(email, new_email, email_confirm):
|
|||||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
|
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
email_validate(new_email)
|
email_validate(new_email)
|
||||||
except:
|
except:
|
||||||
click.echo(click.style("Invalid email: {}".format(new_email), fg="red"))
|
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
account.email = new_email
|
account.email = new_email
|
||||||
@ -143,7 +143,7 @@ def reset_encrypt_key_pair():
|
|||||||
|
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(
|
click.style(
|
||||||
"Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id),
|
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||||
fg="green",
|
fg="green",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -191,14 +191,14 @@ def migrate_annotation_vector_database():
|
|||||||
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
|
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
click.echo("Creating app annotation index: {}".format(app.id))
|
click.echo(f"Creating app annotation index: {app.id}")
|
||||||
app_annotation_setting = (
|
app_annotation_setting = (
|
||||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not app_annotation_setting:
|
if not app_annotation_setting:
|
||||||
skipped_count = skipped_count + 1
|
skipped_count = skipped_count + 1
|
||||||
click.echo("App annotation setting disabled: {}".format(app.id))
|
click.echo(f"App annotation setting disabled: {app.id}")
|
||||||
continue
|
continue
|
||||||
# get dataset_collection_binding info
|
# get dataset_collection_binding info
|
||||||
dataset_collection_binding = (
|
dataset_collection_binding = (
|
||||||
@ -207,7 +207,7 @@ def migrate_annotation_vector_database():
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not dataset_collection_binding:
|
if not dataset_collection_binding:
|
||||||
click.echo("App annotation collection binding not found: {}".format(app.id))
|
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||||
continue
|
continue
|
||||||
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
|
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
@ -253,9 +253,7 @@ def migrate_annotation_vector_database():
|
|||||||
create_count += 1
|
create_count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(
|
click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red")
|
||||||
"Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)), fg="red"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -320,7 +318,7 @@ def migrate_knowledge_vector_database():
|
|||||||
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
|
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
click.echo("Creating dataset vector database index: {}".format(dataset.id))
|
click.echo(f"Creating dataset vector database index: {dataset.id}")
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
if dataset.index_struct_dict["type"] == vector_type:
|
if dataset.index_struct_dict["type"] == vector_type:
|
||||||
skipped_count = skipped_count + 1
|
skipped_count = skipped_count + 1
|
||||||
@ -424,9 +422,7 @@ def migrate_knowledge_vector_database():
|
|||||||
create_count += 1
|
create_count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
click.echo(
|
click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red"))
|
||||||
click.style("Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)), fg="red")
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
click.echo(
|
click.echo(
|
||||||
@ -477,7 +473,7 @@ def convert_to_agent_apps():
|
|||||||
break
|
break
|
||||||
|
|
||||||
for app in apps:
|
for app in apps:
|
||||||
click.echo("Converting app: {}".format(app.id))
|
click.echo(f"Converting app: {app.id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.mode = AppMode.AGENT_CHAT.value
|
app.mode = AppMode.AGENT_CHAT.value
|
||||||
@ -489,11 +485,11 @@ def convert_to_agent_apps():
|
|||||||
)
|
)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
|
click.echo(click.style(f"Converted app: {app.id}", fg="green"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
|
click.echo(click.style(f"Convert app error: {e.__class__.__name__} {str(e)}", fg="red"))
|
||||||
|
|
||||||
click.echo(click.style("Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
|
click.echo(click.style(f"Conversion complete. Converted {len(proceeded_app_ids)} agent apps.", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
@click.command("add-qdrant-index", help="Add Qdrant index.")
|
@click.command("add-qdrant-index", help="Add Qdrant index.")
|
||||||
@ -666,7 +662,7 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
|||||||
|
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(
|
click.style(
|
||||||
"Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
|
f"Account and tenant created.\nAccount: {email}\nPassword: {new_password}",
|
||||||
fg="green",
|
fg="green",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -727,16 +723,16 @@ where sites.id is null limit 1000"""
|
|||||||
if tenant:
|
if tenant:
|
||||||
accounts = tenant.get_accounts()
|
accounts = tenant.get_accounts()
|
||||||
if not accounts:
|
if not accounts:
|
||||||
print("Fix failed for app {}".format(app.id))
|
print(f"Fix failed for app {app.id}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
account = accounts[0]
|
account = accounts[0]
|
||||||
print("Fixing missing site for app {}".format(app.id))
|
print(f"Fixing missing site for app {app.id}")
|
||||||
app_was_created.send(app, account=account)
|
app_was_created.send(app, account=account)
|
||||||
except Exception:
|
except Exception:
|
||||||
failed_app_ids.append(app_id)
|
failed_app_ids.append(app_id)
|
||||||
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
|
click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red"))
|
||||||
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
|
logging.exception("Failed to fix app related site missing issue, app_id: %s", app_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not processed_count:
|
if not processed_count:
|
||||||
@ -1218,7 +1214,7 @@ def setup_datasource_oauth_client(provider, client_params):
|
|||||||
provider_id = DatasourceProviderID(provider)
|
provider_id = DatasourceProviderID(provider)
|
||||||
provider_name = provider_id.provider_name
|
provider_name = provider_id.provider_name
|
||||||
plugin_id = provider_id.plugin_id
|
plugin_id = provider_id.plugin_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# json validate
|
# json validate
|
||||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||||
@ -1227,7 +1223,7 @@ def setup_datasource_oauth_client(provider, client_params):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||||
deleted_count = (
|
deleted_count = (
|
||||||
db.session.query(DatasourceOauthParamConfig)
|
db.session.query(DatasourceOauthParamConfig)
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
|||||||
case RemoteSettingsSourceName.NACOS:
|
case RemoteSettingsSourceName.NACOS:
|
||||||
remote_source = NacosSettingsSource(current_state)
|
remote_source = NacosSettingsSource(current_state)
|
||||||
case _:
|
case _:
|
||||||
logger.warning(f"Unsupported remote source: {remote_source_name}")
|
logger.warning("Unsupported remote source: %s", remote_source_name)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
d: dict[str, Any] = {}
|
d: dict[str, Any] = {}
|
||||||
|
|||||||
@ -245,11 +245,7 @@ class CeleryConfig(DatabaseConfig):
|
|||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||||
return (
|
return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL
|
||||||
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
|
|
||||||
if self.CELERY_BACKEND == "database"
|
|
||||||
else self.CELERY_BROKER_URL
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def BROKER_USE_SSL(self) -> bool:
|
def BROKER_USE_SSL(self) -> bool:
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class ApolloClient:
|
|||||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||||
if code == 200:
|
if code == 200:
|
||||||
if not body:
|
if not body:
|
||||||
logger.error(f"get_json_from_net load configs failed, body is {body}")
|
logger.error("get_json_from_net load configs failed, body is %s", body)
|
||||||
return None
|
return None
|
||||||
data = json.loads(body)
|
data = json.loads(body)
|
||||||
data = data["configurations"]
|
data = data["configurations"]
|
||||||
@ -207,7 +207,7 @@ class ApolloClient:
|
|||||||
# if the length is 0 it is returned directly
|
# if the length is 0 it is returned directly
|
||||||
if len(notifications) == 0:
|
if len(notifications) == 0:
|
||||||
return
|
return
|
||||||
url = "{}/notifications/v2".format(self.config_url)
|
url = f"{self.config_url}/notifications/v2"
|
||||||
params = {
|
params = {
|
||||||
"appId": self.app_id,
|
"appId": self.app_id,
|
||||||
"cluster": self.cluster,
|
"cluster": self.cluster,
|
||||||
@ -222,7 +222,7 @@ class ApolloClient:
|
|||||||
return
|
return
|
||||||
if http_code == 200:
|
if http_code == 200:
|
||||||
if not body:
|
if not body:
|
||||||
logger.error(f"_long_poll load configs failed,body is {body}")
|
logger.error("_long_poll load configs failed,body is %s", body)
|
||||||
return
|
return
|
||||||
data = json.loads(body)
|
data = json.loads(body)
|
||||||
for entry in data:
|
for entry in data:
|
||||||
@ -273,12 +273,12 @@ class ApolloClient:
|
|||||||
time.sleep(60 * 10) # 10 minutes
|
time.sleep(60 * 10) # 10 minutes
|
||||||
|
|
||||||
def _do_heart_beat(self, namespace):
|
def _do_heart_beat(self, namespace):
|
||||||
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
|
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
|
||||||
try:
|
try:
|
||||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||||
if code == 200:
|
if code == 200:
|
||||||
if not body:
|
if not body:
|
||||||
logger.error(f"_do_heart_beat load configs failed,body is {body}")
|
logger.error("_do_heart_beat load configs failed,body is %s", body)
|
||||||
return None
|
return None
|
||||||
data = json.loads(body)
|
data = json.loads(body)
|
||||||
if self.last_release_key == data["releaseKey"]:
|
if self.last_release_key == data["releaseKey"]:
|
||||||
|
|||||||
@ -24,7 +24,7 @@ def url_encode_wrapper(params):
|
|||||||
|
|
||||||
|
|
||||||
def no_key_cache_key(namespace, key):
|
def no_key_cache_key(namespace, key):
|
||||||
return "{}{}{}".format(namespace, len(namespace), key)
|
return f"{namespace}{len(namespace)}{key}"
|
||||||
|
|
||||||
|
|
||||||
# Returns whether the obtained value is obtained, and None if it does not
|
# Returns whether the obtained value is obtained, and None if it does not
|
||||||
|
|||||||
@ -28,5 +28,5 @@ def supported_language(lang):
|
|||||||
if lang in languages:
|
if lang in languages:
|
||||||
return lang
|
return lang
|
||||||
|
|
||||||
error = "{lang} is not a valid language.".format(lang=lang)
|
error = f"{lang} is not a valid language."
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|||||||
@ -77,10 +77,10 @@ from .billing import billing, compliance
|
|||||||
|
|
||||||
# Import datasets controllers
|
# Import datasets controllers
|
||||||
from .datasets import (
|
from .datasets import (
|
||||||
|
data_source,
|
||||||
datasets,
|
datasets,
|
||||||
datasets_document,
|
datasets_document,
|
||||||
datasets_segments,
|
datasets_segments,
|
||||||
data_source,
|
|
||||||
external,
|
external,
|
||||||
hit_testing,
|
hit_testing,
|
||||||
metadata,
|
metadata,
|
||||||
|
|||||||
@ -86,7 +86,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||||
cache_result = redis_client.get(app_annotation_job_key)
|
cache_result = redis_client.get(app_annotation_job_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job does not exist.")
|
raise ValueError("The job does not exist.")
|
||||||
@ -94,7 +94,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||||||
job_status = cache_result.decode()
|
job_status = cache_result.decode()
|
||||||
error_msg = ""
|
error_msg = ""
|
||||||
if job_status == "error":
|
if job_status == "error":
|
||||||
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
|
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
|
||||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||||
|
|
||||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||||
@ -123,6 +123,17 @@ class AnnotationListApi(Resource):
|
|||||||
}
|
}
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, app_id):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
app_id = str(app_id)
|
||||||
|
AppAnnotationService.clear_all_annotations(app_id)
|
||||||
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
class AnnotationExportApi(Resource):
|
class AnnotationExportApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@ -223,14 +234,14 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
|
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job does not exist.")
|
raise ValueError("The job does not exist.")
|
||||||
job_status = cache_result.decode()
|
job_status = cache_result.decode()
|
||||||
error_msg = ""
|
error_msg = ""
|
||||||
if job_status == "error":
|
if job_status == "error":
|
||||||
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
|
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
|
||||||
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
||||||
|
|
||||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||||
|
|||||||
@ -51,8 +51,8 @@ class CompletionConversationApi(Resource):
|
|||||||
if args["keyword"]:
|
if args["keyword"]:
|
||||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||||
or_(
|
or_(
|
||||||
Message.query.ilike("%{}%".format(args["keyword"])),
|
Message.query.ilike(f"%{args['keyword']}%"),
|
||||||
Message.answer.ilike("%{}%".format(args["keyword"])),
|
Message.answer.ilike(f"%{args['keyword']}%"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -174,7 +174,7 @@ class ChatConversationApi(Resource):
|
|||||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
|
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
|
||||||
|
|
||||||
if args["keyword"]:
|
if args["keyword"]:
|
||||||
keyword_filter = "%{}%".format(args["keyword"])
|
keyword_filter = f"%{args['keyword']}%"
|
||||||
query = (
|
query = (
|
||||||
query.join(
|
query.join(
|
||||||
Message,
|
Message,
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
@ -29,15 +27,12 @@ class RuleGenerateApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rules = LLMGenerator.generate_rule_config(
|
rules = LLMGenerator.generate_rule_config(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=account.current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=args["no_variable"],
|
no_variable=args["no_variable"],
|
||||||
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
|
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
@ -64,14 +59,12 @@ class RuleCodeGenerateApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
|
|
||||||
try:
|
try:
|
||||||
code_result = LLMGenerator.generate_code(
|
code_result = LLMGenerator.generate_code(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=account.current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
code_language=args["code_language"],
|
code_language=args["code_language"],
|
||||||
max_tokens=CODE_GENERATION_MAX_TOKENS,
|
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
|||||||
@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
|
|||||||
oauth_provider.get_access_token(code)
|
oauth_provider.get_access_token(code)
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}"
|
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||||
)
|
)
|
||||||
return {"error": "OAuth data source process failed"}, 400
|
return {"error": "OAuth data source process failed"}, 400
|
||||||
|
|
||||||
@ -103,7 +103,9 @@ class OAuthDataSourceSync(Resource):
|
|||||||
try:
|
try:
|
||||||
oauth_provider.sync_data_source(binding_id)
|
oauth_provider.sync_data_source(binding_id)
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
logging.exception(
|
||||||
|
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||||
|
)
|
||||||
return {"error": "OAuth data source process failed"}, 400
|
return {"error": "OAuth data source process failed"}, 400
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|||||||
@ -80,7 +80,7 @@ class OAuthCallback(Resource):
|
|||||||
user_info = oauth_provider.get_user_info(token)
|
user_info = oauth_provider.get_user_info(token)
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
error_text = e.response.text if e.response else str(e)
|
error_text = e.response.text if e.response else str(e)
|
||||||
logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}")
|
logging.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||||
return {"error": "OAuth process failed"}, 400
|
return {"error": "OAuth process failed"}, 400
|
||||||
|
|
||||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||||
|
|||||||
@ -972,7 +972,7 @@ class DocumentRetryApi(DocumentResource):
|
|||||||
raise DocumentAlreadyFinishedError()
|
raise DocumentAlreadyFinishedError()
|
||||||
retry_documents.append(document)
|
retry_documents.append(document)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(f"Failed to retry document, document id: {document_id}")
|
logging.exception("Failed to retry document, document id: %s", document_id)
|
||||||
continue
|
continue
|
||||||
# retry document
|
# retry document
|
||||||
DocumentService.retry_document(dataset_id, retry_documents)
|
DocumentService.retry_document(dataset_id, retry_documents)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal, reqparse
|
from flask_restful import Resource, marshal, reqparse
|
||||||
@ -14,8 +13,6 @@ from controllers.console.datasets.error import (
|
|||||||
ChildChunkDeleteIndexError,
|
ChildChunkDeleteIndexError,
|
||||||
ChildChunkIndexingError,
|
ChildChunkIndexingError,
|
||||||
InvalidActionError,
|
InvalidActionError,
|
||||||
NoFileUploadedError,
|
|
||||||
TooManyFilesError,
|
|
||||||
)
|
)
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
@ -32,6 +29,7 @@ from extensions.ext_redis import redis_client
|
|||||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import ChildChunk, DocumentSegment
|
from models.dataset import ChildChunk, DocumentSegment
|
||||||
|
from models.model import UploadFile
|
||||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||||
@ -184,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
segment_ids = request.args.getlist("segment_id")
|
segment_ids = request.args.getlist("segment_id")
|
||||||
|
|
||||||
document_indexing_cache_key = "document_{}_indexing".format(document.id)
|
document_indexing_cache_key = f"document_{document.id}_indexing"
|
||||||
cache_result = redis_client.get(document_indexing_cache_key)
|
cache_result = redis_client.get(document_indexing_cache_key)
|
||||||
if cache_result is not None:
|
if cache_result is not None:
|
||||||
raise InvalidActionError("Document is being indexed, please try again later")
|
raise InvalidActionError("Document is being indexed, please try again later")
|
||||||
@ -365,37 +363,28 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
document = DocumentService.get_document(dataset_id, document_id)
|
document = DocumentService.get_document(dataset_id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# get file from request
|
|
||||||
file = request.files["file"]
|
|
||||||
# check file
|
|
||||||
if "file" not in request.files:
|
|
||||||
raise NoFileUploadedError()
|
|
||||||
|
|
||||||
if len(request.files) > 1:
|
parser = reqparse.RequestParser()
|
||||||
raise TooManyFilesError()
|
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
upload_file_id = args["upload_file_id"]
|
||||||
|
|
||||||
|
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||||
|
if not upload_file:
|
||||||
|
raise NotFound("UploadFile not found.")
|
||||||
|
|
||||||
# check file type
|
# check file type
|
||||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
if not upload_file.name or not upload_file.name.lower().endswith(".csv"):
|
||||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Skip the first row
|
|
||||||
df = pd.read_csv(file)
|
|
||||||
result = []
|
|
||||||
for index, row in df.iterrows():
|
|
||||||
if document.doc_form == "qa_model":
|
|
||||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
|
||||||
else:
|
|
||||||
data = {"content": row.iloc[0]}
|
|
||||||
result.append(data)
|
|
||||||
if len(result) == 0:
|
|
||||||
raise ValueError("The CSV file is empty.")
|
|
||||||
# async job
|
# async job
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
|
indexing_cache_key = f"segment_batch_import_{str(job_id)}"
|
||||||
# send batch add segments task
|
# send batch add segments task
|
||||||
redis_client.setnx(indexing_cache_key, "waiting")
|
redis_client.setnx(indexing_cache_key, "waiting")
|
||||||
batch_create_segment_to_index_task.delay(
|
batch_create_segment_to_index_task.delay(
|
||||||
str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": str(e)}, 500
|
return {"error": str(e)}, 500
|
||||||
@ -406,7 +395,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, job_id):
|
def get(self, job_id):
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
indexing_cache_key = "segment_batch_import_{}".format(job_id)
|
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job does not exist.")
|
raise ValueError("The job does not exist.")
|
||||||
|
|||||||
@ -74,7 +74,7 @@ class InstalledAppsListApi(Resource):
|
|||||||
):
|
):
|
||||||
res.append(installed_app)
|
res.append(installed_app)
|
||||||
installed_app_list = res
|
installed_app_list = res
|
||||||
logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}")
|
logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id)
|
||||||
|
|
||||||
installed_app_list.sort(
|
installed_app_list.sort(
|
||||||
key=lambda app: (
|
key=lambda app: (
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class VersionApi(Resource):
|
|||||||
try:
|
try:
|
||||||
response = requests.get(check_update_url, {"current_version": args.get("current_version")})
|
response = requests.get(check_update_url, {"current_version": args.get("current_version")})
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logging.warning("Check update version error: {}.".format(str(error)))
|
logging.warning("Check update version error: %s.", str(error))
|
||||||
result["version"] = args.get("current_version")
|
result["version"] = args.get("current_version")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
|||||||
# Compare versions
|
# Compare versions
|
||||||
return latest > current
|
return latest > current
|
||||||
except version.InvalidVersion:
|
except version.InvalidVersion:
|
||||||
logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}")
|
logging.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -511,6 +511,8 @@ class CheckEmailUnique(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if AccountService.is_account_in_freeze(args["email"]):
|
||||||
|
raise AccountInFreezeError()
|
||||||
if not AccountService.check_email_unique(args["email"]):
|
if not AccountService.check_email_unique(args["email"]):
|
||||||
raise EmailAlreadyInUseError()
|
raise EmailAlreadyInUseError()
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|||||||
@ -73,8 +73,9 @@ class DefaultModelApi(Resource):
|
|||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
f"Failed to update default model, model type: {model_setting['model_type']},"
|
"Failed to update default model, model type: %s, model: %s",
|
||||||
f" model:{model_setting.get('model')}"
|
model_setting["model_type"],
|
||||||
|
model_setting.get("model"),
|
||||||
)
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
@ -160,8 +161,10 @@ class ModelProviderModelApi(Resource):
|
|||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
f"Failed to save model credentials, tenant_id: {tenant_id},"
|
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
|
||||||
f" model: {args.get('model')}, model_type: {args.get('model_type')}"
|
tenant_id,
|
||||||
|
args.get("model"),
|
||||||
|
args.get("model_type"),
|
||||||
)
|
)
|
||||||
raise ValueError(str(ex))
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||||||
@validate_app_token
|
@validate_app_token
|
||||||
def get(self, app_model: App, job_id, action):
|
def get(self, app_model: App, job_id, action):
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||||
cache_result = redis_client.get(app_annotation_job_key)
|
cache_result = redis_client.get(app_annotation_job_key)
|
||||||
if cache_result is None:
|
if cache_result is None:
|
||||||
raise ValueError("The job does not exist.")
|
raise ValueError("The job does not exist.")
|
||||||
@ -42,7 +42,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||||||
job_status = cache_result.decode()
|
job_status = cache_result.decode()
|
||||||
error_msg = ""
|
error_msg = ""
|
||||||
if job_status == "error":
|
if job_status == "error":
|
||||||
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
|
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
|
||||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||||
|
|
||||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||||
|
|||||||
@ -280,7 +280,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
|
|
||||||
def create_agent_thought(
|
def create_agent_thought(
|
||||||
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
||||||
) -> MessageAgentThought:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create agent thought
|
Create agent thought
|
||||||
"""
|
"""
|
||||||
@ -313,16 +313,15 @@ class BaseAgentRunner(AppRunner):
|
|||||||
|
|
||||||
db.session.add(thought)
|
db.session.add(thought)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.refresh(thought)
|
agent_thought_id = str(thought.id)
|
||||||
|
self.agent_thought_count += 1
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
self.agent_thought_count += 1
|
return agent_thought_id
|
||||||
|
|
||||||
return thought
|
|
||||||
|
|
||||||
def save_agent_thought(
|
def save_agent_thought(
|
||||||
self,
|
self,
|
||||||
agent_thought: MessageAgentThought,
|
agent_thought_id: str,
|
||||||
tool_name: str | None,
|
tool_name: str | None,
|
||||||
tool_input: Union[str, dict, None],
|
tool_input: Union[str, dict, None],
|
||||||
thought: str | None,
|
thought: str | None,
|
||||||
@ -335,12 +334,9 @@ class BaseAgentRunner(AppRunner):
|
|||||||
"""
|
"""
|
||||||
Save agent thought
|
Save agent thought
|
||||||
"""
|
"""
|
||||||
updated_agent_thought = (
|
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
|
||||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first()
|
if not agent_thought:
|
||||||
)
|
|
||||||
if not updated_agent_thought:
|
|
||||||
raise ValueError("agent thought not found")
|
raise ValueError("agent thought not found")
|
||||||
agent_thought = updated_agent_thought
|
|
||||||
|
|
||||||
if thought:
|
if thought:
|
||||||
agent_thought.thought += thought
|
agent_thought.thought += thought
|
||||||
@ -355,7 +351,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
except Exception:
|
except Exception:
|
||||||
tool_input = json.dumps(tool_input)
|
tool_input = json.dumps(tool_input)
|
||||||
|
|
||||||
updated_agent_thought.tool_input = tool_input
|
agent_thought.tool_input = tool_input
|
||||||
|
|
||||||
if observation:
|
if observation:
|
||||||
if isinstance(observation, dict):
|
if isinstance(observation, dict):
|
||||||
@ -364,27 +360,27 @@ class BaseAgentRunner(AppRunner):
|
|||||||
except Exception:
|
except Exception:
|
||||||
observation = json.dumps(observation)
|
observation = json.dumps(observation)
|
||||||
|
|
||||||
updated_agent_thought.observation = observation
|
agent_thought.observation = observation
|
||||||
|
|
||||||
if answer:
|
if answer:
|
||||||
agent_thought.answer = answer
|
agent_thought.answer = answer
|
||||||
|
|
||||||
if messages_ids is not None and len(messages_ids) > 0:
|
if messages_ids is not None and len(messages_ids) > 0:
|
||||||
updated_agent_thought.message_files = json.dumps(messages_ids)
|
agent_thought.message_files = json.dumps(messages_ids)
|
||||||
|
|
||||||
if llm_usage:
|
if llm_usage:
|
||||||
updated_agent_thought.message_token = llm_usage.prompt_tokens
|
agent_thought.message_token = llm_usage.prompt_tokens
|
||||||
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||||
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||||
updated_agent_thought.answer_token = llm_usage.completion_tokens
|
agent_thought.answer_token = llm_usage.completion_tokens
|
||||||
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||||
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||||
updated_agent_thought.tokens = llm_usage.total_tokens
|
agent_thought.tokens = llm_usage.total_tokens
|
||||||
updated_agent_thought.total_price = llm_usage.total_price
|
agent_thought.total_price = llm_usage.total_price
|
||||||
|
|
||||||
# check if tool labels is not empty
|
# check if tool labels is not empty
|
||||||
labels = updated_agent_thought.tool_labels or {}
|
labels = agent_thought.tool_labels or {}
|
||||||
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
|
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if not tool:
|
if not tool:
|
||||||
continue
|
continue
|
||||||
@ -395,7 +391,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
else:
|
else:
|
||||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||||
|
|
||||||
updated_agent_thought.tool_labels_str = json.dumps(labels)
|
agent_thought.tool_labels_str = json.dumps(labels)
|
||||||
|
|
||||||
if tool_invoke_meta is not None:
|
if tool_invoke_meta is not None:
|
||||||
if isinstance(tool_invoke_meta, dict):
|
if isinstance(tool_invoke_meta, dict):
|
||||||
@ -404,7 +400,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
except Exception:
|
except Exception:
|
||||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||||
|
|
||||||
updated_agent_thought.tool_meta_str = tool_invoke_meta
|
agent_thought.tool_meta_str = tool_invoke_meta
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|||||||
@ -97,13 +97,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
|
|
||||||
message_file_ids: list[str] = []
|
message_file_ids: list[str] = []
|
||||||
|
|
||||||
agent_thought = self.create_agent_thought(
|
agent_thought_id = self.create_agent_thought(
|
||||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if iteration_step > 1:
|
if iteration_step > 1:
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
# recalc llm max tokens
|
# recalc llm max tokens
|
||||||
@ -133,7 +133,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
# publish agent thought if it's first iteration
|
# publish agent thought if it's first iteration
|
||||||
if iteration_step == 1:
|
if iteration_step == 1:
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in react_chunks:
|
for chunk in react_chunks:
|
||||||
@ -168,7 +168,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||||
|
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||||
tool_invoke_meta={},
|
tool_invoke_meta={},
|
||||||
@ -181,7 +181,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
|
|
||||||
if not scratchpad.is_final():
|
if not scratchpad.is_final():
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
if not scratchpad.action:
|
if not scratchpad.action:
|
||||||
@ -212,7 +212,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
scratchpad.agent_response = tool_invoke_response
|
scratchpad.agent_response = tool_invoke_response
|
||||||
|
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name=scratchpad.action.action_name,
|
tool_name=scratchpad.action.action_name,
|
||||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||||
thought=scratchpad.thought or "",
|
thought=scratchpad.thought or "",
|
||||||
@ -224,7 +224,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
# update prompt tool message
|
# update prompt tool message
|
||||||
@ -244,7 +244,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
|
|
||||||
# save agent thought
|
# save agent thought
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name="",
|
tool_name="",
|
||||||
tool_input={},
|
tool_input={},
|
||||||
tool_invoke_meta={},
|
tool_invoke_meta={},
|
||||||
|
|||||||
@ -80,7 +80,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
prompt_messages_tools = []
|
prompt_messages_tools = []
|
||||||
|
|
||||||
message_file_ids: list[str] = []
|
message_file_ids: list[str] = []
|
||||||
agent_thought = self.create_agent_thought(
|
agent_thought_id = self.create_agent_thought(
|
||||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if is_first_chunk:
|
if is_first_chunk:
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
is_first_chunk = False
|
is_first_chunk = False
|
||||||
# check if there is any tool call
|
# check if there is any tool call
|
||||||
@ -172,7 +172,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
result.message.content = ""
|
result.message.content = ""
|
||||||
|
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@ -205,7 +205,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
|
|
||||||
# save thought
|
# save thought
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name=tool_call_names,
|
tool_name=tool_call_names,
|
||||||
tool_input=tool_call_inputs,
|
tool_input=tool_call_inputs,
|
||||||
thought=response,
|
thought=response,
|
||||||
@ -216,7 +216,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
llm_usage=current_llm_usage,
|
llm_usage=current_llm_usage,
|
||||||
)
|
)
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
final_answer += response + "\n"
|
final_answer += response + "\n"
|
||||||
@ -276,7 +276,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
if len(tool_responses) > 0:
|
if len(tool_responses) > 0:
|
||||||
# save agent thought
|
# save agent thought
|
||||||
self.save_agent_thought(
|
self.save_agent_thought(
|
||||||
agent_thought=agent_thought,
|
agent_thought_id=agent_thought_id,
|
||||||
tool_name="",
|
tool_name="",
|
||||||
tool_input="",
|
tool_input="",
|
||||||
thought="",
|
thought="",
|
||||||
@ -291,7 +291,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
messages_ids=message_file_ids,
|
messages_ids=message_file_ids,
|
||||||
)
|
)
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
# update prompt tool
|
# update prompt tool
|
||||||
|
|||||||
@ -600,5 +600,5 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
raise GenerateTaskStoppedError()
|
raise GenerateTaskStoppedError()
|
||||||
else:
|
else:
|
||||||
logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}")
|
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@ -271,7 +271,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
start_listener_time = time.time()
|
start_listener_time = time.time()
|
||||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
|
logger.exception("Failed to listen audio message, task_id: %s", task_id)
|
||||||
break
|
break
|
||||||
if tts_publisher:
|
if tts_publisher:
|
||||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||||
|
|||||||
@ -78,7 +78,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
raise GenerateTaskStoppedError()
|
raise GenerateTaskStoppedError()
|
||||||
else:
|
else:
|
||||||
logger.exception(f"Failed to handle response, conversation_id: {conversation.id}")
|
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||||
|
|||||||
@ -562,7 +562,9 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
if is_external_api_call:
|
if is_external_api_call:
|
||||||
# For external API calls, use end user's session ID
|
# For external API calls, use end user's session ID
|
||||||
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
|
end_user = session.scalar(
|
||||||
|
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
|
||||||
|
)
|
||||||
system_user_id = end_user.session_id if end_user else ""
|
system_user_id = end_user.session_id if end_user else ""
|
||||||
else:
|
else:
|
||||||
# For internal calls, use the original user ID
|
# For internal calls, use the original user ID
|
||||||
|
|||||||
@ -483,7 +483,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
try:
|
try:
|
||||||
runner.run()
|
runner.run()
|
||||||
except GenerateTaskStoppedError as e:
|
except GenerateTaskStoppedError as e:
|
||||||
logger.warning(f"Task stopped: {str(e)}")
|
logger.warning("Task stopped: %s", str(e))
|
||||||
pass
|
pass
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
queue_manager.publish_error(
|
queue_manager.publish_error(
|
||||||
@ -540,6 +540,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
raise GenerateTaskStoppedError()
|
raise GenerateTaskStoppedError()
|
||||||
else:
|
else:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
|
"Fails to process generate task pipeline, task_id: %s", application_generate_entity.task_id
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@ -246,7 +246,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
else:
|
else:
|
||||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
|
logger.exception("Fails to get audio trunk, task_id: %s", task_id)
|
||||||
break
|
break
|
||||||
if tts_publisher:
|
if tts_publisher:
|
||||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||||
|
|||||||
@ -83,7 +83,7 @@ class AnnotationReplyFeature:
|
|||||||
|
|
||||||
return annotation
|
return annotation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Query annotation failed, exception: {str(e)}.")
|
logger.warning("Query annotation failed, exception: %s.", str(e))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -97,7 +97,7 @@ class MessageCycleManager:
|
|||||||
conversation.name = name
|
conversation.name = name
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}")
|
logging.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
db.session.merge(conversation)
|
db.session.merge(conversation)
|
||||||
|
|||||||
@ -900,7 +900,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
credentials=copy_credentials,
|
credentials=copy_credentials,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(f"get custom model schema failed, {ex}")
|
logger.warning("get custom model schema failed, %s", ex)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not custom_model_schema:
|
if not custom_model_schema:
|
||||||
@ -1009,7 +1009,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
credentials=model_configuration.credentials,
|
credentials=model_configuration.credentials,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(f"get custom model schema failed, {ex}")
|
logger.warning("get custom model schema failed, %s", ex)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not custom_model_schema:
|
if not custom_model_schema:
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class APIBasedExtensionRequestor:
|
|||||||
:param params: the request params
|
:param params: the request params
|
||||||
:return: the response json
|
:return: the response json
|
||||||
"""
|
"""
|
||||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)}
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
|
|
||||||
url = self.api_endpoint
|
url = self.api_endpoint
|
||||||
|
|
||||||
@ -49,8 +49,6 @@ class APIBasedExtensionRequestor:
|
|||||||
raise ValueError("request connection error")
|
raise ValueError("request connection error")
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise ValueError(
|
raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}")
|
||||||
"request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(dict, response.json())
|
return cast(dict, response.json())
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class Extensible:
|
|||||||
|
|
||||||
# Check for extension module file
|
# Check for extension module file
|
||||||
if (extension_name + ".py") not in file_names:
|
if (extension_name + ".py") not in file_names:
|
||||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
logging.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check for builtin flag and position
|
# Check for builtin flag and position
|
||||||
@ -95,7 +95,7 @@ class Extensible:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not extension_class:
|
if not extension_class:
|
||||||
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
|
logging.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Load schema if not builtin
|
# Load schema if not builtin
|
||||||
@ -103,7 +103,7 @@ class Extensible:
|
|||||||
if not builtin:
|
if not builtin:
|
||||||
json_path = os.path.join(subdir_path, "schema.json")
|
json_path = os.path.join(subdir_path, "schema.json")
|
||||||
if not os.path.exists(json_path):
|
if not os.path.exists(json_path):
|
||||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
logging.warning("Missing schema.json file in %s, Skip.", subdir_path)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with open(json_path, encoding="utf-8") as f:
|
with open(json_path, encoding="utf-8") as f:
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||||||
"""
|
"""
|
||||||
# get params from config
|
# get params from config
|
||||||
if not self.config:
|
if not self.config:
|
||||||
raise ValueError("config is required, config: {}".format(self.config))
|
raise ValueError(f"config is required, config: {self.config}")
|
||||||
api_based_extension_id = self.config.get("api_based_extension_id")
|
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||||
assert api_based_extension_id is not None, "api_based_extension_id is required"
|
assert api_based_extension_id is not None, "api_based_extension_id is required"
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||||||
# request api
|
# request api
|
||||||
requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key)
|
requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e))
|
raise ValueError(f"[External data tool] API query failed, variable: {self.variable}, error: {e}")
|
||||||
|
|
||||||
response_json = requestor.request(
|
response_json = requestor.request(
|
||||||
point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
|
point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
|
||||||
@ -90,7 +90,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||||||
|
|
||||||
if not isinstance(response_json["result"], str):
|
if not isinstance(response_json["result"], str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable)
|
f"[External data tool] API query failed, variable: {self.variable}, error: result is not string"
|
||||||
)
|
)
|
||||||
|
|
||||||
return response_json["result"]
|
return response_json["result"]
|
||||||
|
|||||||
@ -55,7 +55,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
|||||||
if moderation_result is True:
|
if moderation_result is True:
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
|
logger.exception("Fails to check moderation, provider_name: %s", provider_name)
|
||||||
raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
|
raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -30,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
|||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
return module
|
return module
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'")
|
logging.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path))
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -73,10 +73,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||||||
if response.status_code not in STATUS_FORCELIST:
|
if response.status_code not in STATUS_FORCELIST:
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
|
logging.warning(
|
||||||
|
"Received status code %s for URL %s which is in the force list", response.status_code, url
|
||||||
|
)
|
||||||
|
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
|
logging.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e)
|
||||||
if max_retries == 0:
|
if max_retries == 0:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@ -84,14 +84,14 @@ class IndexingRunner:
|
|||||||
documents=documents,
|
documents=documents,
|
||||||
)
|
)
|
||||||
except DocumentIsPausedError:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
dataset_document.indexing_status = "error"
|
||||||
dataset_document.error = str(e.description)
|
dataset_document.error = str(e.description)
|
||||||
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
except ObjectDeletedError:
|
except ObjectDeletedError:
|
||||||
logging.warning("Document deleted, document id: {}".format(dataset_document.id))
|
logging.warning("Document deleted, document id: %s", dataset_document.id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("consume document failed")
|
logging.exception("consume document failed")
|
||||||
dataset_document.indexing_status = "error"
|
dataset_document.indexing_status = "error"
|
||||||
@ -147,7 +147,7 @@ class IndexingRunner:
|
|||||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||||
)
|
)
|
||||||
except DocumentIsPausedError:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
dataset_document.indexing_status = "error"
|
||||||
dataset_document.error = str(e.description)
|
dataset_document.error = str(e.description)
|
||||||
@ -222,7 +222,7 @@ class IndexingRunner:
|
|||||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||||
)
|
)
|
||||||
except DocumentIsPausedError:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}")
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
dataset_document.indexing_status = "error"
|
||||||
dataset_document.error = str(e.description)
|
dataset_document.error = str(e.description)
|
||||||
@ -324,7 +324,8 @@ class IndexingRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
"Delete image_files failed while indexing_estimate, \
|
"Delete image_files failed while indexing_estimate, \
|
||||||
image_upload_file_is: {}".format(upload_file_id)
|
image_upload_file_is: %s",
|
||||||
|
upload_file_id,
|
||||||
)
|
)
|
||||||
db.session.delete(image_file)
|
db.session.delete(image_file)
|
||||||
|
|
||||||
@ -649,7 +650,7 @@ class IndexingRunner:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_document_paused_status(document_id: str):
|
def _check_document_paused_status(document_id: str):
|
||||||
indexing_cache_key = "document_{}_is_paused".format(document_id)
|
indexing_cache_key = f"document_{document_id}_is_paused"
|
||||||
result = redis_client.get(indexing_cache_key)
|
result = redis_client.get(indexing_cache_key)
|
||||||
if result:
|
if result:
|
||||||
raise DocumentIsPausedError()
|
raise DocumentIsPausedError()
|
||||||
|
|||||||
@ -125,16 +125,13 @@ class LLMGenerator:
|
|||||||
return questions
|
return questions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_rule_config(
|
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict:
|
||||||
cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512
|
|
||||||
) -> dict:
|
|
||||||
output_parser = RuleConfigGeneratorOutputParser()
|
output_parser = RuleConfigGeneratorOutputParser()
|
||||||
|
|
||||||
error = ""
|
error = ""
|
||||||
error_step = ""
|
error_step = ""
|
||||||
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
||||||
model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01}
|
model_parameters = model_config.get("completion_params", {})
|
||||||
|
|
||||||
if no_variable:
|
if no_variable:
|
||||||
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||||
|
|
||||||
@ -170,7 +167,7 @@ class LLMGenerator:
|
|||||||
error = str(e)
|
error = str(e)
|
||||||
error_step = "generate rule config"
|
error_step = "generate rule config"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}")
|
logging.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||||
rule_config["error"] = str(e)
|
rule_config["error"] = str(e)
|
||||||
|
|
||||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||||
@ -267,7 +264,7 @@ class LLMGenerator:
|
|||||||
error_step = "generate conversation opener"
|
error_step = "generate conversation opener"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}")
|
logging.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||||
rule_config["error"] = str(e)
|
rule_config["error"] = str(e)
|
||||||
|
|
||||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||||
@ -276,12 +273,7 @@ class LLMGenerator:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_code(
|
def generate_code(
|
||||||
cls,
|
cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"
|
||||||
tenant_id: str,
|
|
||||||
instruction: str,
|
|
||||||
model_config: dict,
|
|
||||||
code_language: str = "javascript",
|
|
||||||
max_tokens: int = 1000,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if code_language == "python":
|
if code_language == "python":
|
||||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||||
@ -305,8 +297,7 @@ class LLMGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||||
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
|
model_parameters = model_config.get("completion_params", {})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = cast(
|
response = cast(
|
||||||
LLMResult,
|
LLMResult,
|
||||||
@ -323,7 +314,7 @@ class LLMGenerator:
|
|||||||
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
f"Failed to invoke LLM model, model: {model_config.get('name')}, language: {code_language}"
|
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
|
||||||
)
|
)
|
||||||
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||||
|
|
||||||
@ -395,5 +386,5 @@ class LLMGenerator:
|
|||||||
error = str(e)
|
error = str(e)
|
||||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}")
|
logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||||
|
|||||||
@ -88,7 +88,7 @@ class SSETransport:
|
|||||||
status_queue: Queue to put status updates.
|
status_queue: Queue to put status updates.
|
||||||
"""
|
"""
|
||||||
endpoint_url = urljoin(self.url, sse_data)
|
endpoint_url = urljoin(self.url, sse_data)
|
||||||
logger.info(f"Received endpoint URL: {endpoint_url}")
|
logger.info("Received endpoint URL: %s", endpoint_url)
|
||||||
|
|
||||||
if not self._validate_endpoint_url(endpoint_url):
|
if not self._validate_endpoint_url(endpoint_url):
|
||||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||||
@ -107,7 +107,7 @@ class SSETransport:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
message = types.JSONRPCMessage.model_validate_json(sse_data)
|
message = types.JSONRPCMessage.model_validate_json(sse_data)
|
||||||
logger.debug(f"Received server message: {message}")
|
logger.debug("Received server message: %s", message)
|
||||||
session_message = SessionMessage(message)
|
session_message = SessionMessage(message)
|
||||||
read_queue.put(session_message)
|
read_queue.put(session_message)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@ -128,7 +128,7 @@ class SSETransport:
|
|||||||
case "message":
|
case "message":
|
||||||
self._handle_message_event(sse.data, read_queue)
|
self._handle_message_event(sse.data, read_queue)
|
||||||
case _:
|
case _:
|
||||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
logger.warning("Unknown SSE event: %s", sse.event)
|
||||||
|
|
||||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||||
"""Read and process SSE events.
|
"""Read and process SSE events.
|
||||||
@ -142,7 +142,7 @@ class SSETransport:
|
|||||||
for sse in event_source.iter_sse():
|
for sse in event_source.iter_sse():
|
||||||
self._handle_sse_event(sse, read_queue, status_queue)
|
self._handle_sse_event(sse, read_queue, status_queue)
|
||||||
except httpx.ReadError as exc:
|
except httpx.ReadError as exc:
|
||||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
logger.debug("SSE reader shutting down normally: %s", exc)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
read_queue.put(exc)
|
read_queue.put(exc)
|
||||||
finally:
|
finally:
|
||||||
@ -165,7 +165,7 @@ class SSETransport:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
logger.debug("Client message sent successfully: %s", response.status_code)
|
||||||
|
|
||||||
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
|
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
|
||||||
"""Handle writing messages to the server.
|
"""Handle writing messages to the server.
|
||||||
@ -190,7 +190,7 @@ class SSETransport:
|
|||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
except httpx.ReadError as exc:
|
except httpx.ReadError as exc:
|
||||||
logger.debug(f"Post writer shutting down normally: {exc}")
|
logger.debug("Post writer shutting down normally: %s", exc)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Error writing messages")
|
logger.exception("Error writing messages")
|
||||||
write_queue.put(exc)
|
write_queue.put(exc)
|
||||||
@ -326,7 +326,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
logger.debug("Client message sent successfully: %s", response.status_code)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Error sending message")
|
logger.exception("Error sending message")
|
||||||
raise
|
raise
|
||||||
@ -349,13 +349,13 @@ def read_messages(
|
|||||||
if sse.event == "message":
|
if sse.event == "message":
|
||||||
try:
|
try:
|
||||||
message = types.JSONRPCMessage.model_validate_json(sse.data)
|
message = types.JSONRPCMessage.model_validate_json(sse.data)
|
||||||
logger.debug(f"Received server message: {message}")
|
logger.debug("Received server message: %s", message)
|
||||||
yield SessionMessage(message)
|
yield SessionMessage(message)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Error parsing server message")
|
logger.exception("Error parsing server message")
|
||||||
yield exc
|
yield exc
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
logger.warning("Unknown SSE event: %s", sse.event)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Error reading SSE messages")
|
logger.exception("Error reading SSE messages")
|
||||||
yield exc
|
yield exc
|
||||||
|
|||||||
@ -129,7 +129,7 @@ class StreamableHTTPTransport:
|
|||||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||||
if new_session_id:
|
if new_session_id:
|
||||||
self.session_id = new_session_id
|
self.session_id = new_session_id
|
||||||
logger.info(f"Received session ID: {self.session_id}")
|
logger.info("Received session ID: %s", self.session_id)
|
||||||
|
|
||||||
def _handle_sse_event(
|
def _handle_sse_event(
|
||||||
self,
|
self,
|
||||||
@ -142,7 +142,7 @@ class StreamableHTTPTransport:
|
|||||||
if sse.event == "message":
|
if sse.event == "message":
|
||||||
try:
|
try:
|
||||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||||
logger.debug(f"SSE message: {message}")
|
logger.debug("SSE message: %s", message)
|
||||||
|
|
||||||
# If this is a response and we have original_request_id, replace it
|
# If this is a response and we have original_request_id, replace it
|
||||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||||
@ -168,7 +168,7 @@ class StreamableHTTPTransport:
|
|||||||
logger.debug("Received ping event")
|
logger.debug("Received ping event")
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
logger.warning("Unknown SSE event: %s", sse.event)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def handle_get_stream(
|
def handle_get_stream(
|
||||||
@ -197,7 +197,7 @@ class StreamableHTTPTransport:
|
|||||||
self._handle_sse_event(sse, server_to_client_queue)
|
self._handle_sse_event(sse, server_to_client_queue)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug(f"GET stream error (non-fatal): {exc}")
|
logger.debug("GET stream error (non-fatal): %s", exc)
|
||||||
|
|
||||||
def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||||
"""Handle a resumption request using GET with SSE."""
|
"""Handle a resumption request using GET with SSE."""
|
||||||
@ -352,7 +352,7 @@ class StreamableHTTPTransport:
|
|||||||
# Check if this is a resumption request
|
# Check if this is a resumption request
|
||||||
is_resumption = bool(metadata and metadata.resumption_token)
|
is_resumption = bool(metadata and metadata.resumption_token)
|
||||||
|
|
||||||
logger.debug(f"Sending client message: {message}")
|
logger.debug("Sending client message: %s", message)
|
||||||
|
|
||||||
# Handle initialized notification
|
# Handle initialized notification
|
||||||
if self._is_initialized_notification(message):
|
if self._is_initialized_notification(message):
|
||||||
@ -389,9 +389,9 @@ class StreamableHTTPTransport:
|
|||||||
if response.status_code == 405:
|
if response.status_code == 405:
|
||||||
logger.debug("Server does not allow session termination")
|
logger.debug("Server does not allow session termination")
|
||||||
elif response.status_code != 200:
|
elif response.status_code != 200:
|
||||||
logger.warning(f"Session termination failed: {response.status_code}")
|
logger.warning("Session termination failed: %s", response.status_code)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Session termination failed: {exc}")
|
logger.warning("Session termination failed: %s", exc)
|
||||||
|
|
||||||
def get_session_id(self) -> str | None:
|
def get_session_id(self) -> str | None:
|
||||||
"""Get the current session ID."""
|
"""Get the current session ID."""
|
||||||
|
|||||||
@ -75,7 +75,7 @@ class MCPClient:
|
|||||||
self.connect_server(client_factory, method_name)
|
self.connect_server(client_factory, method_name)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.")
|
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
||||||
self.connect_server(sse_client, "sse")
|
self.connect_server(sse_client, "sse")
|
||||||
except MCPConnectionError:
|
except MCPConnectionError:
|
||||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||||
|
|||||||
@ -368,7 +368,7 @@ class BaseSession(
|
|||||||
self._handle_incoming(notification)
|
self._handle_incoming(notification)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# For other validation errors, log and continue
|
# For other validation errors, log and continue
|
||||||
logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
|
logging.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
||||||
else: # Response or error
|
else: # Response or error
|
||||||
response_queue = self._response_streams.get(message.message.root.id)
|
response_queue = self._response_streams.get(message.message.root.id)
|
||||||
if response_queue is not None:
|
if response_queue is not None:
|
||||||
|
|||||||
@ -535,9 +535,19 @@ class LBModelManager:
|
|||||||
|
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
"""Model LB
|
||||||
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
id: %s
|
||||||
f"model_type: {self._model_type.value}\nmodel: {self._model}"
|
name:%s
|
||||||
|
tenant_id: %s
|
||||||
|
provider: %s
|
||||||
|
model_type: %s
|
||||||
|
model: %s""",
|
||||||
|
config.id,
|
||||||
|
config.name,
|
||||||
|
self._tenant_id,
|
||||||
|
self._provider,
|
||||||
|
self._model_type.value,
|
||||||
|
self._model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|||||||
@ -440,7 +440,9 @@ class LargeLanguageModel(AIModel):
|
|||||||
if callback.raise_error:
|
if callback.raise_error:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
|
logger.warning(
|
||||||
|
"Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e
|
||||||
|
)
|
||||||
|
|
||||||
def _trigger_new_chunk_callbacks(
|
def _trigger_new_chunk_callbacks(
|
||||||
self,
|
self,
|
||||||
@ -487,7 +489,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
if callback.raise_error:
|
if callback.raise_error:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
|
logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e)
|
||||||
|
|
||||||
def _trigger_after_invoke_callbacks(
|
def _trigger_after_invoke_callbacks(
|
||||||
self,
|
self,
|
||||||
@ -535,7 +537,9 @@ class LargeLanguageModel(AIModel):
|
|||||||
if callback.raise_error:
|
if callback.raise_error:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
|
logger.warning(
|
||||||
|
"Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e
|
||||||
|
)
|
||||||
|
|
||||||
def _trigger_invoke_error_callbacks(
|
def _trigger_invoke_error_callbacks(
|
||||||
self,
|
self,
|
||||||
@ -583,4 +587,6 @@ class LargeLanguageModel(AIModel):
|
|||||||
if callback.raise_error:
|
if callback.raise_error:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}")
|
logger.warning(
|
||||||
|
"Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e
|
||||||
|
)
|
||||||
|
|||||||
@ -136,6 +136,6 @@ class OutputModeration(BaseModel):
|
|||||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Moderation Output error, app_id: {app_id}")
|
logger.exception("Moderation Output error, app_id: %s", app_id)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -97,7 +97,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||||||
try:
|
try:
|
||||||
return self.trace_client.get_project_url()
|
return self.trace_client.get_project_url()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True)
|
logger.info("Aliyun get run url failed: %s", str(e), exc_info=True)
|
||||||
raise ValueError(f"Aliyun get run url failed: {str(e)}")
|
raise ValueError(f"Aliyun get run url failed: {str(e)}")
|
||||||
|
|
||||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||||
@ -139,7 +139,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||||
attributes={
|
attributes={
|
||||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||||
GEN_AI_USER_ID: str(user_id),
|
GEN_AI_USER_ID: str(user_id),
|
||||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||||
GEN_AI_FRAMEWORK: "dify",
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
@ -161,12 +161,12 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||||
attributes={
|
attributes={
|
||||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||||
GEN_AI_USER_ID: str(user_id),
|
GEN_AI_USER_ID: str(user_id),
|
||||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||||
GEN_AI_FRAMEWORK: "dify",
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
|
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "",
|
||||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
|
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "",
|
||||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
|
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
|
||||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
|
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
|
||||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
|
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
|
||||||
@ -286,7 +286,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||||||
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
|
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||||
return node_span
|
return node_span
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True)
|
logging.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
|
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
|
||||||
@ -386,14 +386,14 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
|
||||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||||
GEN_AI_FRAMEWORK: "dify",
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
GEN_AI_MODEL_NAME: process_data.get("model_name") or "",
|
||||||
GEN_AI_SYSTEM: process_data.get("model_provider", ""),
|
GEN_AI_SYSTEM: process_data.get("model_provider") or "",
|
||||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason") or "",
|
||||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||||
},
|
},
|
||||||
@ -421,7 +421,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||||||
GEN_AI_USER_ID: str(user_id),
|
GEN_AI_USER_ID: str(user_id),
|
||||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||||
GEN_AI_FRAMEWORK: "dify",
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
|
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||||
},
|
},
|
||||||
status=status,
|
status=status,
|
||||||
@ -461,8 +461,8 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||||||
attributes={
|
attributes={
|
||||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||||
GEN_AI_FRAMEWORK: "dify",
|
GEN_AI_FRAMEWORK: "dify",
|
||||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
|
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "",
|
||||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
|
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "",
|
||||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||||
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||||
|
|||||||
@ -69,10 +69,10 @@ class TraceClient:
|
|||||||
if response.status_code == 405:
|
if response.status_code == 405:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}")
|
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||||
return False
|
return False
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.debug(f"AliyunTrace API check failed: {str(e)}")
|
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||||
|
|
||||||
def get_project_url(self):
|
def get_project_url(self):
|
||||||
@ -109,7 +109,7 @@ class TraceClient:
|
|||||||
try:
|
try:
|
||||||
self.exporter.export(spans_to_export)
|
self.exporter.export(spans_to_export)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error exporting spans: {e}")
|
logger.debug("Error exporting spans: %s", e)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
with self.condition:
|
with self.condition:
|
||||||
|
|||||||
@ -77,10 +77,10 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
|
|||||||
|
|
||||||
# Create a named tracer instead of setting the global provider
|
# Create a named tracer instead of setting the global provider
|
||||||
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
|
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
|
||||||
logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
|
logger.info("[Arize/Phoenix] Created tracer with name: %s", tracer_name)
|
||||||
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
|
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
|
logger.error("[Arize/Phoenix] Failed to setup the tracer: %s", str(e), exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||||
|
|
||||||
def trace(self, trace_info: BaseTraceInfo):
|
def trace(self, trace_info: BaseTraceInfo):
|
||||||
logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
|
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
|
||||||
try:
|
try:
|
||||||
if isinstance(trace_info, WorkflowTraceInfo):
|
if isinstance(trace_info, WorkflowTraceInfo):
|
||||||
self.workflow_trace(trace_info)
|
self.workflow_trace(trace_info)
|
||||||
@ -138,7 +138,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||||||
self.generate_name_trace(trace_info)
|
self.generate_name_trace(trace_info)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
|
logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||||
@ -570,7 +570,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||||
logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
|
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
|
||||||
|
|
||||||
# Create span context with the same trace_id as the parent
|
# Create span context with the same trace_id as the parent
|
||||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||||
@ -673,7 +673,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||||||
span.set_attribute("test", "true")
|
span.set_attribute("test", "true")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
|
logger.info("[Arize/Phoenix] API check failed: %s", str(e), exc_info=True)
|
||||||
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
||||||
|
|
||||||
def get_project_url(self):
|
def get_project_url(self):
|
||||||
@ -683,7 +683,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||||||
else:
|
else:
|
||||||
return f"{self.arize_phoenix_config.endpoint}/projects/"
|
return f"{self.arize_phoenix_config.endpoint}/projects/"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
|
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
|
||||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||||
|
|
||||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||||
|
|||||||
@ -102,7 +102,7 @@ class LangfuseConfig(BaseTracingConfig):
|
|||||||
@field_validator("host")
|
@field_validator("host")
|
||||||
@classmethod
|
@classmethod
|
||||||
def host_validator(cls, v, info: ValidationInfo):
|
def host_validator(cls, v, info: ValidationInfo):
|
||||||
return cls.validate_endpoint_url(v, "https://api.langfuse.com")
|
return validate_url_with_path(v, "https://api.langfuse.com")
|
||||||
|
|
||||||
|
|
||||||
class LangSmithConfig(BaseTracingConfig):
|
class LangSmithConfig(BaseTracingConfig):
|
||||||
|
|||||||
@ -440,7 +440,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
try:
|
try:
|
||||||
return self.langfuse_client.auth_check()
|
return self.langfuse_client.auth_check()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"LangFuse API check failed: {str(e)}")
|
logger.debug("LangFuse API check failed: %s", str(e))
|
||||||
raise ValueError(f"LangFuse API check failed: {str(e)}")
|
raise ValueError(f"LangFuse API check failed: {str(e)}")
|
||||||
|
|
||||||
def get_project_key(self):
|
def get_project_key(self):
|
||||||
@ -448,5 +448,5 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
projects = self.langfuse_client.client.projects.get()
|
projects = self.langfuse_client.client.projects.get()
|
||||||
return projects.data[0].id
|
return projects.data[0].id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"LangFuse get project key failed: {str(e)}")
|
logger.debug("LangFuse get project key failed: %s", str(e))
|
||||||
raise ValueError(f"LangFuse get project key failed: {str(e)}")
|
raise ValueError(f"LangFuse get project key failed: {str(e)}")
|
||||||
|
|||||||
@ -504,7 +504,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
self.langsmith_client.delete_project(project_name=random_project_name)
|
self.langsmith_client.delete_project(project_name=random_project_name)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"LangSmith API check failed: {str(e)}")
|
logger.debug("LangSmith API check failed: %s", str(e))
|
||||||
raise ValueError(f"LangSmith API check failed: {str(e)}")
|
raise ValueError(f"LangSmith API check failed: {str(e)}")
|
||||||
|
|
||||||
def get_project_url(self):
|
def get_project_url(self):
|
||||||
@ -523,5 +523,5 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
)
|
)
|
||||||
return project_url.split("/r/")[0]
|
return project_url.split("/r/")[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"LangSmith get run url failed: {str(e)}")
|
logger.debug("LangSmith get run url failed: %s", str(e))
|
||||||
raise ValueError(f"LangSmith get run url failed: {str(e)}")
|
raise ValueError(f"LangSmith get run url failed: {str(e)}")
|
||||||
|
|||||||
@ -453,12 +453,12 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
self.opik_client.auth_check()
|
self.opik_client.auth_check()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
|
logger.info("Opik API check failed: %s", str(e), exc_info=True)
|
||||||
raise ValueError(f"Opik API check failed: {str(e)}")
|
raise ValueError(f"Opik API check failed: {str(e)}")
|
||||||
|
|
||||||
def get_project_url(self):
|
def get_project_url(self):
|
||||||
try:
|
try:
|
||||||
return self.opik_client.get_project_url(project_name=self.project)
|
return self.opik_client.get_project_url(project_name=self.project)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
|
logger.info("Opik get run url failed: %s", str(e), exc_info=True)
|
||||||
raise ValueError(f"Opik get run url failed: {str(e)}")
|
raise ValueError(f"Opik get run url failed: {str(e)}")
|
||||||
|
|||||||
@ -287,7 +287,7 @@ class OpsTraceManager:
|
|||||||
# create new tracing_instance and update the cache if it absent
|
# create new tracing_instance and update the cache if it absent
|
||||||
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
|
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
|
||||||
cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
|
cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
|
||||||
logging.info(f"new tracing_instance for app_id: {app_id}")
|
logging.info("new tracing_instance for app_id: %s", app_id)
|
||||||
return tracing_instance
|
return tracing_instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -843,7 +843,7 @@ class TraceQueueManager:
|
|||||||
trace_task.app_id = self.app_id
|
trace_task.app_id = self.app_id
|
||||||
trace_manager_queue.put(trace_task)
|
trace_manager_queue.put(trace_task)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
|
logging.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
|
||||||
finally:
|
finally:
|
||||||
self.start_timer()
|
self.start_timer()
|
||||||
|
|
||||||
|
|||||||
@ -67,7 +67,13 @@ def generate_dotted_order(
|
|||||||
|
|
||||||
def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str:
|
def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str:
|
||||||
"""
|
"""
|
||||||
Validate and normalize URL with proper error handling
|
Validate and normalize URL with proper error handling.
|
||||||
|
|
||||||
|
NOTE: This function does not retain the `path` component of the provided URL.
|
||||||
|
In most cases, it is recommended to use `validate_url_with_path` instead.
|
||||||
|
|
||||||
|
This function is deprecated and retained only for compatibility purposes.
|
||||||
|
New implementations should use `validate_url_with_path`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url: The URL to validate
|
url: The URL to validate
|
||||||
|
|||||||
@ -66,11 +66,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||||||
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
|
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
|
||||||
return project_url
|
return project_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Weave get run url failed: {str(e)}")
|
logger.debug("Weave get run url failed: %s", str(e))
|
||||||
raise ValueError(f"Weave get run url failed: {str(e)}")
|
raise ValueError(f"Weave get run url failed: {str(e)}")
|
||||||
|
|
||||||
def trace(self, trace_info: BaseTraceInfo):
|
def trace(self, trace_info: BaseTraceInfo):
|
||||||
logger.debug(f"Trace info: {trace_info}")
|
logger.debug("Trace info: %s", trace_info)
|
||||||
if isinstance(trace_info, WorkflowTraceInfo):
|
if isinstance(trace_info, WorkflowTraceInfo):
|
||||||
self.workflow_trace(trace_info)
|
self.workflow_trace(trace_info)
|
||||||
if isinstance(trace_info, MessageTraceInfo):
|
if isinstance(trace_info, MessageTraceInfo):
|
||||||
@ -403,7 +403,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||||||
print("Weave login successful")
|
print("Weave login successful")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Weave API check failed: {str(e)}")
|
logger.debug("Weave API check failed: %s", str(e))
|
||||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||||
|
|
||||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
|
def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
|
||||||
|
|||||||
@ -361,4 +361,4 @@ class PluginDatasourceManager(BasePluginClient):
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,3 +1,8 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
|
||||||
class PluginDaemonError(Exception):
|
class PluginDaemonError(Exception):
|
||||||
"""Base class for all plugin daemon errors."""
|
"""Base class for all plugin daemon errors."""
|
||||||
|
|
||||||
@ -36,6 +41,21 @@ class PluginDaemonBadRequestError(PluginDaemonClientSideError):
|
|||||||
class PluginInvokeError(PluginDaemonClientSideError):
|
class PluginInvokeError(PluginDaemonClientSideError):
|
||||||
description: str = "Invoke Error"
|
description: str = "Invoke Error"
|
||||||
|
|
||||||
|
def _get_error_object(self) -> Mapping:
|
||||||
|
try:
|
||||||
|
return TypeAdapter(Mapping).validate_json(self.description)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_error_type(self) -> str:
|
||||||
|
return self._get_error_object().get("error_type", "unknown")
|
||||||
|
|
||||||
|
def get_error_message(self) -> str:
|
||||||
|
try:
|
||||||
|
return self._get_error_object().get("message", "unknown")
|
||||||
|
except Exception:
|
||||||
|
return self.description
|
||||||
|
|
||||||
|
|
||||||
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
|
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
|
||||||
description: str = "Unique Identifier Error"
|
description: str = "Unique Identifier Error"
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class Jieba(BaseKeyword):
|
|||||||
self._config = KeywordTableConfig()
|
self._config = KeywordTableConfig()
|
||||||
|
|
||||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||||
with redis_client.lock(lock_name, timeout=600):
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
keyword_table_handler = JiebaKeywordTableHandler()
|
keyword_table_handler = JiebaKeywordTableHandler()
|
||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
@ -45,7 +45,7 @@ class Jieba(BaseKeyword):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def add_texts(self, texts: list[Document], **kwargs):
|
def add_texts(self, texts: list[Document], **kwargs):
|
||||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||||
with redis_client.lock(lock_name, timeout=600):
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
keyword_table_handler = JiebaKeywordTableHandler()
|
keyword_table_handler = JiebaKeywordTableHandler()
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ class Jieba(BaseKeyword):
|
|||||||
return id in set.union(*keyword_table.values())
|
return id in set.union(*keyword_table.values())
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||||
with redis_client.lock(lock_name, timeout=600):
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
if keyword_table is not None:
|
if keyword_table is not None:
|
||||||
@ -117,7 +117,7 @@ class Jieba(BaseKeyword):
|
|||||||
return documents
|
return documents
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||||
with redis_client.lock(lock_name, timeout=600):
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||||
if dataset_keyword_table:
|
if dataset_keyword_table:
|
||||||
|
|||||||
@ -203,9 +203,9 @@ class BaiduVector(BaseVector):
|
|||||||
|
|
||||||
def _create_table(self, dimension: int) -> None:
|
def _create_table(self, dimension: int) -> None:
|
||||||
# Try to grab distributed lock and create table
|
# Try to grab distributed lock and create table
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=60):
|
with redis_client.lock(lock_name, timeout=60):
|
||||||
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
table_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(table_exist_cache_key):
|
if redis_client.get(table_exist_cache_key):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -57,9 +57,9 @@ class ChromaVector(BaseVector):
|
|||||||
self.add_texts(texts, embeddings, **kwargs)
|
self.add_texts(texts, embeddings, **kwargs)
|
||||||
|
|
||||||
def create_collection(self, collection_name: str):
|
def create_collection(self, collection_name: str):
|
||||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
lock_name = f"vector_indexing_lock_{collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
self._client.get_or_create_collection(collection_name)
|
self._client.get_or_create_collection(collection_name)
|
||||||
|
|||||||
@ -74,9 +74,9 @@ class CouchbaseVector(BaseVector):
|
|||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def _create_collection(self, vector_length: int, uuid: str):
|
def _create_collection(self, vector_length: int, uuid: str):
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
if self._collection_exists(self._collection_name):
|
if self._collection_exists(self._collection_name):
|
||||||
@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector):
|
|||||||
try:
|
try:
|
||||||
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Failed to delete documents, ids: {ids}")
|
logger.exception("Failed to delete documents, ids: %s", ids)
|
||||||
|
|
||||||
def delete_by_document_id(self, document_id: str):
|
def delete_by_document_id(self, document_id: str):
|
||||||
query = f"""
|
query = f"""
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class ElasticSearchJaVector(ElasticSearchVector):
|
|||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
logger.info(f"Collection {self._collection_name} already exists.")
|
logger.info("Collection %s already exists.", self._collection_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
|
|||||||
@ -186,7 +186,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
logger.info(f"Collection {self._collection_name} already exists.")
|
logger.info("Collection %s already exists.", self._collection_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
|
|||||||
@ -164,7 +164,7 @@ class HuaweiCloudVector(BaseVector):
|
|||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
logger.info(f"Collection {self._collection_name} already exists.")
|
logger.info("Collection %s already exists.", self._collection_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
|
|||||||
@ -89,7 +89,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
timeout: int = 60,
|
timeout: int = 60,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
logger.info(f"Total documents to add: {len(documents)}")
|
logger.info("Total documents to add: %s", len(documents))
|
||||||
uuids = self._get_uuids(documents)
|
uuids = self._get_uuids(documents)
|
||||||
|
|
||||||
total_docs = len(documents)
|
total_docs = len(documents)
|
||||||
@ -147,7 +147,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Failed to process batch {batch_num + 1}")
|
logger.exception("Failed to process batch %s", batch_num + 1)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
@ -180,7 +180,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
|
|
||||||
# 1. First check if collection exists
|
# 1. First check if collection exists
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
logger.warning(f"Collection {self._collection_name} does not exist")
|
logger.warning("Collection %s does not exist", self._collection_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 2. Batch process deletions
|
# 2. Batch process deletions
|
||||||
@ -196,7 +196,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
logger.warning("DELETE BY ID: ID %s does not exist in the index.", id)
|
||||||
|
|
||||||
# 3. Perform bulk deletion if there are valid documents to delete
|
# 3. Perform bulk deletion if there are valid documents to delete
|
||||||
if actions:
|
if actions:
|
||||||
@ -209,9 +209,9 @@ class LindormVectorStore(BaseVector):
|
|||||||
doc_id = delete_error.get("_id")
|
doc_id = delete_error.get("_id")
|
||||||
|
|
||||||
if status == 404:
|
if status == 404:
|
||||||
logger.warning(f"Document not found for deletion: {doc_id}")
|
logger.warning("Document not found for deletion: %s", doc_id)
|
||||||
else:
|
else:
|
||||||
logger.exception(f"Error deleting document: {error}")
|
logger.exception("Error deleting document: %s", error)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
if self._using_ugc:
|
if self._using_ugc:
|
||||||
@ -225,7 +225,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
||||||
logger.info("Delete index success")
|
logger.info("Delete index success")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
logger.warning("Index '%s' does not exist. No deletion performed.", self._collection_name)
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
@ -257,7 +257,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
params["routing"] = self._routing # type: ignore
|
params["routing"] = self._routing # type: ignore
|
||||||
response = self._client.search(index=self._collection_name, body=query, params=params)
|
response = self._client.search(index=self._collection_name, body=query, params=params)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error executing vector search, query: {query}")
|
logger.exception("Error executing vector search, query: %s", query)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
docs_and_scores = []
|
docs_and_scores = []
|
||||||
@ -324,10 +324,10 @@ class LindormVectorStore(BaseVector):
|
|||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
logger.info(f"Collection {self._collection_name} already exists.")
|
logger.info("Collection %s already exists.", self._collection_name)
|
||||||
return
|
return
|
||||||
if self._client.indices.exists(index=self._collection_name):
|
if self._client.indices.exists(index=self._collection_name):
|
||||||
logger.info(f"{self._collection_name.lower()} already exists.")
|
logger.info("%s already exists.", self._collection_name.lower())
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
return
|
return
|
||||||
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
||||||
|
|||||||
@ -103,7 +103,7 @@ class MilvusVector(BaseVector):
|
|||||||
# For standard Milvus installations, check version number
|
# For standard Milvus installations, check version number
|
||||||
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
@ -289,9 +289,9 @@ class MilvusVector(BaseVector):
|
|||||||
"""
|
"""
|
||||||
Create a new collection in Milvus with the specified schema and index parameters.
|
Create a new collection in Milvus with the specified schema and index parameters.
|
||||||
"""
|
"""
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
# Grab the existing collection if it exists
|
# Grab the existing collection if it exists
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class MyScaleVector(BaseVector):
|
|||||||
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
|
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
|
||||||
|
|
||||||
def _create_collection(self, dimension: int):
|
def _create_collection(self, dimension: int):
|
||||||
logging.info(f"create MyScale collection {self._collection_name} with dimension {dimension}")
|
logging.info("create MyScale collection %s with dimension %s", self._collection_name, dimension)
|
||||||
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
||||||
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
||||||
sql = f"""
|
sql = f"""
|
||||||
@ -151,7 +151,7 @@ class MyScaleVector(BaseVector):
|
|||||||
for r in self._client.query(sql).named_results()
|
for r in self._client.query(sql).named_results()
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") # noqa:TRY401
|
logging.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
|
|||||||
@ -147,7 +147,7 @@ class OceanBaseVector(BaseVector):
|
|||||||
logger.debug("Current OceanBase version is %s", ob_version)
|
logger.debug("Current OceanBase version is %s", ob_version)
|
||||||
return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version
|
return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to check OceanBase version: {str(e)}. Disabling hybrid search.")
|
logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
@ -229,7 +229,7 @@ class OceanBaseVector(BaseVector):
|
|||||||
|
|
||||||
return docs
|
return docs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fulltext search: {str(e)}.")
|
logger.warning("Failed to fulltext search: %s.", str(e))
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
|||||||
@ -131,7 +131,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
index_name = self._collection_name.lower()
|
index_name = self._collection_name.lower()
|
||||||
if not self._client.indices.exists(index=index_name):
|
if not self._client.indices.exists(index=index_name):
|
||||||
logger.warning(f"Index {index_name} does not exist")
|
logger.warning("Index %s does not exist", index_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Obtaining All Actual Documents_ID
|
# Obtaining All Actual Documents_ID
|
||||||
@ -142,7 +142,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
if es_ids:
|
if es_ids:
|
||||||
actual_ids.extend(es_ids)
|
actual_ids.extend(es_ids)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Document with metadata doc_id {doc_id} not found for deletion")
|
logger.warning("Document with metadata doc_id %s not found for deletion", doc_id)
|
||||||
|
|
||||||
if actual_ids:
|
if actual_ids:
|
||||||
actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids]
|
actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids]
|
||||||
@ -155,9 +155,9 @@ class OpenSearchVector(BaseVector):
|
|||||||
doc_id = delete_error.get("_id")
|
doc_id = delete_error.get("_id")
|
||||||
|
|
||||||
if status == 404:
|
if status == 404:
|
||||||
logger.warning(f"Document not found for deletion: {doc_id}")
|
logger.warning("Document not found for deletion: %s", doc_id)
|
||||||
else:
|
else:
|
||||||
logger.exception(f"Error deleting document: {error}")
|
logger.exception("Error deleting document: %s", error)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
self._client.indices.delete(index=self._collection_name.lower())
|
self._client.indices.delete(index=self._collection_name.lower())
|
||||||
@ -198,7 +198,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
try:
|
try:
|
||||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error executing vector search, query: {query}")
|
logger.exception("Error executing vector search, query: %s", query)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
@ -242,7 +242,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
logger.info(f"Collection {self._collection_name.lower()} already exists.")
|
logger.info("Collection %s already exists.", self._collection_name.lower())
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self._client.indices.exists(index=self._collection_name.lower()):
|
if not self._client.indices.exists(index=self._collection_name.lower()):
|
||||||
@ -272,7 +272,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
|
logger.info("Creating OpenSearch index %s", self._collection_name.lower())
|
||||||
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
||||||
|
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|||||||
@ -82,9 +82,9 @@ class PGVectoRS(BaseVector):
|
|||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def create_collection(self, dimension: int):
|
def create_collection(self, dimension: int):
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
index_name = f"{self._collection_name}_embedding_index"
|
index_name = f"{self._collection_name}_embedding_index"
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class PGVector(BaseVector):
|
|||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||||
except psycopg2.errors.UndefinedTable:
|
except psycopg2.errors.UndefinedTable:
|
||||||
# table not exists
|
# table not exists
|
||||||
logging.warning(f"Table {self.table_name} not found, skipping delete operation.")
|
logging.warning("Table %s not found, skipping delete operation.", self.table_name)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@ -95,9 +95,9 @@ class QdrantVector(BaseVector):
|
|||||||
self.add_texts(texts, embeddings, **kwargs)
|
self.add_texts(texts, embeddings, **kwargs)
|
||||||
|
|
||||||
def create_collection(self, collection_name: str, vector_size: int):
|
def create_collection(self, collection_name: str, vector_size: int):
|
||||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
lock_name = f"vector_indexing_lock_{collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
collection_name = collection_name or uuid.uuid4().hex
|
collection_name = collection_name or uuid.uuid4().hex
|
||||||
|
|||||||
@ -70,9 +70,9 @@ class RelytVector(BaseVector):
|
|||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def create_collection(self, dimension: int):
|
def create_collection(self, dimension: int):
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
index_name = f"{self._collection_name}_embedding_index"
|
index_name = f"{self._collection_name}_embedding_index"
|
||||||
|
|||||||
@ -142,7 +142,7 @@ class TableStoreVector(BaseVector):
|
|||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
logging.info(f"Collection {self._collection_name} already exists.")
|
logging.info("Collection %s already exists.", self._collection_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._create_table_if_not_exist()
|
self._create_table_if_not_exist()
|
||||||
|
|||||||
@ -92,9 +92,9 @@ class TencentVector(BaseVector):
|
|||||||
|
|
||||||
def _create_collection(self, dimension: int) -> None:
|
def _create_collection(self, dimension: int) -> None:
|
||||||
self._dimension = dimension
|
self._dimension = dimension
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -104,9 +104,9 @@ class TidbOnQdrantVector(BaseVector):
|
|||||||
self.add_texts(texts, embeddings, **kwargs)
|
self.add_texts(texts, embeddings, **kwargs)
|
||||||
|
|
||||||
def create_collection(self, collection_name: str, vector_size: int):
|
def create_collection(self, collection_name: str, vector_size: int):
|
||||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
lock_name = f"vector_indexing_lock_{collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
collection_name = collection_name or uuid.uuid4().hex
|
collection_name = collection_name or uuid.uuid4().hex
|
||||||
|
|||||||
@ -91,9 +91,9 @@ class TiDBVector(BaseVector):
|
|||||||
|
|
||||||
def _create_collection(self, dimension: int):
|
def _create_collection(self, dimension: int):
|
||||||
logger.info("_create_collection, collection_name " + self._collection_name)
|
logger.info("_create_collection, collection_name " + self._collection_name)
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
tidb_dist_func = self._get_distance_func()
|
tidb_dist_func = self._get_distance_func()
|
||||||
@ -192,7 +192,7 @@ class TiDBVector(BaseVector):
|
|||||||
query_vector_str = ", ".join(format(x) for x in query_vector)
|
query_vector_str = ", ".join(format(x) for x in query_vector)
|
||||||
query_vector_str = "[" + query_vector_str + "]"
|
query_vector_str = "[" + query_vector_str + "]"
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}"
|
"_collection_name: %s, score_threshold: %s, distance: %s", self._collection_name, score_threshold, distance
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
|
|||||||
@ -178,19 +178,19 @@ class Vector:
|
|||||||
def create(self, texts: Optional[list] = None, **kwargs):
|
def create(self, texts: Optional[list] = None, **kwargs):
|
||||||
if texts:
|
if texts:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
logger.info(f"start embedding {len(texts)} texts {start}")
|
logger.info("start embedding %s texts %s", len(texts), start)
|
||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
total_batches = len(texts) + batch_size - 1
|
total_batches = len(texts) + batch_size - 1
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
batch = texts[i : i + batch_size]
|
batch = texts[i : i + batch_size]
|
||||||
batch_start = time.time()
|
batch_start = time.time()
|
||||||
logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
|
logger.info("Processing batch %s/%s (%s texts)", i // batch_size + 1, total_batches, len(batch))
|
||||||
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
|
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
|
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
|
||||||
)
|
)
|
||||||
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||||
logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")
|
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], **kwargs):
|
def add_texts(self, documents: list[Document], **kwargs):
|
||||||
if kwargs.get("duplicate_check", False):
|
if kwargs.get("duplicate_check", False):
|
||||||
@ -219,7 +219,7 @@ class Vector:
|
|||||||
self._vector_processor.delete()
|
self._vector_processor.delete()
|
||||||
# delete collection redis cache
|
# delete collection redis cache
|
||||||
if self._vector_processor.collection_name:
|
if self._vector_processor.collection_name:
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._vector_processor.collection_name}"
|
||||||
redis_client.delete(collection_exist_cache_key)
|
redis_client.delete(collection_exist_cache_key)
|
||||||
|
|
||||||
def _get_embeddings(self) -> Embeddings:
|
def _get_embeddings(self) -> Embeddings:
|
||||||
|
|||||||
@ -92,9 +92,9 @@ class WeaviateVector(BaseVector):
|
|||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def _create_collection(self):
|
def _create_collection(self):
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
schema = self._default_schema(self._collection_name)
|
schema = self._default_schema(self._collection_name)
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class DatasetDocumentStore:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dateset_id(self) -> Any:
|
def dataset_id(self) -> Any:
|
||||||
return self._dataset.id
|
return self._dataset.id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -123,13 +123,13 @@ class DatasetDocumentStore:
|
|||||||
db.session.flush()
|
db.session.flush()
|
||||||
if save_child:
|
if save_child:
|
||||||
if doc.children:
|
if doc.children:
|
||||||
for postion, child in enumerate(doc.children, start=1):
|
for position, child in enumerate(doc.children, start=1):
|
||||||
child_segment = ChildChunk(
|
child_segment = ChildChunk(
|
||||||
tenant_id=self._dataset.tenant_id,
|
tenant_id=self._dataset.tenant_id,
|
||||||
dataset_id=self._dataset.id,
|
dataset_id=self._dataset.id,
|
||||||
document_id=self._document_id,
|
document_id=self._document_id,
|
||||||
segment_id=segment_document.id,
|
segment_id=segment_document.id,
|
||||||
position=postion,
|
position=position,
|
||||||
index_node_id=child.metadata.get("doc_id"),
|
index_node_id=child.metadata.get("doc_id"),
|
||||||
index_node_hash=child.metadata.get("doc_hash"),
|
index_node_hash=child.metadata.get("doc_hash"),
|
||||||
content=child.page_content,
|
content=child.page_content,
|
||||||
|
|||||||
@ -69,7 +69,7 @@ class CacheEmbedding(Embeddings):
|
|||||||
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
||||||
if np.isnan(normalized_embedding).any():
|
if np.isnan(normalized_embedding).any():
|
||||||
# for issue #11827 float values are not json compliant
|
# for issue #11827 float values are not json compliant
|
||||||
logger.warning(f"Normalized embedding is nan: {normalized_embedding}")
|
logger.warning("Normalized embedding is nan: %s", normalized_embedding)
|
||||||
continue
|
continue
|
||||||
embedding_queue_embeddings.append(normalized_embedding)
|
embedding_queue_embeddings.append(normalized_embedding)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
@ -122,7 +122,7 @@ class CacheEmbedding(Embeddings):
|
|||||||
raise ValueError("Normalized embedding is nan please try again")
|
raise ValueError("Normalized embedding is nan please try again")
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
|
logging.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text))
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -136,7 +136,9 @@ class CacheEmbedding(Embeddings):
|
|||||||
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
|
logging.exception(
|
||||||
|
"Failed to add embedding to redis for the text '%s...(%s chars)'", text[:10], len(text)
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
return embedding_results # type: ignore
|
return embedding_results # type: ignore
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
|
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
|
||||||
if total > self._chunk_size:
|
if total > self._chunk_size:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Created a chunk of size {total}, which is longer than the specified {self._chunk_size}"
|
"Created a chunk of size %s, which is longer than the specified %s", total, self._chunk_size
|
||||||
)
|
)
|
||||||
if len(current_doc) > 0:
|
if len(current_doc) > 0:
|
||||||
doc = self._join_docs(current_doc, separator)
|
doc = self._join_docs(current_doc, separator)
|
||||||
|
|||||||
@ -153,7 +153,7 @@ class DifyCoreRepositoryFactory:
|
|||||||
RepositoryImportError: If the configured repository cannot be created
|
RepositoryImportError: If the configured repository cannot be created
|
||||||
"""
|
"""
|
||||||
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
|
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
|
||||||
logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
|
logger.debug("Creating WorkflowExecutionRepository from: %s", class_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = cls._import_class(class_path)
|
repository_class = cls._import_class(class_path)
|
||||||
@ -199,7 +199,7 @@ class DifyCoreRepositoryFactory:
|
|||||||
RepositoryImportError: If the configured repository cannot be created
|
RepositoryImportError: If the configured repository cannot be created
|
||||||
"""
|
"""
|
||||||
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
||||||
logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
|
logger.debug("Creating WorkflowNodeExecutionRepository from: %s", class_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = cls._import_class(class_path)
|
repository_class = cls._import_class(class_path)
|
||||||
|
|||||||
@ -203,5 +203,5 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Update the in-memory cache for faster subsequent lookups
|
# Update the in-memory cache for faster subsequent lookups
|
||||||
logger.debug(f"Updating cache for execution_id: {db_model.id}")
|
logger.debug("Updating cache for execution_id: %s", db_model.id)
|
||||||
self._execution_cache[db_model.id] = db_model
|
self._execution_cache[db_model.id] = db_model
|
||||||
|
|||||||
@ -215,7 +215,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
# Update the in-memory cache for faster subsequent lookups
|
# Update the in-memory cache for faster subsequent lookups
|
||||||
# Only cache if we have a node_execution_id to use as the cache key
|
# Only cache if we have a node_execution_id to use as the cache key
|
||||||
if db_model.node_execution_id:
|
if db_model.node_execution_id:
|
||||||
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id)
|
||||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||||
|
|
||||||
def get_db_models_by_workflow_run(
|
def get_db_models_by_workflow_run(
|
||||||
|
|||||||
@ -206,7 +206,7 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
builtin_provider = None
|
builtin_provider = None
|
||||||
logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
|
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||||
# if the provider has been deleted, raise an error
|
# if the provider has been deleted, raise an error
|
||||||
if builtin_provider is None:
|
if builtin_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||||
@ -237,7 +237,7 @@ class ToolManager:
|
|||||||
if builtin_provider is None:
|
if builtin_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||||
|
|
||||||
encrypter, _ = create_provider_encrypter(
|
encrypter, cache = create_provider_encrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=[
|
config=[
|
||||||
x.to_basic_provider_config()
|
x.to_basic_provider_config()
|
||||||
@ -281,6 +281,7 @@ class ToolManager:
|
|||||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
decrypted_credentials = refreshed_credentials.credentials
|
decrypted_credentials = refreshed_credentials.credentials
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
@ -569,7 +570,7 @@ class ToolManager:
|
|||||||
yield provider
|
yield provider
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"load builtin provider {provider_path}")
|
logger.exception("load builtin provider %s", provider_path)
|
||||||
continue
|
continue
|
||||||
# set builtin providers loaded
|
# set builtin providers loaded
|
||||||
cls._builtin_providers_loaded = True
|
cls._builtin_providers_loaded = True
|
||||||
|
|||||||
@ -105,6 +105,29 @@ class ApiBasedToolSchemaParser:
|
|||||||
# overwrite the content
|
# overwrite the content
|
||||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||||
|
|
||||||
|
# handle allOf reference in schema properties
|
||||||
|
for prop_dict in root.get("properties", {}).values():
|
||||||
|
for item in prop_dict.get("allOf", []):
|
||||||
|
if "$ref" in item:
|
||||||
|
ref_schema = openapi
|
||||||
|
reference = item["$ref"].split("/")[1:]
|
||||||
|
for ref in reference:
|
||||||
|
ref_schema = ref_schema[ref]
|
||||||
|
else:
|
||||||
|
ref_schema = item
|
||||||
|
for key, value in ref_schema.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
if key not in prop_dict:
|
||||||
|
prop_dict[key] = []
|
||||||
|
# extends list field
|
||||||
|
if isinstance(prop_dict[key], list):
|
||||||
|
prop_dict[key].extend(value)
|
||||||
|
elif key not in prop_dict:
|
||||||
|
# add new field
|
||||||
|
prop_dict[key] = value
|
||||||
|
if "allOf" in prop_dict:
|
||||||
|
del prop_dict["allOf"]
|
||||||
|
|
||||||
# parse body parameters
|
# parse body parameters
|
||||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||||
|
|||||||
@ -55,7 +55,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
|||||||
main_content_type = mimetypes.guess_type(filename)[0]
|
main_content_type = mimetypes.guess_type(filename)[0]
|
||||||
|
|
||||||
if main_content_type not in supported_content_types:
|
if main_content_type not in supported_content_types:
|
||||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
return f"Unsupported content-type [{main_content_type}] of URL."
|
||||||
|
|
||||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||||
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
||||||
@ -67,7 +67,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
|||||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return "URL returned status code {}.".format(response.status_code)
|
return f"URL returned status code {response.status_code}."
|
||||||
|
|
||||||
# Detect encoding using chardet
|
# Detect encoding using chardet
|
||||||
detected_encoding = chardet.detect(response.content)
|
detected_encoding = chardet.detect(response.content)
|
||||||
|
|||||||
@ -194,7 +194,7 @@ class WorkflowTool(Tool):
|
|||||||
|
|
||||||
files.append(file_dict)
|
files.append(file_dict)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Failed to transform file {file}")
|
logger.exception("Failed to transform file %s", file)
|
||||||
else:
|
else:
|
||||||
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
|
parameters_result[parameter.name] = tool_parameters.get(parameter.name)
|
||||||
|
|
||||||
|
|||||||
@ -238,13 +238,13 @@ class GraphEngine:
|
|||||||
while True:
|
while True:
|
||||||
# max steps reached
|
# max steps reached
|
||||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||||
raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps))
|
raise GraphRunFailedError(f"Max steps {self.max_execution_steps} reached.")
|
||||||
|
|
||||||
# or max execution time reached
|
# or max execution time reached
|
||||||
if self._is_timed_out(
|
if self._is_timed_out(
|
||||||
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
|
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
|
||||||
):
|
):
|
||||||
raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time))
|
raise GraphRunFailedError(f"Max execution time {self.max_execution_time}s reached.")
|
||||||
|
|
||||||
# init route node state
|
# init route node state
|
||||||
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
|
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
|
||||||
@ -377,7 +377,7 @@ class GraphEngine:
|
|||||||
|
|
||||||
edge = cast(GraphEdge, sub_edge_mappings[0])
|
edge = cast(GraphEdge, sub_edge_mappings[0])
|
||||||
if edge.run_condition is None:
|
if edge.run_condition is None:
|
||||||
logger.warning(f"Edge {edge.target_node_id} run condition is None")
|
logger.warning("Edge %s run condition is None", edge.target_node_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
result = ConditionManager.get_condition_handler(
|
result = ConditionManager.get_condition_handler(
|
||||||
@ -848,7 +848,7 @@ class GraphEngine:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Node {node.title} run failed")
|
logger.exception("Node %s run failed", node.title)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||||
|
|||||||
@ -50,6 +50,7 @@ from .exc import (
|
|||||||
AgentInputTypeError,
|
AgentInputTypeError,
|
||||||
AgentInvocationError,
|
AgentInvocationError,
|
||||||
AgentMessageTransformError,
|
AgentMessageTransformError,
|
||||||
|
AgentNodeError,
|
||||||
AgentVariableNotFoundError,
|
AgentVariableNotFoundError,
|
||||||
AgentVariableTypeError,
|
AgentVariableTypeError,
|
||||||
ToolFileNotFoundError,
|
ToolFileNotFoundError,
|
||||||
@ -593,7 +594,14 @@ class AgentNode(BaseNode):
|
|||||||
variables[variable_name] = variable_value
|
variables[variable_name] = variable_value
|
||||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||||
assert message.meta is not None
|
assert message.meta is not None
|
||||||
assert isinstance(message.meta, File)
|
assert isinstance(message.meta, dict)
|
||||||
|
# Validate that meta contains a 'file' key
|
||||||
|
if "file" not in message.meta:
|
||||||
|
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||||
|
|
||||||
|
# Validate that the file is an instance of File
|
||||||
|
if not isinstance(message.meta["file"], File):
|
||||||
|
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||||
files.append(message.meta["file"])
|
files.append(message.meta["file"])
|
||||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||||
|
|||||||
@ -36,7 +36,7 @@ class StreamProcessor(ABC):
|
|||||||
reachable_node_ids: list[str] = []
|
reachable_node_ids: list[str] = []
|
||||||
unreachable_first_node_ids: list[str] = []
|
unreachable_first_node_ids: list[str] = []
|
||||||
if finished_node_id not in self.graph.edge_mapping:
|
if finished_node_id not in self.graph.edge_mapping:
|
||||||
logger.warning(f"node {finished_node_id} has no edge mapping")
|
logger.warning("node %s has no edge mapping", finished_node_id)
|
||||||
return
|
return
|
||||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||||
if (
|
if (
|
||||||
|
|||||||
@ -65,7 +65,7 @@ class BaseNode:
|
|||||||
try:
|
try:
|
||||||
result = self._run()
|
result = self._run()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Node {self.node_id} failed to run")
|
logger.exception("Node %s failed to run", self.node_id)
|
||||||
result = NodeRunResult(
|
result = NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
|
|||||||
@ -47,7 +47,6 @@ class DatasourceNode(BaseNode):
|
|||||||
_node_data: DatasourceNodeData
|
_node_data: DatasourceNodeData
|
||||||
_node_type = NodeType.DATASOURCE
|
_node_type = NodeType.DATASOURCE
|
||||||
|
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||||
self._node_data = DatasourceNodeData.model_validate(data)
|
self._node_data = DatasourceNodeData.model_validate(data)
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user