mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 12:37:20 +08:00
Merge branch 'main' into feat/new-login
* main: chore: refurbish Python code by applying refurb linter rules (#8296) chore: apply ruff E501 line-too-long linter rule (#8275) fix(workflow): missing content in the answer node stream output during iterations (#8292) chore: cleanup ruff flake8-simplify linter rules (#8286) fix: markdown paragraph margin (#8289)
This commit is contained in:
commit
619eeec9b1
@ -411,7 +411,8 @@ def migrate_knowledge_vector_database():
|
|||||||
try:
|
try:
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(
|
click.style(
|
||||||
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
|
f"Start to created vector index with {len(documents)} documents of {segments_count}"
|
||||||
|
f" segments for dataset {dataset.id}.",
|
||||||
fg="green",
|
fg="green",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):
|
|||||||
|
|
||||||
site = app.site
|
site = app.site
|
||||||
if not site:
|
if not site:
|
||||||
desc = args["desc"] if args["desc"] else ""
|
desc = args["desc"] or ""
|
||||||
copy_right = args["copyright"] if args["copyright"] else ""
|
copy_right = args["copyright"] or ""
|
||||||
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
|
privacy_policy = args["privacy_policy"] or ""
|
||||||
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
|
custom_disclaimer = args["custom_disclaimer"] or ""
|
||||||
else:
|
else:
|
||||||
desc = site.description if site.description else args["desc"] if args["desc"] else ""
|
desc = site.description or args["desc"] or ""
|
||||||
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
|
copy_right = site.copyright or args["copyright"] or ""
|
||||||
privacy_policy = (
|
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||||
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
|
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||||
)
|
|
||||||
custom_disclaimer = (
|
|
||||||
site.custom_disclaimer
|
|
||||||
if site.custom_disclaimer
|
|
||||||
else args["custom_disclaimer"]
|
|
||||||
if args["custom_disclaimer"]
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||||
|
|
||||||
|
|||||||
@ -99,14 +99,10 @@ class ChatMessageTextApi(Resource):
|
|||||||
and app_model.workflow.features_dict
|
and app_model.workflow.features_dict
|
||||||
):
|
):
|
||||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
voice = args.get("voice") or text_to_speech.get("voice")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = (
|
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||||
args.get("voice")
|
|
||||||
if args.get("voice")
|
|
||||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||||
|
|||||||
@ -29,10 +29,13 @@ class DailyMessageStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
FROM messages where app_id = :app_id
|
COUNT(*) AS message_count
|
||||||
"""
|
FROM
|
||||||
|
messages
|
||||||
|
WHERE
|
||||||
|
app_id = :app_id"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
@ -45,7 +48,7 @@ class DailyMessageStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -55,10 +58,10 @@ class DailyMessageStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -83,10 +86,13 @@ class DailyConversationStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
FROM messages where app_id = :app_id
|
COUNT(DISTINCT messages.conversation_id) AS conversation_count
|
||||||
"""
|
FROM
|
||||||
|
messages
|
||||||
|
WHERE
|
||||||
|
app_id = :app_id"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
@ -99,7 +105,7 @@ class DailyConversationStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -109,10 +115,10 @@ class DailyConversationStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -137,10 +143,13 @@ class DailyTerminalsStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
FROM messages where app_id = :app_id
|
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||||
"""
|
FROM
|
||||||
|
messages
|
||||||
|
WHERE
|
||||||
|
app_id = :app_id"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
@ -153,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -163,10 +172,10 @@ class DailyTerminalsStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -191,12 +200,14 @@ class DailyTokenCostStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
|
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||||
sum(total_price) as total_price
|
SUM(total_price) AS total_price
|
||||||
FROM messages where app_id = :app_id
|
FROM
|
||||||
"""
|
messages
|
||||||
|
WHERE
|
||||||
|
app_id = :app_id"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
@ -209,7 +220,7 @@ class DailyTokenCostStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -219,10 +230,10 @@ class DailyTokenCostStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -249,12 +260,22 @@ class AverageSessionInteractionStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
sql_query = """SELECT
|
||||||
AVG(subquery.message_count) AS interactions
|
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
AVG(subquery.message_count) AS interactions
|
||||||
FROM conversations c
|
FROM
|
||||||
JOIN messages m ON c.id = m.conversation_id
|
(
|
||||||
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
SELECT
|
||||||
|
m.conversation_id,
|
||||||
|
COUNT(m.id) AS message_count
|
||||||
|
FROM
|
||||||
|
conversations c
|
||||||
|
JOIN
|
||||||
|
messages m
|
||||||
|
ON c.id = m.conversation_id
|
||||||
|
WHERE
|
||||||
|
c.override_model_configs IS NULL
|
||||||
|
AND c.app_id = :app_id"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
@ -267,7 +288,7 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and c.created_at >= :start"
|
sql_query += " AND c.created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -277,14 +298,19 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and c.created_at < :end"
|
sql_query += " AND c.created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += """
|
sql_query += """
|
||||||
GROUP BY m.conversation_id) subquery
|
GROUP BY m.conversation_id
|
||||||
LEFT JOIN conversations c on c.id=subquery.conversation_id
|
) subquery
|
||||||
GROUP BY date
|
LEFT JOIN
|
||||||
ORDER BY date"""
|
conversations c
|
||||||
|
ON c.id = subquery.conversation_id
|
||||||
|
GROUP BY
|
||||||
|
date
|
||||||
|
ORDER BY
|
||||||
|
date"""
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -311,13 +337,17 @@ class UserSatisfactionRateStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
COUNT(m.id) AS message_count,
|
||||||
FROM messages m
|
COUNT(mf.id) AS feedback_count
|
||||||
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
|
FROM
|
||||||
WHERE m.app_id = :app_id
|
messages m
|
||||||
"""
|
LEFT JOIN
|
||||||
|
message_feedbacks mf
|
||||||
|
ON mf.message_id=m.id AND mf.rating='like'
|
||||||
|
WHERE
|
||||||
|
m.app_id = :app_id"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
@ -330,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and m.created_at >= :start"
|
sql_query += " AND m.created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -340,10 +370,10 @@ class UserSatisfactionRateStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and m.created_at < :end"
|
sql_query += " AND m.created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -373,12 +403,13 @@ class AverageResponseTimeStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
AVG(provider_response_latency) as latency
|
AVG(provider_response_latency) AS latency
|
||||||
FROM messages
|
FROM
|
||||||
WHERE app_id = :app_id
|
messages
|
||||||
"""
|
WHERE
|
||||||
|
app_id = :app_id"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
@ -391,7 +422,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -401,10 +432,10 @@ class AverageResponseTimeStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -429,13 +460,16 @@ class TokensPerSecondStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
sql_query = """SELECT
|
||||||
CASE
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
|
CASE
|
||||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||||
END as tokens_per_second
|
END as tokens_per_second
|
||||||
FROM messages
|
FROM
|
||||||
WHERE app_id = :app_id"""
|
messages
|
||||||
|
WHERE
|
||||||
|
app_id = :app_id"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
@ -448,7 +482,7 @@ WHERE app_id = :app_id"""
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -458,10 +492,10 @@ WHERE app_id = :app_id"""
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
|
|||||||
@ -30,12 +30,14 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
FROM workflow_runs
|
COUNT(id) AS runs
|
||||||
WHERE app_id = :app_id
|
FROM
|
||||||
AND triggered_from = :triggered_from
|
workflow_runs
|
||||||
"""
|
WHERE
|
||||||
|
app_id = :app_id
|
||||||
|
AND triggered_from = :triggered_from"""
|
||||||
arg_dict = {
|
arg_dict = {
|
||||||
"tz": account.timezone,
|
"tz": account.timezone,
|
||||||
"app_id": app_model.id,
|
"app_id": app_model.id,
|
||||||
@ -52,7 +54,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -62,10 +64,10 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -90,12 +92,14 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
FROM workflow_runs
|
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||||
WHERE app_id = :app_id
|
FROM
|
||||||
AND triggered_from = :triggered_from
|
workflow_runs
|
||||||
"""
|
WHERE
|
||||||
|
app_id = :app_id
|
||||||
|
AND triggered_from = :triggered_from"""
|
||||||
arg_dict = {
|
arg_dict = {
|
||||||
"tz": account.timezone,
|
"tz": account.timezone,
|
||||||
"app_id": app_model.id,
|
"app_id": app_model.id,
|
||||||
@ -112,7 +116,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -122,10 +126,10 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -150,14 +154,14 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT
|
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
SUM(workflow_runs.total_tokens) AS token_count
|
||||||
SUM(workflow_runs.total_tokens) as token_count
|
FROM
|
||||||
FROM workflow_runs
|
workflow_runs
|
||||||
WHERE app_id = :app_id
|
WHERE
|
||||||
AND triggered_from = :triggered_from
|
app_id = :app_id
|
||||||
"""
|
AND triggered_from = :triggered_from"""
|
||||||
arg_dict = {
|
arg_dict = {
|
||||||
"tz": account.timezone,
|
"tz": account.timezone,
|
||||||
"app_id": app_model.id,
|
"app_id": app_model.id,
|
||||||
@ -174,7 +178,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_datetime_timezone = timezone.localize(start_datetime)
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if args["end"]:
|
||||||
@ -184,10 +188,10 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query += " and created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
sql_query += " GROUP BY date order by date"
|
sql_query += " GROUP BY date ORDER BY date"
|
||||||
|
|
||||||
response_data = []
|
response_data = []
|
||||||
|
|
||||||
@ -217,23 +221,27 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """SELECT
|
||||||
SELECT
|
AVG(sub.interactions) AS interactions,
|
||||||
AVG(sub.interactions) as interactions,
|
sub.date
|
||||||
sub.date
|
FROM
|
||||||
FROM
|
(
|
||||||
(SELECT
|
SELECT
|
||||||
date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
c.created_by,
|
c.created_by,
|
||||||
COUNT(c.id) AS interactions
|
COUNT(c.id) AS interactions
|
||||||
FROM workflow_runs c
|
FROM
|
||||||
WHERE c.app_id = :app_id
|
workflow_runs c
|
||||||
AND c.triggered_from = :triggered_from
|
WHERE
|
||||||
{{start}}
|
c.app_id = :app_id
|
||||||
{{end}}
|
AND c.triggered_from = :triggered_from
|
||||||
GROUP BY date, c.created_by) sub
|
{{start}}
|
||||||
GROUP BY sub.date
|
{{end}}
|
||||||
"""
|
GROUP BY
|
||||||
|
date, c.created_by
|
||||||
|
) sub
|
||||||
|
GROUP BY
|
||||||
|
sub.date"""
|
||||||
arg_dict = {
|
arg_dict = {
|
||||||
"tz": account.timezone,
|
"tz": account.timezone,
|
||||||
"app_id": app_model.id,
|
"app_id": app_model.id,
|
||||||
@ -262,7 +270,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
end_datetime_timezone = timezone.localize(end_datetime)
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||||
|
|
||||||
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
|
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
else:
|
else:
|
||||||
sql_query = sql_query.replace("{{end}}", "")
|
sql_query = sql_query.replace("{{end}}", "")
|
||||||
|
|||||||
@ -150,7 +150,8 @@ class EmailCodeLoginApi(Resource):
|
|||||||
)
|
)
|
||||||
except WorkSpaceNotAllowedCreateError:
|
except WorkSpaceNotAllowedCreateError:
|
||||||
return redirect(
|
return redirect(
|
||||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||||
|
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||||
)
|
)
|
||||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||||
|
|
||||||
|
|||||||
@ -96,7 +96,8 @@ class OAuthCallback(Resource):
|
|||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=WorkspaceNotFound")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=WorkspaceNotFound")
|
||||||
except WorkSpaceNotAllowedCreateError:
|
except WorkSpaceNotAllowedCreateError:
|
||||||
return redirect(
|
return redirect(
|
||||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||||
|
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check account status
|
# Check account status
|
||||||
@ -114,7 +115,8 @@ class OAuthCallback(Resource):
|
|||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=WorkspaceNotFound")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=WorkspaceNotFound")
|
||||||
except WorkSpaceNotAllowedCreateError:
|
except WorkSpaceNotAllowedCreateError:
|
||||||
return redirect(
|
return redirect(
|
||||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||||
|
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||||
)
|
)
|
||||||
|
|
||||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||||
@ -149,7 +151,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||||||
if not account:
|
if not account:
|
||||||
if not dify_config.ALLOW_REGISTER:
|
if not dify_config.ALLOW_REGISTER:
|
||||||
raise AccountNotFoundError()
|
raise AccountNotFoundError()
|
||||||
account_name = user_info.name if user_info.name else "Dify"
|
account_name = user_info.name or "Dify"
|
||||||
account = RegisterService.register(
|
account = RegisterService.register(
|
||||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||||
)
|
)
|
||||||
|
|||||||
@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
return {
|
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
|
||||||
"api_base_url": (
|
|
||||||
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
|
|
||||||
)
|
|
||||||
+ "/v1"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrievalSettingApi(Resource):
|
class DatasetRetrievalSettingApi(Resource):
|
||||||
|
|||||||
@ -86,14 +86,10 @@ class ChatTextApi(InstalledAppResource):
|
|||||||
and app_model.workflow.features_dict
|
and app_model.workflow.features_dict
|
||||||
):
|
):
|
||||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
voice = args.get("voice") or text_to_speech.get("voice")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = (
|
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||||
args.get("voice")
|
|
||||||
if args.get("voice")
|
|
||||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||||
|
|||||||
@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||||||
|
|
||||||
return ApiToolManageService.test_api_tool_preview(
|
return ApiToolManageService.test_api_tool_preview(
|
||||||
current_user.current_tenant_id,
|
current_user.current_tenant_id,
|
||||||
args["provider_name"] if args["provider_name"] else "",
|
args["provider_name"] or "",
|
||||||
args["tool_name"],
|
args["tool_name"],
|
||||||
args["credentials"],
|
args["credentials"],
|
||||||
args["parameters"],
|
args["parameters"],
|
||||||
|
|||||||
@ -64,7 +64,8 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||||
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
||||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
# The api of file upload is used in the multiple places,
|
||||||
|
# so we need to check the source of the request from datasets
|
||||||
source = request.args.get("source")
|
source = request.args.get("source")
|
||||||
if source == "datasets":
|
if source == "datasets":
|
||||||
abort(403, "The number of documents has reached the limit of your subscription.")
|
abort(403, "The number of documents has reached the limit of your subscription.")
|
||||||
|
|||||||
@ -84,14 +84,10 @@ class TextApi(Resource):
|
|||||||
and app_model.workflow.features_dict
|
and app_model.workflow.features_dict
|
||||||
):
|
):
|
||||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
voice = args.get("voice") or text_to_speech.get("voice")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = (
|
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||||
args.get("voice")
|
|
||||||
if args.get("voice")
|
|
||||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(
|
||||||
|
|||||||
@ -83,14 +83,10 @@ class TextApi(WebApiResource):
|
|||||||
and app_model.workflow.features_dict
|
and app_model.workflow.features_dict
|
||||||
):
|
):
|
||||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
voice = args.get("voice") or text_to_speech.get("voice")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = (
|
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||||
args.get("voice")
|
|
||||||
if args.get("voice")
|
|
||||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
|
|
||||||
|
|||||||
@ -80,7 +80,8 @@ def _validate_web_sso_token(decoded, system_features, app_code):
|
|||||||
if not source or source != "sso":
|
if not source or source != "sso":
|
||||||
raise WebSSOAuthRequiredError()
|
raise WebSSOAuthRequiredError()
|
||||||
|
|
||||||
# Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
|
# Check if SSO is not enforced for web, and if the token source is SSO,
|
||||||
|
# raise an error and redirect to normal passport login
|
||||||
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
|
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
|
||||||
source = decoded.get("token_source")
|
source = decoded.get("token_source")
|
||||||
if source and source == "sso":
|
if source and source == "sso":
|
||||||
|
|||||||
@ -256,7 +256,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
model=model_instance.model,
|
model=model_instance.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(content=final_answer),
|
message=AssistantPromptMessage(content=final_answer),
|
||||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||||
system_fingerprint="",
|
system_fingerprint="",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
|||||||
@ -298,7 +298,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
model=model_instance.model,
|
model=model_instance.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(content=final_answer),
|
message=AssistantPromptMessage(content=final_answer),
|
||||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||||
system_fingerprint="",
|
system_fingerprint="",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
|||||||
@ -41,7 +41,8 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
|
|||||||
{{historic_messages}}
|
{{historic_messages}}
|
||||||
Question: {{query}}
|
Question: {{query}}
|
||||||
{{agent_scratchpad}}
|
{{agent_scratchpad}}
|
||||||
Thought:"""
|
Thought:""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
||||||
Thought:"""
|
Thought:"""
|
||||||
@ -86,7 +87,8 @@ Action:
|
|||||||
```
|
```
|
||||||
|
|
||||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||||
|
|
||||||
|
|||||||
@ -161,7 +161,7 @@ class AppRunner:
|
|||||||
app_mode=AppMode.value_of(app_record.mode),
|
app_mode=AppMode.value_of(app_record.mode),
|
||||||
prompt_template_entity=prompt_template_entity,
|
prompt_template_entity=prompt_template_entity,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query if query else "",
|
query=query or "",
|
||||||
files=files,
|
files=files,
|
||||||
context=context,
|
context=context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
@ -189,7 +189,7 @@ class AppRunner:
|
|||||||
prompt_messages = prompt_transform.get_prompt(
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query if query else "",
|
query=query or "",
|
||||||
files=files,
|
files=files,
|
||||||
context=context,
|
context=context,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
@ -238,7 +238,7 @@ class AppRunner:
|
|||||||
model=app_generate_entity.model_conf.model,
|
model=app_generate_entity.model_conf.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(content=text),
|
message=AssistantPromptMessage(content=text),
|
||||||
usage=usage if usage else LLMUsage.empty_usage(),
|
usage=usage or LLMUsage.empty_usage(),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
PublishFrom.APPLICATION_MANAGER,
|
PublishFrom.APPLICATION_MANAGER,
|
||||||
@ -351,7 +351,7 @@ class AppRunner:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
app_config=app_generate_entity.app_config,
|
app_config=app_generate_entity.app_config,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query if query else "",
|
query=query or "",
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
trace_manager=app_generate_entity.trace_manager,
|
trace_manager=app_generate_entity.trace_manager,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -84,10 +84,12 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
if route_node_state.node_run_result:
|
if route_node_state.node_run_result:
|
||||||
node_run_result = route_node_state.node_run_result
|
node_run_result = route_node_state.node_run_result
|
||||||
self.print_text(
|
self.print_text(
|
||||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
|
f"Inputs: " f"{jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||||
|
color="green",
|
||||||
)
|
)
|
||||||
self.print_text(
|
self.print_text(
|
||||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
f"Process Data: "
|
||||||
|
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||||
color="green",
|
color="green",
|
||||||
)
|
)
|
||||||
self.print_text(
|
self.print_text(
|
||||||
@ -114,14 +116,17 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
node_run_result = route_node_state.node_run_result
|
node_run_result = route_node_state.node_run_result
|
||||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
self.print_text(f"Error: {node_run_result.error}", color="red")
|
||||||
self.print_text(
|
self.print_text(
|
||||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
|
f"Inputs: " f"" f"{jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||||
)
|
|
||||||
self.print_text(
|
|
||||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
self.print_text(
|
self.print_text(
|
||||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
|
f"Process Data: "
|
||||||
|
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
self.print_text(
|
||||||
|
f"Outputs: " f"{jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||||
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||||
|
|||||||
@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline:
|
|||||||
|
|
||||||
if isinstance(e, InvokeAuthorizationError):
|
if isinstance(e, InvokeAuthorizationError):
|
||||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
elif isinstance(e, InvokeError | ValueError):
|
||||||
err = e
|
err = e
|
||||||
else:
|
else:
|
||||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import importlib.util
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -63,8 +64,7 @@ class Extensible:
|
|||||||
|
|
||||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||||
if os.path.exists(builtin_file_path):
|
if os.path.exists(builtin_file_path):
|
||||||
with open(builtin_file_path, encoding="utf-8") as f:
|
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
||||||
position = int(f.read().strip())
|
|
||||||
position_map[extension_name] = position
|
position_map[extension_name] = position
|
||||||
|
|
||||||
if (extension_name + ".py") not in file_names:
|
if (extension_name + ".py") not in file_names:
|
||||||
|
|||||||
@ -188,7 +188,8 @@ class MessageFileParser:
|
|||||||
def _check_image_remote_url(self, url):
|
def _check_image_remote_url(self, url):
|
||||||
try:
|
try:
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||||
|
" Chrome/91.0.4472.124 Safari/537.36"
|
||||||
}
|
}
|
||||||
|
|
||||||
def is_s3_presigned_url(url):
|
def is_s3_presigned_url(url):
|
||||||
|
|||||||
@ -89,7 +89,8 @@ class CodeExecutor:
|
|||||||
raise CodeExecutionError("Code execution service is unavailable")
|
raise CodeExecutionError("Code execution service is unavailable")
|
||||||
elif response.status_code != 200:
|
elif response.status_code != 200:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
|
f"Failed to execute code, got status code {response.status_code},"
|
||||||
|
f" please check if the sandbox service is running"
|
||||||
)
|
)
|
||||||
except CodeExecutionError as e:
|
except CodeExecutionError as e:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@ -14,7 +14,10 @@ class ToolParameterCache:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
|
self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str
|
||||||
):
|
):
|
||||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
|
self.cache_key = (
|
||||||
|
f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||||
|
f":identity_id:{identity_id}"
|
||||||
|
)
|
||||||
|
|
||||||
def get(self) -> Optional[dict]:
|
def get(self) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -59,24 +59,27 @@ User Input: yo, 你今天咋样?
|
|||||||
}
|
}
|
||||||
|
|
||||||
User Input:
|
User Input:
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||||
"Please help me predict the three most likely questions that human would ask, "
|
"Please help me predict the three most likely questions that human would ask, "
|
||||||
"and keeping each question under 20 characters.\n"
|
"and keeping each question under 20 characters.\n"
|
||||||
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
|
"MAKE SURE your output is the SAME language as the Assistant's latest response"
|
||||||
|
"(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
|
||||||
"The output must be an array in JSON format following the specified schema:\n"
|
"The output must be an array in JSON format following the specified schema:\n"
|
||||||
'["question1","question2","question3"]\n'
|
'["question1","question2","question3"]\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
GENERATOR_QA_PROMPT = (
|
GENERATOR_QA_PROMPT = (
|
||||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step."
|
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||||
|
" in the long text. Please think step by step."
|
||||||
"Step 1: Understand and summarize the main content of this text.\n"
|
"Step 1: Understand and summarize the main content of this text.\n"
|
||||||
"Step 2: What key information or concepts are mentioned in this text?\n"
|
"Step 2: What key information or concepts are mentioned in this text?\n"
|
||||||
"Step 3: Decompose or combine multiple pieces of information and concepts.\n"
|
"Step 3: Decompose or combine multiple pieces of information and concepts.\n"
|
||||||
"Step 4: Generate questions and answers based on these key information and concepts.\n"
|
"Step 4: Generate questions and answers based on these key information and concepts.\n"
|
||||||
"<Constraints> The questions should be clear and detailed, and the answers should be detailed and complete. "
|
"<Constraints> The questions should be clear and detailed, and the answers should be detailed and complete. "
|
||||||
"You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n"
|
"You must answer in {language}, in a style that is clear and detailed in {language}."
|
||||||
|
" No language other than {language} should be used. \n"
|
||||||
"<Format> Use the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
"<Format> Use the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||||
"<QA Pairs>"
|
"<QA Pairs>"
|
||||||
)
|
)
|
||||||
@ -94,7 +97,7 @@ Based on task description, please create a well-structured prompt template that
|
|||||||
- Use the same language as task description.
|
- Use the same language as task description.
|
||||||
- Output in ``` xml ``` and start with <instruction>
|
- Output in ``` xml ``` and start with <instruction>
|
||||||
Please generate the full prompt template with at least 300 words and output only the prompt template.
|
Please generate the full prompt template with at least 300 words and output only the prompt template.
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """
|
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """
|
||||||
Here is a task description for which I would like you to create a high-quality prompt template for:
|
Here is a task description for which I would like you to create a high-quality prompt template for:
|
||||||
@ -109,7 +112,7 @@ Based on task description, please create a well-structured prompt template that
|
|||||||
- Use the same language as task description.
|
- Use the same language as task description.
|
||||||
- Output in ``` xml ``` and start with <instruction>
|
- Output in ``` xml ``` and start with <instruction>
|
||||||
Please generate the full prompt template and output only the prompt template.
|
Please generate the full prompt template and output only the prompt template.
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """
|
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """
|
||||||
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
|
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
|
||||||
@ -134,7 +137,7 @@ Inside <text></text> XML tags, there is a text that I should extract parameters
|
|||||||
|
|
||||||
### Answer
|
### Answer
|
||||||
I should always output a valid list. Output nothing other than the list of variable_name. Output an empty list if there is no variable name in input text.
|
I should always output a valid list. Output nothing other than the list of variable_name. Output an empty list if there is no variable name in input text.
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """
|
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """
|
||||||
<instruction>
|
<instruction>
|
||||||
@ -150,4 +153,4 @@ Welcome! I'm here to assist you with any questions or issues you might have with
|
|||||||
Here is the task description: {{INPUT_TEXT}}
|
Here is the task description: {{INPUT_TEXT}}
|
||||||
|
|
||||||
You just need to generate the output
|
You just need to generate the output
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class TokenBufferMemory:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if message_limit and message_limit > 0:
|
if message_limit and message_limit > 0:
|
||||||
message_limit = message_limit if message_limit <= 500 else 500
|
message_limit = min(message_limit, 500)
|
||||||
else:
|
else:
|
||||||
message_limit = 500
|
message_limit = 500
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,11 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||||||
},
|
},
|
||||||
"type": "float",
|
"type": "float",
|
||||||
"help": {
|
"help": {
|
||||||
"en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
|
"en_US": "Controls randomness. Lower temperature results in less random completions."
|
||||||
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
|
" As the temperature approaches zero, the model will become deterministic and repetitive."
|
||||||
|
" Higher temperature results in more random completions.",
|
||||||
|
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。"
|
||||||
|
"较高的温度会导致更多的随机完成。",
|
||||||
},
|
},
|
||||||
"required": False,
|
"required": False,
|
||||||
"default": 0.0,
|
"default": 0.0,
|
||||||
@ -24,7 +27,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||||||
},
|
},
|
||||||
"type": "float",
|
"type": "float",
|
||||||
"help": {
|
"help": {
|
||||||
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
|
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
|
||||||
|
" are considered.",
|
||||||
"zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
|
"zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
|
||||||
},
|
},
|
||||||
"required": False,
|
"required": False,
|
||||||
@ -88,7 +92,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||||||
},
|
},
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"help": {
|
"help": {
|
||||||
"en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.",
|
"en_US": "Specifies the upper limit on the length of generated results."
|
||||||
|
" If the generated results are truncated, you can increase this parameter.",
|
||||||
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
|
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
|
||||||
},
|
},
|
||||||
"required": False,
|
"required": False,
|
||||||
@ -104,7 +109,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||||||
},
|
},
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"help": {
|
"help": {
|
||||||
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
|
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
|
||||||
|
" such as JSON, XML, etc.",
|
||||||
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
|
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
|
||||||
},
|
},
|
||||||
"required": False,
|
"required": False,
|
||||||
|
|||||||
@ -72,7 +72,9 @@ class AIModel(ABC):
|
|||||||
if isinstance(error, tuple(model_errors)):
|
if isinstance(error, tuple(model_errors)):
|
||||||
if invoke_error == InvokeAuthorizationError:
|
if invoke_error == InvokeAuthorizationError:
|
||||||
return invoke_error(
|
return invoke_error(
|
||||||
description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. "
|
description=(
|
||||||
|
f"[{provider_name}] Incorrect model credentials provided, please check and try again."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
|
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
|
||||||
|
|||||||
@ -187,7 +187,7 @@ if you are not sure about the structure.
|
|||||||
<instructions>
|
<instructions>
|
||||||
{{instructions}}
|
{{instructions}}
|
||||||
</instructions>
|
</instructions>
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
code_block = model_parameters.get("response_format", "")
|
code_block = model_parameters.get("response_format", "")
|
||||||
if not code_block:
|
if not code_block:
|
||||||
@ -449,7 +449,7 @@ if you are not sure about the structure.
|
|||||||
model=real_model,
|
model=real_model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=prompt_message,
|
message=prompt_message,
|
||||||
usage=usage if usage else LLMUsage.empty_usage(),
|
usage=usage or LLMUsage.empty_usage(),
|
||||||
system_fingerprint=system_fingerprint,
|
system_fingerprint=system_fingerprint,
|
||||||
),
|
),
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
@ -830,7 +830,8 @@ if you are not sure about the structure.
|
|||||||
else:
|
else:
|
||||||
if parameter_value != round(parameter_value, parameter_rule.precision):
|
if parameter_value != round(parameter_value, parameter_rule.precision):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places."
|
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision}"
|
||||||
|
f" decimal places."
|
||||||
)
|
)
|
||||||
|
|
||||||
# validate parameter value range
|
# validate parameter value range
|
||||||
|
|||||||
@ -51,7 +51,7 @@ if you are not sure about the structure.
|
|||||||
<instructions>
|
<instructions>
|
||||||
{{instructions}}
|
{{instructions}}
|
||||||
</instructions>
|
</instructions>
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||||
@ -409,7 +409,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||||
chunk_text = chunk.delta.text if chunk.delta.text else ""
|
chunk_text = chunk.delta.text or ""
|
||||||
full_assistant_content += chunk_text
|
full_assistant_content += chunk_text
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
|
|||||||
@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
|||||||
model=real_model,
|
model=real_model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=prompt_message,
|
message=prompt_message,
|
||||||
usage=usage if usage else LLMUsage.empty_usage(),
|
usage=usage or LLMUsage.empty_usage(),
|
||||||
system_fingerprint=system_fingerprint,
|
system_fingerprint=system_fingerprint,
|
||||||
),
|
),
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
|
|||||||
@ -16,6 +16,15 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
|
|
||||||
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
|
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
|
||||||
|
|
||||||
|
AZURE_DEFAULT_PARAM_SEED_HELP = I18nObject(
|
||||||
|
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,"
|
||||||
|
"您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||||
|
en_US="If specified, model will make a best effort to sample deterministically,"
|
||||||
|
" such that repeated requests with the same seed and parameters should return the same result."
|
||||||
|
" Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter"
|
||||||
|
" to monitor changes in the backend.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
||||||
rule = ParameterRule(
|
rule = ParameterRule(
|
||||||
@ -229,10 +238,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -297,10 +303,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -365,10 +368,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -433,10 +433,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -502,10 +499,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -571,10 +565,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -650,10 +641,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -719,10 +707,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -788,10 +773,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -867,10 +849,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -936,10 +915,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1000,10 +976,7 @@ LLM_BASE_MODELS = [
|
|||||||
name="seed",
|
name="seed",
|
||||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
type="int",
|
type="int",
|
||||||
help=I18nObject(
|
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||||
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
|
||||||
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
|
||||||
),
|
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
min=0,
|
min=0,
|
||||||
|
|||||||
@ -225,7 +225,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
text = delta.text if delta.text else ""
|
text = delta.text or ""
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||||
|
|
||||||
full_text += text
|
full_text += text
|
||||||
@ -400,15 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||||
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
|
|
||||||
)
|
|
||||||
|
|
||||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
full_assistant_content += delta.delta.content or ""
|
||||||
|
|
||||||
real_model = chunk.model
|
real_model = chunk.model
|
||||||
system_fingerprint = chunk.system_fingerprint
|
system_fingerprint = chunk.system_fingerprint
|
||||||
completion += delta.delta.content if delta.delta.content else ""
|
completion += delta.delta.content or ""
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=real_model,
|
model=real_model,
|
||||||
|
|||||||
@ -84,7 +84,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
)
|
)
|
||||||
for i in range(len(sentences))
|
for i in range(len(sentences))
|
||||||
]
|
]
|
||||||
for index, future in enumerate(futures):
|
for future in futures:
|
||||||
yield from future.result().__enter__().iter_bytes(1024)
|
yield from future.result().__enter__().iter_bytes(1024)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -15,6 +15,7 @@ class BaichuanTokenizer:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_num_tokens(cls, text: str) -> int:
|
def _get_num_tokens(cls, text: str) -> int:
|
||||||
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
|
# tokens = number of Chinese characters + number of English words * 1.3
|
||||||
|
# (for estimation only, subject to actual return)
|
||||||
# https://platform.baichuan-ai.com/docs/text-Embedding
|
# https://platform.baichuan-ai.com/docs/text-Embedding
|
||||||
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
|
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class BaichuanModel:
|
|||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if model in self._model_mapping.keys():
|
if model in self._model_mapping:
|
||||||
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||||
# we need to rename it to res_format to get its value
|
# we need to rename it to res_format to get its value
|
||||||
if parameters.get("res_format") == "json_object":
|
if parameters.get("res_format") == "json_object":
|
||||||
@ -94,7 +94,7 @@ class BaichuanModel:
|
|||||||
timeout: int,
|
timeout: int,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> Union[Iterator, dict]:
|
) -> Union[Iterator, dict]:
|
||||||
if model in self._model_mapping.keys():
|
if model in self._model_mapping:
|
||||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||||
else:
|
else:
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
raise BadRequestError(f"Unknown model: {model}")
|
||||||
|
|||||||
@ -52,7 +52,7 @@ if you are not sure about the structure.
|
|||||||
<instructions>
|
<instructions>
|
||||||
{{instructions}}
|
{{instructions}}
|
||||||
</instructions>
|
</instructions>
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
@ -331,10 +331,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
elif "contentBlockDelta" in chunk:
|
elif "contentBlockDelta" in chunk:
|
||||||
delta = chunk["contentBlockDelta"]["delta"]
|
delta = chunk["contentBlockDelta"]["delta"]
|
||||||
if "text" in delta:
|
if "text" in delta:
|
||||||
chunk_text = delta["text"] if delta["text"] else ""
|
chunk_text = delta["text"] or ""
|
||||||
full_assistant_content += chunk_text
|
full_assistant_content += chunk_text
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=chunk_text if chunk_text else "",
|
content=chunk_text or "",
|
||||||
)
|
)
|
||||||
index = chunk["contentBlockDelta"]["contentBlockIndex"]
|
index = chunk["contentBlockDelta"]["contentBlockIndex"]
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@ -541,7 +541,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
"max_tokens": 32,
|
"max_tokens": 32,
|
||||||
}
|
}
|
||||||
elif "ai21" in model:
|
elif "ai21" in model:
|
||||||
# ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again.
|
# ValidationException: Malformed input request: #/temperature: expected type: Number,
|
||||||
|
# found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null,
|
||||||
|
# please reformat your input and try again.
|
||||||
required_params = {
|
required_params = {
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"topP": 0.9,
|
"topP": 0.9,
|
||||||
@ -749,7 +751,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
elif model_prefix == "cohere":
|
elif model_prefix == "cohere":
|
||||||
output = response_body.get("generations")[0].get("text")
|
output = response_body.get("generations")[0].get("text")
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else "")
|
completion_tokens = self.get_num_tokens(model, credentials, output or "")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||||
@ -826,7 +828,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=content_delta if content_delta else "",
|
content=content_delta or "",
|
||||||
)
|
)
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
|||||||
@ -302,11 +302,11 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
if delta.delta.function_call:
|
if delta.delta.function_call:
|
||||||
function_calls = [delta.delta.function_call]
|
function_calls = [delta.delta.function_call]
|
||||||
|
|
||||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
|
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
|
|||||||
@ -45,7 +45,7 @@ if you are not sure about the structure.
|
|||||||
<instructions>
|
<instructions>
|
||||||
{{instructions}}
|
{{instructions}}
|
||||||
</instructions>
|
</instructions>
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|||||||
@ -54,7 +54,8 @@ class TeiHelper:
|
|||||||
|
|
||||||
url = str(URL(server_url) / "info")
|
url = str(URL(server_url) / "info")
|
||||||
|
|
||||||
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
# this method is surrounded by a lock, and default requests may hang forever,
|
||||||
|
# so we just set a Adapter with max_retries=3
|
||||||
session = Session()
|
session = Session()
|
||||||
session.mount("http://", HTTPAdapter(max_retries=3))
|
session.mount("http://", HTTPAdapter(max_retries=3))
|
||||||
session.mount("https://", HTTPAdapter(max_retries=3))
|
session.mount("https://", HTTPAdapter(max_retries=3))
|
||||||
|
|||||||
@ -131,7 +131,8 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
|||||||
{
|
{
|
||||||
"Role": message.role.value,
|
"Role": message.role.value,
|
||||||
# fix set content = "" while tool_call request
|
# fix set content = "" while tool_call request
|
||||||
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time.
|
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter
|
||||||
|
# message:Messages Content and Contents not allowed empty at the same time.
|
||||||
"Content": " ", # message.content if (message.content is not None) else "",
|
"Content": " ", # message.content if (message.content is not None) else "",
|
||||||
"ToolCalls": dict_tool_calls,
|
"ToolCalls": dict_tool_calls,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -511,7 +511,7 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
|
assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
# temp_assistant_prompt_message is used to calculate usage
|
# temp_assistant_prompt_message is used to calculate usage
|
||||||
@ -578,11 +578,11 @@ class LocalAILanguageModel(LargeLanguageModel):
|
|||||||
if delta.delta.function_call:
|
if delta.delta.function_call:
|
||||||
function_calls = [delta.delta.function_call]
|
function_calls = [delta.delta.function_call]
|
||||||
|
|
||||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
|
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
|
|||||||
@ -211,7 +211,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||||
usage=usage,
|
usage=usage,
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason or None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif message.function_call:
|
elif message.function_call:
|
||||||
@ -244,7 +244,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason or None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -93,7 +93,8 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
|
|
||||||
def _validate_credentials(self, model: str, credentials: dict) -> None:
|
def _validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
|
Validate model credentials using requests to ensure compatibility with all providers following
|
||||||
|
OpenAI's API standard.
|
||||||
|
|
||||||
:param model: model name
|
:param model: model name
|
||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
|
|||||||
@ -239,7 +239,8 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
|||||||
config_items = oci_config_content.split("/")
|
config_items = oci_config_content.split("/")
|
||||||
if len(config_items) != 5:
|
if len(config_items) != 5:
|
||||||
raise CredentialsValidateFailedError(
|
raise CredentialsValidateFailedError(
|
||||||
"oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
"oci_config_content should be base64.b64encode("
|
||||||
|
"'user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
||||||
)
|
)
|
||||||
oci_config["user"] = config_items[0]
|
oci_config["user"] = config_items[0]
|
||||||
oci_config["fingerprint"] = config_items[1]
|
oci_config["fingerprint"] = config_items[1]
|
||||||
@ -442,9 +443,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|||||||
@ -146,7 +146,8 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
config_items = oci_config_content.split("/")
|
config_items = oci_config_content.split("/")
|
||||||
if len(config_items) != 5:
|
if len(config_items) != 5:
|
||||||
raise CredentialsValidateFailedError(
|
raise CredentialsValidateFailedError(
|
||||||
"oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
"oci_config_content should be base64.b64encode("
|
||||||
|
"'user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))"
|
||||||
)
|
)
|
||||||
oci_config["user"] = config_items[0]
|
oci_config["user"] = config_items[0]
|
||||||
oci_config["fingerprint"] = config_items[1]
|
oci_config["fingerprint"] = config_items[1]
|
||||||
|
|||||||
@ -639,9 +639,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
type=ParameterType.STRING,
|
type=ParameterType.STRING,
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Sets how long the model is kept in memory after generating a response. "
|
en_US="Sets how long the model is kept in memory after generating a response. "
|
||||||
"This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours). "
|
"This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours)."
|
||||||
"A negative number keeps the model loaded indefinitely, and '0' unloads the model immediately after generating a response. "
|
" A negative number keeps the model loaded indefinitely, and '0' unloads the model"
|
||||||
"Valid time units are 's','m','h'. (Default: 5m)"
|
" immediately after generating a response."
|
||||||
|
" Valid time units are 's','m','h'. (Default: 5m)"
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
|
|||||||
@ -65,7 +65,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
inputs = []
|
inputs = []
|
||||||
used_tokens = 0
|
used_tokens = 0
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
for text in texts:
|
||||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ if you are not sure about the structure.
|
|||||||
<instructions>
|
<instructions>
|
||||||
{{instructions}}
|
{{instructions}}
|
||||||
</instructions>
|
</instructions>
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||||
@ -508,7 +508,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
text = delta.text if delta.text else ""
|
text = delta.text or ""
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||||
|
|
||||||
full_text += text
|
full_text += text
|
||||||
@ -760,11 +760,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
final_tool_calls.extend(tool_calls)
|
final_tool_calls.extend(tool_calls)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||||
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
|
|
||||||
)
|
|
||||||
|
|
||||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
full_assistant_content += delta.delta.content or ""
|
||||||
|
|
||||||
if has_finish_reason:
|
if has_finish_reason:
|
||||||
final_chunk = LLMResultChunk(
|
final_chunk = LLMResultChunk(
|
||||||
|
|||||||
@ -88,7 +88,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
)
|
)
|
||||||
for i in range(len(sentences))
|
for i in range(len(sentences))
|
||||||
]
|
]
|
||||||
for index, future in enumerate(futures):
|
for future in futures:
|
||||||
yield from future.result().__enter__().iter_bytes(1024)
|
yield from future.result().__enter__().iter_bytes(1024)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -103,7 +103,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
|
Validate model credentials using requests to ensure compatibility with all providers following
|
||||||
|
OpenAI's API standard.
|
||||||
|
|
||||||
:param model: model name
|
:param model: model name
|
||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
@ -178,9 +179,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
features = []
|
features = []
|
||||||
|
|
||||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||||
if function_calling_type in ["function_call"]:
|
if function_calling_type == "function_call":
|
||||||
features.append(ModelFeature.TOOL_CALL)
|
features.append(ModelFeature.TOOL_CALL)
|
||||||
elif function_calling_type in ["tool_call"]:
|
elif function_calling_type == "tool_call":
|
||||||
features.append(ModelFeature.MULTI_TOOL_CALL)
|
features.append(ModelFeature.MULTI_TOOL_CALL)
|
||||||
|
|
||||||
stream_function_calling = credentials.get("stream_function_calling", "supported")
|
stream_function_calling = credentials.get("stream_function_calling", "supported")
|
||||||
@ -262,7 +263,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
|
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers
|
||||||
|
# following OpenAI's API standard.
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|||||||
@ -179,7 +179,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||||
usage=usage,
|
usage=usage,
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason or None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -189,7 +189,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason or None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -106,7 +106,7 @@ class OpenLLMGenerate:
|
|||||||
timeout = 120
|
timeout = 120
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"stop": stop if stop else [],
|
"stop": stop or [],
|
||||||
"prompt": "\n".join([message.content for message in prompt_messages]),
|
"prompt": "\n".join([message.content for message in prompt_messages]),
|
||||||
"llm_config": default_llm_config,
|
"llm_config": default_llm_config,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -214,7 +214,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
|||||||
|
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=output if output else "")
|
assistant_prompt_message = AssistantPromptMessage(content=output or "")
|
||||||
|
|
||||||
if index < prediction_output_length:
|
if index < prediction_output_length:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
@ -94,7 +95,7 @@ class SageMakerRerankModel(RerankModel):
|
|||||||
for idx in range(len(scores)):
|
for idx in range(len(scores)):
|
||||||
candidate_docs.append({"content": docs[idx], "score": scores[idx]})
|
candidate_docs.append({"content": docs[idx], "score": scores[idx]})
|
||||||
|
|
||||||
sorted(candidate_docs, key=lambda x: x["score"], reverse=True)
|
sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
||||||
|
|
||||||
line = 3
|
line = 3
|
||||||
rerank_documents = []
|
rerank_documents = []
|
||||||
|
|||||||
@ -260,7 +260,7 @@ class SageMakerText2SpeechModel(TTSModel):
|
|||||||
for payload in payloads
|
for payload in payloads
|
||||||
]
|
]
|
||||||
|
|
||||||
for index, future in enumerate(futures):
|
for future in futures:
|
||||||
resp = future.result()
|
resp = future.result()
|
||||||
audio_bytes = requests.get(resp.get("s3_presign_url")).content
|
audio_bytes = requests.get(resp.get("s3_presign_url")).content
|
||||||
for i in range(0, len(audio_bytes), 1024):
|
for i in range(0, len(audio_bytes), 1024):
|
||||||
|
|||||||
@ -61,7 +61,10 @@ class SparkLLMClient:
|
|||||||
|
|
||||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||||
|
|
||||||
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
authorization_origin = (
|
||||||
|
f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line",'
|
||||||
|
f' signature="{signature_sha_base64}"'
|
||||||
|
)
|
||||||
|
|
||||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||||
|
|
||||||
|
|||||||
@ -220,7 +220,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta = content
|
delta = content
|
||||||
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta if delta else "",
|
content=delta or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
|
import operator
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -127,7 +128,7 @@ class FlashRecognizer:
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
def _build_req_with_signature(self, secret_key, params, header):
|
def _build_req_with_signature(self, secret_key, params, header):
|
||||||
query = sorted(params.items(), key=lambda d: d[0])
|
query = sorted(params.items(), key=operator.itemgetter(0))
|
||||||
signstr = self._format_sign_string(query)
|
signstr = self._format_sign_string(query)
|
||||||
signature = self._sign(signstr, secret_key)
|
signature = self._sign(signstr, secret_key)
|
||||||
header["Authorization"] = signature
|
header["Authorization"] = signature
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import tempfile
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from dashscope import Generation, MultiModalConversation, get_tokenizer
|
from dashscope import Generation, MultiModalConversation, get_tokenizer
|
||||||
@ -350,9 +351,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||||||
break
|
break
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = content
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = content
|
message_text = content
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
@ -456,8 +455,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}")
|
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}")
|
||||||
|
|
||||||
with open(file_path, "wb") as image_file:
|
Path(file_path).write_bytes(base64.b64decode(encoded_string))
|
||||||
image_file.write(base64.b64decode(encoded_string))
|
|
||||||
|
|
||||||
return f"file://{file_path}"
|
return f"file://{file_path}"
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,7 @@ if you are not sure about the structure.
|
|||||||
<instructions>
|
<instructions>
|
||||||
{{instructions}}
|
{{instructions}}
|
||||||
</instructions>
|
</instructions>
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
|
class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
|
||||||
@ -368,11 +368,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
|
|||||||
final_tool_calls.extend(tool_calls)
|
final_tool_calls.extend(tool_calls)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||||
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
|
|
||||||
)
|
|
||||||
|
|
||||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
full_assistant_content += delta.delta.content or ""
|
||||||
|
|
||||||
if has_finish_reason:
|
if has_finish_reason:
|
||||||
final_chunk = LLMResultChunk(
|
final_chunk = LLMResultChunk(
|
||||||
|
|||||||
@ -114,7 +114,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||||||
credentials.refresh(request)
|
credentials.refresh(request)
|
||||||
token = credentials.token
|
token = credentials.token
|
||||||
|
|
||||||
# Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region
|
# Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available
|
||||||
|
# in us-central1 region
|
||||||
if "opus" in model or "claude-3-5-sonnet" in model:
|
if "opus" in model or "claude-3-5-sonnet" in model:
|
||||||
location = "us-east5"
|
location = "us-east5"
|
||||||
else:
|
else:
|
||||||
@ -123,7 +124,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||||||
# use access token to authenticate
|
# use access token to authenticate
|
||||||
if token:
|
if token:
|
||||||
client = AnthropicVertex(region=location, project_id=project_id, access_token=token)
|
client = AnthropicVertex(region=location, project_id=project_id, access_token=token)
|
||||||
# When access token is empty, try to use the Google Cloud VM's built-in service account or the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
# When access token is empty, try to use the Google Cloud VM's built-in service account
|
||||||
|
# or the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
||||||
else:
|
else:
|
||||||
client = AnthropicVertex(
|
client = AnthropicVertex(
|
||||||
region=location,
|
region=location,
|
||||||
@ -229,10 +231,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||||
chunk_text = chunk.delta.text if chunk.delta.text else ""
|
chunk_text = chunk.delta.text or ""
|
||||||
full_assistant_content += chunk_text
|
full_assistant_content += chunk_text
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=chunk_text if chunk_text else "",
|
content=chunk_text or "",
|
||||||
)
|
)
|
||||||
index = chunk.index
|
index = chunk.index
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@ -633,9 +635,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# coding : utf-8
|
# coding : utf-8
|
||||||
import datetime
|
import datetime
|
||||||
|
from itertools import starmap
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ class SignResult:
|
|||||||
self.authorization = ""
|
self.authorization = ""
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()])
|
return "\n".join(list(starmap("{}:{}".format, self.__dict__.items())))
|
||||||
|
|
||||||
|
|
||||||
class Credentials:
|
class Credentials:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
|
import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
@ -40,4 +41,4 @@ class Util:
|
|||||||
if len(hv) == 1:
|
if len(hv) == 1:
|
||||||
hv = "0" + hv
|
hv = "0" + hv
|
||||||
lst.append(hv)
|
lst.append(hv)
|
||||||
return reduce(lambda x, y: x + y, lst)
|
return reduce(operator.add, lst)
|
||||||
|
|||||||
@ -174,9 +174,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(content=message["content"] or "", tool_calls=[]),
|
||||||
content=message["content"] if message["content"] else "", tool_calls=[]
|
|
||||||
),
|
|
||||||
usage=usage,
|
usage=usage,
|
||||||
finish_reason=choice.get("finish_reason"),
|
finish_reason=choice.get("finish_reason"),
|
||||||
),
|
),
|
||||||
@ -208,7 +206,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=message["content"] if message["content"] else "",
|
content=message["content"] or "",
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
),
|
),
|
||||||
usage=self._calc_response_usage(
|
usage=self._calc_response_usage(
|
||||||
@ -284,7 +282,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=message.content if message.content else "",
|
content=message.content or "",
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
),
|
),
|
||||||
usage=self._calc_response_usage(
|
usage=self._calc_response_usage(
|
||||||
|
|||||||
@ -28,7 +28,7 @@ if you are not sure about the structure.
|
|||||||
</instructions>
|
</instructions>
|
||||||
|
|
||||||
You should also complete the text started with ``` but not tell ``` directly.
|
You should also complete the text started with ``` but not tell ``` directly.
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||||
@ -199,7 +199,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||||||
secret_key=credentials["secret_key"],
|
secret_key=credentials["secret_key"],
|
||||||
)
|
)
|
||||||
|
|
||||||
user = user if user else "ErnieBotDefault"
|
user = user or "ErnieBotDefault"
|
||||||
|
|
||||||
# convert prompt messages to baichuan messages
|
# convert prompt messages to baichuan messages
|
||||||
messages = [
|
messages = [
|
||||||
@ -289,7 +289,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||||
usage=usage,
|
usage=usage,
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason or None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -299,7 +299,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason or None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -85,7 +85,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
api_key = credentials["api_key"]
|
api_key = credentials["api_key"]
|
||||||
secret_key = credentials["secret_key"]
|
secret_key = credentials["secret_key"]
|
||||||
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
|
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
|
||||||
user = user if user else "ErnieBotDefault"
|
user = user or "ErnieBotDefault"
|
||||||
|
|
||||||
context_size = self._get_context_size(model, credentials)
|
context_size = self._get_context_size(model, credentials)
|
||||||
max_chunks = self._get_max_chunks(model, credentials)
|
max_chunks = self._get_max_chunks(model, credentials)
|
||||||
|
|||||||
@ -130,7 +130,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
credentials["completion_type"] = "completion"
|
credentials["completion_type"] = "completion"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type"
|
f"xinference model ability {extra_param.model_ability} is not supported,"
|
||||||
|
f" check if you have the right model type"
|
||||||
)
|
)
|
||||||
|
|
||||||
if extra_param.support_function_call:
|
if extra_param.support_function_call:
|
||||||
@ -272,11 +273,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
text = ""
|
text = ""
|
||||||
for item in message:
|
for item in message:
|
||||||
if isinstance(item, UserPromptMessage):
|
if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage):
|
||||||
text += item.content
|
|
||||||
elif isinstance(item, SystemPromptMessage):
|
|
||||||
text += item.content
|
|
||||||
elif isinstance(item, AssistantPromptMessage):
|
|
||||||
text += item.content
|
text += item.content
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")
|
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")
|
||||||
@ -362,7 +359,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they "
|
en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they "
|
||||||
"appear in the text so far, increasing the model's likelihood to talk about new topics.",
|
"appear in the text so far, increasing the model's likelihood to talk about new topics.",
|
||||||
zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。",
|
zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,"
|
||||||
|
"从而增加模型谈论新话题的可能性。",
|
||||||
),
|
),
|
||||||
default=0.0,
|
default=0.0,
|
||||||
min=-2.0,
|
min=-2.0,
|
||||||
@ -382,7 +380,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on their "
|
en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on their "
|
||||||
"existing frequency in the text so far, decreasing the model's likelihood to repeat the "
|
"existing frequency in the text so far, decreasing the model's likelihood to repeat the "
|
||||||
"same line verbatim.",
|
"same line verbatim.",
|
||||||
zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。",
|
zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,"
|
||||||
|
"从而降低模型逐字重复相同内容的可能性。",
|
||||||
),
|
),
|
||||||
default=0.0,
|
default=0.0,
|
||||||
min=-2.0,
|
min=-2.0,
|
||||||
@ -590,7 +589,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
# convert tool call to assistant message tool call
|
# convert tool call to assistant message tool call
|
||||||
tool_calls = assistant_message.tool_calls
|
tool_calls = assistant_message.tool_calls
|
||||||
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else [])
|
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls or [])
|
||||||
function_call = assistant_message.function_call
|
function_call = assistant_message.function_call
|
||||||
if function_call:
|
if function_call:
|
||||||
assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
|
assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
|
||||||
@ -653,7 +652,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
|
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
@ -750,7 +749,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
|
assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
# temp_assistant_prompt_message is used to calculate usage
|
# temp_assistant_prompt_message is used to calculate usage
|
||||||
|
|||||||
@ -101,12 +101,16 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
|||||||
|
|
||||||
:param model: model name
|
:param model: model name
|
||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:param file: The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpe g,mpga, m4a, ogg, wav, or webm.
|
:param file: The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
|
||||||
|
mpga, m4a, ogg, wav, or webm.
|
||||||
:param language: The language of the input audio. Supplying the input language in ISO-639-1
|
:param language: The language of the input audio. Supplying the input language in ISO-639-1
|
||||||
:param prompt: An optional text to guide the model's style or continue a previous audio segment.
|
:param prompt: An optional text to guide the model's style or continue a previous audio segment.
|
||||||
The prompt should match the audio language.
|
The prompt should match the audio language.
|
||||||
:param response_format: The format of the transcript output, in one of these options: json, text, srt, verbose _json, or vtt.
|
:param response_format: The format of the transcript output, in one of these options: json, text, srt,
|
||||||
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
|
verbose_json, or vtt.
|
||||||
|
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more
|
||||||
|
random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model will use
|
||||||
|
log probability to automatically increase the temperature until certain thresholds are hit.
|
||||||
:return: text for given audio file
|
:return: text for given audio file
|
||||||
"""
|
"""
|
||||||
server_url = credentials["server_url"]
|
server_url = credentials["server_url"]
|
||||||
|
|||||||
@ -215,7 +215,7 @@ class XinferenceText2SpeechModel(TTSModel):
|
|||||||
for i in range(len(sentences))
|
for i in range(len(sentences))
|
||||||
]
|
]
|
||||||
|
|
||||||
for index, future in enumerate(futures):
|
for future in futures:
|
||||||
response = future.result()
|
response = future.result()
|
||||||
for i in range(0, len(response), 1024):
|
for i in range(0, len(response), 1024):
|
||||||
yield response[i : i + 1024]
|
yield response[i : i + 1024]
|
||||||
|
|||||||
@ -76,7 +76,8 @@ class XinferenceHelper:
|
|||||||
|
|
||||||
url = str(URL(server_url) / "v1" / "models" / model_uid)
|
url = str(URL(server_url) / "v1" / "models" / model_uid)
|
||||||
|
|
||||||
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
# this method is surrounded by a lock, and default requests may hang forever,
|
||||||
|
# so we just set a Adapter with max_retries=3
|
||||||
session = Session()
|
session = Session()
|
||||||
session.mount("http://", HTTPAdapter(max_retries=3))
|
session.mount("http://", HTTPAdapter(max_retries=3))
|
||||||
session.mount("https://", HTTPAdapter(max_retries=3))
|
session.mount("https://", HTTPAdapter(max_retries=3))
|
||||||
@ -88,7 +89,8 @@ class XinferenceHelper:
|
|||||||
raise RuntimeError(f"get xinference model extra parameter failed, url: {url}, error: {e}")
|
raise RuntimeError(f"get xinference model extra parameter failed, url: {url}, error: {e}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}"
|
f"get xinference model extra parameter failed, status code: {response.status_code},"
|
||||||
|
f" response: {response.text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|||||||
@ -31,7 +31,7 @@ And you should always end the block with a "```" to indicate the end of the JSON
|
|||||||
{{instructions}}
|
{{instructions}}
|
||||||
</instructions>
|
</instructions>
|
||||||
|
|
||||||
```JSON"""
|
```JSON""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
):
|
):
|
||||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||||
else:
|
else:
|
||||||
if copy_prompt_message.role == PromptMessageRole.USER:
|
if (
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
copy_prompt_message.role == PromptMessageRole.USER
|
||||||
elif copy_prompt_message.role == PromptMessageRole.TOOL:
|
or copy_prompt_message.role == PromptMessageRole.TOOL
|
||||||
|
):
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
new_prompt_messages.append(copy_prompt_message)
|
||||||
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||||
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||||
@ -413,10 +414,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls
|
content=delta.delta.content or "", tool_calls=assistant_tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
full_assistant_content += delta.delta.content or ""
|
||||||
|
|
||||||
if delta.finish_reason is not None and chunk.usage is not None:
|
if delta.finish_reason is not None and chunk.usage is not None:
|
||||||
completion_tokens = chunk.usage.completion_tokens
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = content
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = content
|
message_text = content
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|||||||
@ -75,7 +75,8 @@ Headers = Mapping[str, Union[str, Omit]]
|
|||||||
|
|
||||||
ResponseT = TypeVar(
|
ResponseT = TypeVar(
|
||||||
"ResponseT",
|
"ResponseT",
|
||||||
bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
|
bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol,"
|
||||||
|
" BinaryResponseContent]",
|
||||||
)
|
)
|
||||||
|
|
||||||
# for user input files
|
# for user input files
|
||||||
|
|||||||
@ -30,6 +30,8 @@ def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
|
|||||||
return {key: val for key, val in merged.items() if val is not None}
|
return {key: val for key, val in merged.items() if val is not None}
|
||||||
|
|
||||||
|
|
||||||
|
from itertools import starmap
|
||||||
|
|
||||||
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
|
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
|
||||||
|
|
||||||
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
|
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
|
||||||
@ -159,7 +161,7 @@ class HttpClient:
|
|||||||
return [(key, str_data)]
|
return [(key, str_data)]
|
||||||
|
|
||||||
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
|
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
|
||||||
items = flatten([self._object_to_formdata(k, v) for k, v in data.items()])
|
items = flatten(list(starmap(self._object_to_formdata, data.items())))
|
||||||
|
|
||||||
serialized: dict[str, object] = {}
|
serialized: dict[str, object] = {}
|
||||||
for key, value in items:
|
for key, value in items:
|
||||||
|
|||||||
@ -67,7 +67,8 @@ class CommonValidator:
|
|||||||
if credential_form_schema.max_length:
|
if credential_form_schema.max_length:
|
||||||
if len(value) > credential_form_schema.max_length:
|
if len(value) > credential_form_schema.max_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}"
|
f"Variable {credential_form_schema.variable} length should not"
|
||||||
|
f" greater than {credential_form_schema.max_length}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# check the type of value
|
# check the type of value
|
||||||
|
|||||||
@ -56,14 +56,7 @@ class KeywordsModeration(Moderation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||||
for value in inputs.values():
|
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||||
if self._check_keywords_in_value(keywords_list, value):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
def _check_keywords_in_value(self, keywords_list, value) -> bool:
|
||||||
|
return any(keyword.lower() in value.lower() for keyword in keywords_list)
|
||||||
def _check_keywords_in_value(self, keywords_list, value):
|
|
||||||
for keyword in keywords_list:
|
|
||||||
if keyword.lower() in value.lower():
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|||||||
@ -65,7 +65,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
self.generate_name_trace(trace_info)
|
self.generate_name_trace(trace_info)
|
||||||
|
|
||||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||||
trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
|
trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||||
user_id = trace_info.metadata.get("user_id")
|
user_id = trace_info.metadata.get("user_id")
|
||||||
if trace_info.message_id:
|
if trace_info.message_id:
|
||||||
trace_id = trace_info.message_id
|
trace_id = trace_info.message_id
|
||||||
@ -84,7 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
)
|
)
|
||||||
self.add_trace(langfuse_trace_data=trace_data)
|
self.add_trace(langfuse_trace_data=trace_data)
|
||||||
workflow_span_data = LangfuseSpan(
|
workflow_span_data = LangfuseSpan(
|
||||||
id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id),
|
id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
|
||||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||||
input=trace_info.workflow_run_inputs,
|
input=trace_info.workflow_run_inputs,
|
||||||
output=trace_info.workflow_run_outputs,
|
output=trace_info.workflow_run_outputs,
|
||||||
@ -93,7 +93,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
end_time=trace_info.end_time,
|
end_time=trace_info.end_time,
|
||||||
metadata=trace_info.metadata,
|
metadata=trace_info.metadata,
|
||||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||||
status_message=trace_info.error if trace_info.error else "",
|
status_message=trace_info.error or "",
|
||||||
)
|
)
|
||||||
self.add_span(langfuse_span_data=workflow_span_data)
|
self.add_span(langfuse_span_data=workflow_span_data)
|
||||||
else:
|
else:
|
||||||
@ -143,7 +143,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
else:
|
else:
|
||||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||||
created_at = node_execution.created_at if node_execution.created_at else datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
@ -172,10 +172,8 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
end_time=finished_at,
|
end_time=finished_at,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||||
status_message=trace_info.error if trace_info.error else "",
|
status_message=trace_info.error or "",
|
||||||
parent_observation_id=(
|
parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
|
||||||
trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
span_data = LangfuseSpan(
|
span_data = LangfuseSpan(
|
||||||
@ -188,7 +186,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
end_time=finished_at,
|
end_time=finished_at,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||||
status_message=trace_info.error if trace_info.error else "",
|
status_message=trace_info.error or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.add_span(langfuse_span_data=span_data)
|
self.add_span(langfuse_span_data=span_data)
|
||||||
@ -212,7 +210,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
output=outputs,
|
output=outputs,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||||
status_message=trace_info.error if trace_info.error else "",
|
status_message=trace_info.error or "",
|
||||||
usage=generation_usage,
|
usage=generation_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -277,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
output=message_data.answer,
|
output=message_data.answer,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
||||||
status_message=message_data.error if message_data.error else "",
|
status_message=message_data.error or "",
|
||||||
usage=generation_usage,
|
usage=generation_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -319,7 +317,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
end_time=trace_info.end_time,
|
end_time=trace_info.end_time,
|
||||||
metadata=trace_info.metadata,
|
metadata=trace_info.metadata,
|
||||||
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
||||||
status_message=message_data.error if message_data.error else "",
|
status_message=message_data.error or "",
|
||||||
usage=generation_usage,
|
usage=generation_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -82,7 +82,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
langsmith_run = LangSmithRunModel(
|
langsmith_run = LangSmithRunModel(
|
||||||
file_list=trace_info.file_list,
|
file_list=trace_info.file_list,
|
||||||
total_tokens=trace_info.total_tokens,
|
total_tokens=trace_info.total_tokens,
|
||||||
id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
|
id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||||
inputs=trace_info.workflow_run_inputs,
|
inputs=trace_info.workflow_run_inputs,
|
||||||
run_type=LangSmithRunType.tool,
|
run_type=LangSmithRunType.tool,
|
||||||
@ -94,7 +94,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
},
|
},
|
||||||
error=trace_info.error,
|
error=trace_info.error,
|
||||||
tags=["workflow"],
|
tags=["workflow"],
|
||||||
parent_run_id=trace_info.message_id if trace_info.message_id else None,
|
parent_run_id=trace_info.message_id or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.add_run(langsmith_run)
|
self.add_run(langsmith_run)
|
||||||
@ -133,7 +133,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
else:
|
else:
|
||||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||||
created_at = node_execution.created_at if node_execution.created_at else datetime.now()
|
created_at = node_execution.created_at or datetime.now()
|
||||||
elapsed_time = node_execution.elapsed_time
|
elapsed_time = node_execution.elapsed_time
|
||||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||||
|
|
||||||
@ -180,9 +180,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
extra={
|
extra={
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
},
|
},
|
||||||
parent_run_id=trace_info.workflow_app_log_id
|
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
|
||||||
if trace_info.workflow_app_log_id
|
|
||||||
else trace_info.workflow_run_id,
|
|
||||||
tags=["node_execution"],
|
tags=["node_execution"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -223,7 +223,7 @@ class OpsTraceManager:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# auth check
|
# auth check
|
||||||
if tracing_provider not in provider_config_map.keys() and tracing_provider is not None:
|
if tracing_provider not in provider_config_map and tracing_provider is not None:
|
||||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||||
|
|
||||||
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
||||||
@ -354,11 +354,11 @@ class TraceTask:
|
|||||||
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
|
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
|
||||||
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
|
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
|
||||||
workflow_run_version = workflow_run.version
|
workflow_run_version = workflow_run.version
|
||||||
error = workflow_run.error if workflow_run.error else ""
|
error = workflow_run.error or ""
|
||||||
|
|
||||||
total_tokens = workflow_run.total_tokens
|
total_tokens = workflow_run.total_tokens
|
||||||
|
|
||||||
file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else []
|
file_list = workflow_run_inputs.get("sys.file") or []
|
||||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||||
|
|
||||||
# get workflow_app_log_id
|
# get workflow_app_log_id
|
||||||
@ -452,7 +452,7 @@ class TraceTask:
|
|||||||
message_tokens=message_tokens,
|
message_tokens=message_tokens,
|
||||||
answer_tokens=message_data.answer_tokens,
|
answer_tokens=message_data.answer_tokens,
|
||||||
total_tokens=message_tokens + message_data.answer_tokens,
|
total_tokens=message_tokens + message_data.answer_tokens,
|
||||||
error=message_data.error if message_data.error else "",
|
error=message_data.error or "",
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=message_data.answer,
|
outputs=message_data.answer,
|
||||||
file_list=file_list,
|
file_list=file_list,
|
||||||
@ -487,7 +487,7 @@ class TraceTask:
|
|||||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||||
|
|
||||||
moderation_trace_info = ModerationTraceInfo(
|
moderation_trace_info = ModerationTraceInfo(
|
||||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
message_id=workflow_app_log_id or message_id,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
message_data=message_data.to_dict(),
|
message_data=message_data.to_dict(),
|
||||||
flagged=moderation_result.flagged,
|
flagged=moderation_result.flagged,
|
||||||
@ -527,7 +527,7 @@ class TraceTask:
|
|||||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||||
|
|
||||||
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
||||||
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
|
message_id=workflow_app_log_id or message_id,
|
||||||
message_data=message_data.to_dict(),
|
message_data=message_data.to_dict(),
|
||||||
inputs=message_data.message,
|
inputs=message_data.message,
|
||||||
outputs=message_data.answer,
|
outputs=message_data.answer,
|
||||||
@ -569,7 +569,7 @@ class TraceTask:
|
|||||||
|
|
||||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
inputs=message_data.query if message_data.query else message_data.inputs,
|
inputs=message_data.query or message_data.inputs,
|
||||||
documents=[doc.model_dump() for doc in documents],
|
documents=[doc.model_dump() for doc in documents],
|
||||||
start_time=timer.get("start"),
|
start_time=timer.get("start"),
|
||||||
end_time=timer.get("end"),
|
end_time=timer.get("end"),
|
||||||
@ -695,8 +695,7 @@ class TraceQueueManager:
|
|||||||
self.start_timer()
|
self.start_timer()
|
||||||
|
|
||||||
def add_trace_task(self, trace_task: TraceTask):
|
def add_trace_task(self, trace_task: TraceTask):
|
||||||
global trace_manager_timer
|
global trace_manager_timer, trace_manager_queue
|
||||||
global trace_manager_queue
|
|
||||||
try:
|
try:
|
||||||
if self.trace_instance:
|
if self.trace_instance:
|
||||||
trace_task.app_id = self.app_id
|
trace_task.app_id = self.app_id
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n"
|
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" # noqa: E501
|
||||||
|
|
||||||
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n"
|
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" # noqa: E501
|
||||||
|
|
||||||
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||||
"completion_prompt_config": {
|
"completion_prompt_config": {
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
|
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
|
||||||
},
|
},
|
||||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||||
},
|
},
|
||||||
@ -24,7 +24,7 @@ COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
|||||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||||
"completion_prompt_config": {
|
"completion_prompt_config": {
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
|
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" # noqa: E501
|
||||||
},
|
},
|
||||||
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
|
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
|
||||||
},
|
},
|
||||||
|
|||||||
@ -112,11 +112,11 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
for v in prompt_template_config["special_variable_keys"]:
|
for v in prompt_template_config["special_variable_keys"]:
|
||||||
# support #context#, #query# and #histories#
|
# support #context#, #query# and #histories#
|
||||||
if v == "#context#":
|
if v == "#context#":
|
||||||
variables["#context#"] = context if context else ""
|
variables["#context#"] = context or ""
|
||||||
elif v == "#query#":
|
elif v == "#query#":
|
||||||
variables["#query#"] = query if query else ""
|
variables["#query#"] = query or ""
|
||||||
elif v == "#histories#":
|
elif v == "#histories#":
|
||||||
variables["#histories#"] = histories if histories else ""
|
variables["#histories#"] = histories or ""
|
||||||
|
|
||||||
prompt_template = prompt_template_config["prompt_template"]
|
prompt_template = prompt_template_config["prompt_template"]
|
||||||
prompt = prompt_template.format(variables)
|
prompt = prompt_template.format(variables)
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class BaseKeyword(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||||
for text in texts[:]:
|
for text in texts.copy():
|
||||||
doc_id = text.metadata["doc_id"]
|
doc_id = text.metadata["doc_id"]
|
||||||
exists_duplicate_node = self.text_exists(doc_id)
|
exists_duplicate_node = self.text_exists(doc_id)
|
||||||
if exists_duplicate_node:
|
if exists_duplicate_node:
|
||||||
|
|||||||
@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
|
|||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||||
dbinstance_id=self.config.instance_id,
|
dbinstance_id=self.config.instance_id,
|
||||||
region_id=self.config.region_id,
|
region_id=self.config.region_id,
|
||||||
@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
|
|||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||||
dbinstance_id=self.config.instance_id,
|
dbinstance_id=self.config.instance_id,
|
||||||
region_id=self.config.region_id,
|
region_id=self.config.region_id,
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
|
|||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
collection = self._client.get_or_create_collection(self._collection_name)
|
collection = self._client.get_or_create_collection(self._collection_name)
|
||||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
|
|
||||||
ids: list[str] = results["ids"][0]
|
ids: list[str] = results["ids"][0]
|
||||||
documents: list[str] = results["documents"][0]
|
documents: list[str] = results["documents"][0]
|
||||||
|
|||||||
@ -86,8 +86,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
id=uuids[i],
|
id=uuids[i],
|
||||||
document={
|
document={
|
||||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
Field.VECTOR.value: embeddings[i] or None,
|
||||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
|
Field.METADATA_KEY.value: documents[i].metadata or {},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self._client.indices.refresh(index=self._collection_name)
|
self._client.indices.refresh(index=self._collection_name)
|
||||||
@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
for doc, score in docs_and_scores:
|
for doc, score in docs_and_scores:
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
if score > score_threshold:
|
if score > score_threshold:
|
||||||
doc.metadata["score"] = score
|
doc.metadata["score"] = score
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|||||||
@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
|
|||||||
for result in results[0]:
|
for result in results[0]:
|
||||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||||
metadata["score"] = result["distance"]
|
metadata["score"] = result["distance"]
|
||||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
if result["distance"] > score_threshold:
|
if result["distance"] > score_threshold:
|
||||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|||||||
@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
|
|||||||
|
|
||||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
where_str = (
|
where_str = (
|
||||||
f"WHERE dist < {1 - score_threshold}"
|
f"WHERE dist < {1 - score_threshold}"
|
||||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||||
|
|||||||
@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
metadata["score"] = hit["_score"]
|
metadata["score"] = hit["_score"]
|
||||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
if hit["_score"] > score_threshold:
|
if hit["_score"] > score_threshold:
|
||||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|||||||
@ -195,11 +195,12 @@ class OracleVector(BaseVector):
|
|||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only",
|
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||||
|
f" ORDER BY distance fetch first {top_k} rows only",
|
||||||
[numpy.array(query_vector)],
|
[numpy.array(query_vector)],
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
for record in cur:
|
for record in cur:
|
||||||
metadata, text, distance = record
|
metadata, text, distance = record
|
||||||
score = 1 - distance
|
score = 1 - distance
|
||||||
@ -211,7 +212,7 @@ class OracleVector(BaseVector):
|
|||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
# just not implement fetch by score_threshold now, may be later
|
# just not implement fetch by score_threshold now, may be later
|
||||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
if len(query) > 0:
|
if len(query) > 0:
|
||||||
# Check which language the query is in
|
# Check which language the query is in
|
||||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||||
@ -254,7 +255,8 @@ class OracleVector(BaseVector):
|
|||||||
entities.append(token)
|
entities.append(token)
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
f"select meta, text, embedding FROM {self.table_name}"
|
||||||
|
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||||
[" ACCUM ".join(entities)],
|
[" ACCUM ".join(entities)],
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
|
|||||||
@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
|
|||||||
metadata = record.meta
|
metadata = record.meta
|
||||||
score = 1 - dis
|
score = 1 - dis
|
||||||
metadata["score"] = score
|
metadata["score"] = score
|
||||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
if score > score_threshold:
|
if score > score_threshold:
|
||||||
doc = Document(page_content=record.text, metadata=metadata)
|
doc = Document(page_content=record.text, metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|||||||
@ -139,11 +139,12 @@ class PGVector(BaseVector):
|
|||||||
|
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}",
|
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||||
|
f" ORDER BY distance LIMIT {top_k}",
|
||||||
(json.dumps(query_vector),),
|
(json.dumps(query_vector),),
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
for record in cur:
|
for record in cur:
|
||||||
metadata, text, distance = record
|
metadata, text, distance = record
|
||||||
score = 1 - distance
|
score = 1 - distance
|
||||||
|
|||||||
@ -339,7 +339,7 @@ class QdrantVector(BaseVector):
|
|||||||
for result in results:
|
for result in results:
|
||||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||||
# duplicate check score threshold
|
# duplicate check score threshold
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
if result.score > score_threshold:
|
if result.score > score_threshold:
|
||||||
metadata["score"] = result.score
|
metadata["score"] = result.score
|
||||||
doc = Document(
|
doc = Document(
|
||||||
|
|||||||
@ -127,27 +127,26 @@ class RelytVector(BaseVector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
chunks_table_data = []
|
chunks_table_data = []
|
||||||
with self.client.connect() as conn:
|
with self.client.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
||||||
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
chunks_table_data.append(
|
||||||
chunks_table_data.append(
|
{
|
||||||
{
|
"id": chunk_id,
|
||||||
"id": chunk_id,
|
"embedding": embedding,
|
||||||
"embedding": embedding,
|
"document": document,
|
||||||
"document": document,
|
"metadata": metadata,
|
||||||
"metadata": metadata,
|
}
|
||||||
}
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the batch insert when the batch size is reached
|
# Execute the batch insert when the batch size is reached
|
||||||
if len(chunks_table_data) == 500:
|
if len(chunks_table_data) == 500:
|
||||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
|
||||||
# Clear the chunks_table_data list for the next batch
|
|
||||||
chunks_table_data.clear()
|
|
||||||
|
|
||||||
# Insert any remaining records that didn't make up a full batch
|
|
||||||
if chunks_table_data:
|
|
||||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||||
|
# Clear the chunks_table_data list for the next batch
|
||||||
|
chunks_table_data.clear()
|
||||||
|
|
||||||
|
# Insert any remaining records that didn't make up a full batch
|
||||||
|
if chunks_table_data:
|
||||||
|
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||||
|
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
@ -186,11 +185,10 @@ class RelytVector(BaseVector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self.client.connect() as conn:
|
with self.client.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
delete_condition = chunks_table.c.id.in_(ids)
|
||||||
delete_condition = chunks_table.c.id.in_(ids)
|
conn.execute(chunks_table.delete().where(delete_condition))
|
||||||
conn.execute(chunks_table.delete().where(delete_condition))
|
return True
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Delete operation failed:", str(e))
|
print("Delete operation failed:", str(e))
|
||||||
return False
|
return False
|
||||||
@ -232,7 +230,7 @@ class RelytVector(BaseVector):
|
|||||||
# Organize results.
|
# Organize results.
|
||||||
docs = []
|
docs = []
|
||||||
for document, score in results:
|
for document, score in results:
|
||||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
if 1 - score > score_threshold:
|
if 1 - score > score_threshold:
|
||||||
docs.append(document)
|
docs.append(document)
|
||||||
return docs
|
return docs
|
||||||
|
|||||||
@ -63,10 +63,7 @@ class TencentVector(BaseVector):
|
|||||||
|
|
||||||
def _has_collection(self) -> bool:
|
def _has_collection(self) -> bool:
|
||||||
collections = self._db.list_collections()
|
collections = self._db.list_collections()
|
||||||
for collection in collections:
|
return any(collection.collection_name == self._collection_name for collection in collections)
|
||||||
if collection.collection_name == self._collection_name:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _create_collection(self, dimension: int) -> None:
|
def _create_collection(self, dimension: int) -> None:
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||||
@ -156,7 +153,7 @@ class TencentVector(BaseVector):
|
|||||||
limit=kwargs.get("top_k", 4),
|
limit=kwargs.get("top_k", 4),
|
||||||
timeout=self._client_config.timeout,
|
timeout=self._client_config.timeout,
|
||||||
)
|
)
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
return self._get_search_res(res, score_threshold)
|
return self._get_search_res(res, score_threshold)
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
|||||||
@ -124,20 +124,19 @@ class TiDBVector(BaseVector):
|
|||||||
texts = [d.page_content for d in documents]
|
texts = [d.page_content for d in documents]
|
||||||
|
|
||||||
chunks_table_data = []
|
chunks_table_data = []
|
||||||
with self._engine.connect() as conn:
|
with self._engine.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
||||||
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||||
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
|
||||||
|
|
||||||
# Execute the batch insert when the batch size is reached
|
# Execute the batch insert when the batch size is reached
|
||||||
if len(chunks_table_data) == 500:
|
if len(chunks_table_data) == 500:
|
||||||
conn.execute(insert(table).values(chunks_table_data))
|
|
||||||
# Clear the chunks_table_data list for the next batch
|
|
||||||
chunks_table_data.clear()
|
|
||||||
|
|
||||||
# Insert any remaining records that didn't make up a full batch
|
|
||||||
if chunks_table_data:
|
|
||||||
conn.execute(insert(table).values(chunks_table_data))
|
conn.execute(insert(table).values(chunks_table_data))
|
||||||
|
# Clear the chunks_table_data list for the next batch
|
||||||
|
chunks_table_data.clear()
|
||||||
|
|
||||||
|
# Insert any remaining records that didn't make up a full batch
|
||||||
|
if chunks_table_data:
|
||||||
|
conn.execute(insert(table).values(chunks_table_data))
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
@ -160,11 +159,10 @@ class TiDBVector(BaseVector):
|
|||||||
raise ValueError("No ids provided to delete.")
|
raise ValueError("No ids provided to delete.")
|
||||||
table = self._table(self._dimension)
|
table = self._table(self._dimension)
|
||||||
try:
|
try:
|
||||||
with self._engine.connect() as conn:
|
with self._engine.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
delete_condition = table.c.id.in_(ids)
|
||||||
delete_condition = table.c.id.in_(ids)
|
conn.execute(table.delete().where(delete_condition))
|
||||||
conn.execute(table.delete().where(delete_condition))
|
return True
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Delete operation failed:", str(e))
|
print("Delete operation failed:", str(e))
|
||||||
return False
|
return False
|
||||||
@ -187,7 +185,7 @@ class TiDBVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
filter = kwargs.get("filter")
|
filter = kwargs.get("filter")
|
||||||
distance = 1 - score_threshold
|
distance = 1 - score_threshold
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class BaseVector(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||||
for text in texts[:]:
|
for text in texts.copy():
|
||||||
doc_id = text.metadata["doc_id"]
|
doc_id = text.metadata["doc_id"]
|
||||||
exists_duplicate_node = self.text_exists(doc_id)
|
exists_duplicate_node = self.text_exists(doc_id)
|
||||||
if exists_duplicate_node:
|
if exists_duplicate_node:
|
||||||
|
|||||||
@ -153,7 +153,7 @@ class Vector:
|
|||||||
return CacheEmbedding(embedding_model)
|
return CacheEmbedding(embedding_model)
|
||||||
|
|
||||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||||
for text in texts[:]:
|
for text in texts.copy():
|
||||||
doc_id = text.metadata["doc_id"]
|
doc_id = text.metadata["doc_id"]
|
||||||
exists_duplicate_node = self.text_exists(doc_id)
|
exists_duplicate_node = self.text_exists(doc_id)
|
||||||
if exists_duplicate_node:
|
if exists_duplicate_node:
|
||||||
|
|||||||
@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
|
|||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
for doc, score in docs_and_scores:
|
for doc, score in docs_and_scores:
|
||||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||||
# check score threshold
|
# check score threshold
|
||||||
if score > score_threshold:
|
if score > score_threshold:
|
||||||
doc.metadata["score"] = score
|
doc.metadata["score"] = score
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import mimetypes
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator, Iterable, Mapping
|
from collections.abc import Generator, Iterable, Mapping
|
||||||
from io import BufferedReader, BytesIO
|
from io import BufferedReader, BytesIO
|
||||||
from pathlib import PurePath
|
from pathlib import Path, PurePath
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, model_validator
|
from pydantic import BaseModel, ConfigDict, model_validator
|
||||||
@ -56,8 +56,7 @@ class Blob(BaseModel):
|
|||||||
def as_string(self) -> str:
|
def as_string(self) -> str:
|
||||||
"""Read data as a string."""
|
"""Read data as a string."""
|
||||||
if self.data is None and self.path:
|
if self.data is None and self.path:
|
||||||
with open(str(self.path), encoding=self.encoding) as f:
|
return Path(str(self.path)).read_text(encoding=self.encoding)
|
||||||
return f.read()
|
|
||||||
elif isinstance(self.data, bytes):
|
elif isinstance(self.data, bytes):
|
||||||
return self.data.decode(self.encoding)
|
return self.data.decode(self.encoding)
|
||||||
elif isinstance(self.data, str):
|
elif isinstance(self.data, str):
|
||||||
@ -72,8 +71,7 @@ class Blob(BaseModel):
|
|||||||
elif isinstance(self.data, str):
|
elif isinstance(self.data, str):
|
||||||
return self.data.encode(self.encoding)
|
return self.data.encode(self.encoding)
|
||||||
elif self.data is None and self.path:
|
elif self.data is None and self.path:
|
||||||
with open(str(self.path), "rb") as f:
|
return Path(str(self.path)).read_bytes()
|
||||||
return f.read()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unable to get bytes for blob {self}")
|
raise ValueError(f"Unable to get bytes for blob {self}")
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,10 @@ from extensions.ext_storage import storage
|
|||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
|
|
||||||
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
|
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
|
||||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
USER_AGENT = (
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124"
|
||||||
|
" Safari/537.36"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExtractProcessor:
|
class ExtractProcessor:
|
||||||
@ -65,8 +68,7 @@ class ExtractProcessor:
|
|||||||
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
|
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
|
||||||
|
|
||||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||||
with open(file_path, "wb") as file:
|
Path(file_path).write_bytes(response.content)
|
||||||
file.write(response.content)
|
|
||||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||||
if return_text:
|
if return_text:
|
||||||
delimiter = "\n"
|
delimiter = "\n"
|
||||||
@ -108,7 +110,7 @@ class ExtractProcessor:
|
|||||||
)
|
)
|
||||||
elif file_extension in [".htm", ".html"]:
|
elif file_extension in [".htm", ".html"]:
|
||||||
extractor = HtmlExtractor(file_path)
|
extractor = HtmlExtractor(file_path)
|
||||||
elif file_extension in [".docx"]:
|
elif file_extension == ".docx":
|
||||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||||
elif file_extension == ".csv":
|
elif file_extension == ".csv":
|
||||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||||
@ -140,7 +142,7 @@ class ExtractProcessor:
|
|||||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||||
elif file_extension in [".htm", ".html"]:
|
elif file_extension in [".htm", ".html"]:
|
||||||
extractor = HtmlExtractor(file_path)
|
extractor = HtmlExtractor(file_path)
|
||||||
elif file_extension in [".docx"]:
|
elif file_extension == ".docx":
|
||||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||||
elif file_extension == ".csv":
|
elif file_extension == ".csv":
|
||||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user