diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000..a6a922a221b --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,17 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + timezone: "Asia/Shanghai" + day: "friday" + target-branch: "v2" + groups: + python-dependencies: + patterns: + - "*" +# ignore: +# - dependency-name: "pymupdf" +# versions: ["*"] + diff --git a/.github/workflows/build-and-push-python-pg.yml b/.github/workflows/build-and-push-python-pg.yml index bc4dc3f2c77..4640f5edbd0 100644 --- a/.github/workflows/build-and-push-python-pg.yml +++ b/.github/workflows/build-and-push-python-pg.yml @@ -33,13 +33,13 @@ jobs: - name: Checkout uses: actions/checkout@v4 with: - ref: main + ref: v1 - name: Prepare id: prepare run: | DOCKER_IMAGE=ghcr.io/1panel-dev/maxkb-python-pg DOCKER_PLATFORMS=${{ github.event.inputs.architecture }} - TAG_NAME=python3.11-pg15.8 + TAG_NAME=python3.11-pg15.14 DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" echo ::set-output name=docker_image::${DOCKER_IMAGE} echo ::set-output name=version::${TAG_NAME} diff --git a/.github/workflows/build-and-push.yml b/.github/workflows/build-and-push.yml index 26d2b86d297..1e1daf2696c 100644 --- a/.github/workflows/build-and-push.yml +++ b/.github/workflows/build-and-push.yml @@ -7,7 +7,7 @@ on: inputs: dockerImageTag: description: 'Image Tag' - default: 'v1.10.3-dev' + default: 'v1.10.7-dev' required: true dockerImageTagWithLatest: description: '是否发布latest tag(正式发版时选择,测试版本切勿选择)' @@ -36,7 +36,7 @@ on: jobs: build-and-push-to-fit2cloud-registry: if: ${{ contains(github.event.inputs.registry, 'fit2cloud') }} - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - name: Check Disk Space run: df -h @@ -52,10 +52,6 @@ jobs: swap-storage: true - name: Check Disk Space run: df -h - - name: Set Swap Space - uses: pierotofy/set-swap-space@master - with: - swap-size-gb: 8 - name: Checkout uses: actions/checkout@v4 with: @@ -68,24 +64,17 @@ jobs: TAG_NAME=${{ github.event.inputs.dockerImageTag }} TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }} if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then - DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*}" else DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}" fi echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \ - --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=${GITHUB_SHA::8} --no-cache \ + --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \ ${DOCKER_IMAGE_TAGS} . - name: Set up QEMU uses: docker/setup-qemu-action@v3 - with: - # Until https://github.com/tonistiigi/binfmt/issues/215 - image: tonistiigi/binfmt:qemu-v7.0.0-28 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - with: - buildkitd-config-inline: | - [worker.oci] - max-parallelism = 1 - name: Login to GitHub Container Registry uses: docker/login-action@v3 with: @@ -100,11 +89,12 @@ jobs: password: ${{ secrets.FIT2CLOUD_REGISTRY_PASSWORD }} - name: Docker Buildx (build-and-push) run: | + sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile build-and-push-to-dockerhub: if: ${{ contains(github.event.inputs.registry, 'dockerhub') }} - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - name: Check Disk Space run: df -h @@ -120,10 +110,6 @@ jobs: swap-storage: true - name: Check Disk Space run: df -h - - name: Set Swap Space - uses: pierotofy/set-swap-space@master - with: - swap-size-gb: 8 - name: Checkout uses: actions/checkout@v4 with: @@ -136,24 +122,17 @@ jobs: TAG_NAME=${{ github.event.inputs.dockerImageTag }} TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }} if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then - DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*}" else DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}" fi echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \ - --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=${GITHUB_SHA::8} --no-cache \ + --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \ ${DOCKER_IMAGE_TAGS} . - name: Set up QEMU uses: docker/setup-qemu-action@v3 - with: - # Until https://github.com/tonistiigi/binfmt/issues/215 - image: tonistiigi/binfmt:qemu-v7.0.0-28 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - with: - buildkitd-config-inline: | - [worker.oci] - max-parallelism = 1 - name: Login to GitHub Container Registry uses: docker/login-action@v3 with: @@ -167,4 +146,5 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Docker Buildx (build-and-push) run: | + sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile diff --git a/README.md b/README.md index cfe819e56ff..b4a925edb64 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@

MaxKB

-

Ready-to-use AI Chatbot

+

Open-source platform for building enterprise-grade agents

+

强大易用的企业级智能体平台

1Panel-dev%2FMaxKB | Trendshift

License: GPL v3 @@ -10,10 +11,10 @@


-MaxKB = Max Knowledge Base, it is a ready-to-use AI chatbot that integrates Retrieval-Augmented Generation (RAG) pipelines, supports robust workflows, and provides advanced MCP tool-use capabilities. MaxKB is widely applied in scenarios such as intelligent customer service, corporate internal knowledge bases, academic research, and education. +MaxKB = Max Knowledge Brain, it is an open-source platform for building enterprise-grade agents. MaxKB integrates Retrieval-Augmented Generation (RAG) pipelines, supports robust workflows, and provides advanced MCP tool-use capabilities. MaxKB is widely applied in scenarios such as intelligent customer service, corporate internal knowledge bases, academic research, and education. -- **RAG Pipeline**: Supports direct uploading of documents / automatic crawling of online documents, with features for automatic text splitting, vectorization, and RAG (Retrieval-Augmented Generation). This effectively reduces hallucinations in large models, providing a superior smart Q&A interaction experience. -- **Flexible Orchestration**: Equipped with a powerful workflow engine, function library and MCP tool-use, enabling the orchestration of AI processes to meet the needs of complex business scenarios. +- **RAG Pipeline**: Supports direct uploading of documents / automatic crawling of online documents, with features for automatic text splitting, vectorization. This effectively reduces hallucinations in large models, providing a superior smart Q&A interaction experience. +- **Agentic Workflow**: Equipped with a powerful workflow engine, function library and MCP tool-use, enabling the orchestration of AI processes to meet the needs of complex business scenarios. - **Seamless Integration**: Facilitates zero-coding rapid integration into third-party business systems, quickly equipping existing systems with intelligent Q&A capabilities to enhance user satisfaction. - **Model-Agnostic**: Supports various large models, including private models (such as DeepSeek, Llama, Qwen, etc.) and public models (like OpenAI, Claude, Gemini, etc.). - **Multi Modal**: Native support for input and output text, image, audio and video. @@ -23,7 +24,7 @@ MaxKB = Max Knowledge Base, it is a ready-to-use AI chatbot that integrates Retr Execute the script below to start a MaxKB container using Docker: ```bash -docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages 1panel/maxkb +docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages 1panel/maxkb:v1 ``` Access MaxKB web interface at `http://your_server_ip:8080` with default admin credentials: @@ -31,7 +32,7 @@ Access MaxKB web interface at `http://your_server_ip:8080` with default admin cr - username: admin - password: MaxKB@123.. -中国用户如遇到 Docker 镜像 Pull 失败问题,请参照该 [离线安装文档](https://maxkb.cn/docs/installation/offline_installtion/) 进行安装。 +中国用户如遇到 Docker 镜像 Pull 失败问题,请参照该 [离线安装文档](https://maxkb.cn/docs/v1/installation/offline_installtion/) 进行安装。 ## Screenshots @@ -55,8 +56,6 @@ Access MaxKB web interface at `http://your_server_ip:8080` with default admin cr ## Feature Comparison -MaxKB is positioned as an Ready-to-use RAG (Retrieval-Augmented Generation) intelligent Q&A application, rather than a middleware platform for building large model applications. The following table is merely a comparison from a functional perspective. - diff --git a/README_CN.md b/README_CN.md index e55150902ea..07fa00ea4e6 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,25 +1,25 @@

MaxKB

-

基于大模型和 RAG 的知识库问答系统

-

Ready-to-use, flexible RAG Chatbot

+

强大易用的企业级智能体平台

1Panel-dev%2FMaxKB | Trendshift - 1Panel-dev%2FMaxKB | Aliyun

English README - License: GPL v3 + License: GPL v3 Latest release - Stars - Download + Stars + Download + Gitee Stars + GitCode Stars


-MaxKB = Max Knowledge Base,是一款开箱即用的 RAG Chatbot,具备强大的工作流和 MCP 工具调用能力。它支持对接各种主流大语言模型(LLMs),广泛应用于智能客服、企业内部知识库、学术研究与教育等场景。 +MaxKB = Max Knowledge Brain,是一个强大易用的企业级智能体平台,致力于解决企业 AI 落地面临的技术门槛高、部署成本高、迭代周期长等问题,助力企业在人工智能时代赢得先机。秉承“开箱即用,伴随成长”的设计理念,MaxKB 支持企业快速接入主流大模型,高效构建专属知识库,并提供从基础问答(RAG)、复杂流程自动化(工作流)到智能体(Agent)的渐进式升级路径,全面赋能智能客服、智能办公助手等多种应用场景。 -- **开箱即用**:支持直接上传文档 / 自动爬取在线文档,支持文本自动拆分、向量化和 RAG(检索增强生成),有效减少大模型幻觉,智能问答交互体验好; -- **模型中立**:支持对接各种大模型,包括本地私有大模型(DeepSeek R1 / Llama 3 / Qwen 2 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等); +- **RAG 检索增强生成**:高效搭建本地 AI 知识库,支持直接上传文档 / 自动爬取在线文档,支持文本自动拆分、向量化,有效减少大模型幻觉,提升问答效果; - **灵活编排**:内置强大的工作流引擎、函数库和 MCP 工具调用能力,支持编排 AI 工作过程,满足复杂业务场景下的需求; -- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统,让已有系统快速拥有智能问答能力,提高用户满意度。 +- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统,让已有系统快速拥有智能问答能力,提高用户满意度; +- **模型中立**:支持对接各种大模型,包括本地私有大模型(DeepSeek R1 / Qwen 3 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等)。 MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/ @@ -27,10 +27,10 @@ MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/ ``` # Linux 机器 -docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb +docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb:v1 # Windows 机器 -docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/var/lib/postgresql/data -v C:/python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb +docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/var/lib/postgresql/data -v C:/python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb:v1 # 用户名: admin # 密码: MaxKB@123.. @@ -38,8 +38,8 @@ docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/var/lib/po - 你也可以通过 [1Panel 应用商店](https://apps.fit2cloud.com/1panel) 快速部署 MaxKB; - 如果是内网环境,推荐使用 [离线安装包](https://community.fit2cloud.com/#/products/maxkb/downloads) 进行安装部署; -- MaxKB 产品版本分为社区版和专业版,详情请参见:[MaxKB 产品版本对比](https://maxkb.cn/pricing.html); -- 如果您需要向团队介绍 MaxKB,可以使用这个 [官方 PPT 材料](https://maxkb.cn/download/introduce-maxkb_202503.pdf)。 +- MaxKB 不同产品产品版本的对比请参见:[MaxKB 产品版本对比](https://maxkb.cn/price); +- 如果您需要向团队介绍 MaxKB,可以使用这个 [官方 PPT 材料](https://fit2cloud.com/maxkb/download/introduce-maxkb_202507.pdf)。 如你有更多问题,可以查看使用手册,或者通过论坛与我们交流。 diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index c5a0de1a152..05ab5009c0a 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -11,7 +11,6 @@ import re import time from functools import reduce -from types import AsyncGeneratorType from typing import List, Dict from django.db.models import QuerySet @@ -33,13 +32,26 @@ Called MCP Tool: %s -```json %s -``` + """ +tool_message_json_template = """ +```json +%s +``` +""" + + +def generate_tool_message_template(name, context): + if '```' in context: + return tool_message_template % (name, context) + else: + return tool_message_template % (name, tool_message_json_template % (context)) + + def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, reasoning_content: str): chat_model = node_variable.get('chat_model') @@ -102,19 +114,19 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) - async def _yield_mcp_response(chat_model, message_list, mcp_servers): async with MultiServerMCPClient(json.loads(mcp_servers)) as client: agent = create_react_agent(chat_model, client.get_tools()) response = agent.astream({"messages": message_list}, stream_mode='messages') async for chunk in response: if isinstance(chunk[0], ToolMessage): - content = tool_message_template % (chunk[0].name, chunk[0].content) + content = generate_tool_message_template(chunk[0].name, chunk[0].content) chunk[0].content = content yield chunk[0] if isinstance(chunk[0], AIMessageChunk): yield chunk[0] + def mcp_response_generator(chat_model, message_list, mcp_servers): loop = asyncio.new_event_loop() try: @@ -130,6 +142,7 @@ def mcp_response_generator(chat_model, message_list, mcp_servers): finally: loop.close() + async def anext_async(agen): return await agen.__anext__() @@ -186,7 +199,9 @@ def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['question'] = details.get('question') self.context['reasoning_content'] = details.get('reasoning_content') - self.answer_text = details.get('answer') + self.context['model_setting'] = details.get('model_setting') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, @@ -216,7 +231,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record message_list = self.generate_message_list(system, prompt, history_message) self.context['message_list'] = message_list - if mcp_enable and mcp_servers is not None: + if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers: r = mcp_response_generator(chat_model, message_list, mcp_servers) return NodeResult( {'result': r, 'chat_model': chat_model, 'message_list': message_list, @@ -271,6 +286,7 @@ def get_details(self, index: int, **kwargs): "index": index, 'run_time': self.context.get('run_time'), 'system': self.context.get('system'), + 'model_setting': self.context.get('model_setting'), 'history_message': [{'content': message.content, 'role': message.type} for message in (self.context.get('history_message') if self.context.get( 'history_message') is not None else [])], diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index d962f7163bb..95445f45612 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -168,7 +168,8 @@ def save_context(self, details, workflow_manage): self.context['question'] = details.get('question') self.context['type'] = details.get('type') self.context['reasoning_content'] = details.get('reasoning_content') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None, @@ -178,7 +179,8 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c current_chat_id = string_to_uuid(chat_id + application_id) Chat.objects.get_or_create(id=current_chat_id, defaults={ 'application_id': application_id, - 'abstract': message[0:1024] + 'abstract': message[0:1024], + 'client_id': client_id, }) if app_document_list is None: app_document_list = [] diff --git a/apps/application/flow/step_node/condition_node/compare/contain_compare.py b/apps/application/flow/step_node/condition_node/compare/contain_compare.py index 6073131a54d..044999ed918 100644 --- a/apps/application/flow/step_node/condition_node/compare/contain_compare.py +++ b/apps/application/flow/step_node/condition_node/compare/contain_compare.py @@ -20,4 +20,7 @@ def support(self, node_id, fields: List[str], source_value, compare, target_valu def compare(self, source_value, compare, target_value): if isinstance(source_value, str): return str(target_value) in source_value - return any([str(item) == str(target_value) for item in source_value]) + elif isinstance(source_value, list): + return any([str(item) == str(target_value) for item in source_value]) + else: + return str(target_value) in str(source_value) diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py index 6a51edd6bae..1d3115e4c67 100644 --- a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -15,7 +15,9 @@ class BaseReplyNode(IReplyNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: if reply_type == 'referencing': result = self.get_reference_content(fields) diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index 6ddcb6e2fca..0c4d09bce5c 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -66,7 +66,7 @@ def save_image(image_list): for doc in document: file = QuerySet(File).filter(id=doc['file_id']).first() - buffer = io.BytesIO(file.get_byte().tobytes()) + buffer = io.BytesIO(file.get_byte()) buffer.name = doc['name'] # this is the important line for split_handle in (parse_table_handle_list + split_handles): diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py index 7cbbe9cc1d4..dcf35dd3cfd 100644 --- a/apps/application/flow/step_node/form_node/impl/base_form_node.py +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -38,7 +38,8 @@ def save_context(self, details, workflow_manage): self.context['start_time'] = details.get('start_time') self.context['form_data'] = form_data self.context['is_submit'] = details.get('is_submit') - self.answer_text = details.get('result') + if self.node_params.get('is_result', False): + self.answer_text = details.get('result') if form_data is not None: for key in form_data: self.context[key] = form_data[key] @@ -70,7 +71,7 @@ def get_answer_list(self) -> List[Answer] | None: "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), 'form_data': self.context.get('form_data', {}), "is_submit": self.context.get("is_submit", False)} - form = f'{json.dumps(form_setting,ensure_ascii=False)}' + form = f'{json.dumps(form_setting, ensure_ascii=False)}' context = self.workflow_manage.get_workflow_content() form_content_format = self.workflow_manage.reset_prompt(form_content_format) prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') @@ -85,7 +86,7 @@ def get_details(self, index: int, **kwargs): "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), 'form_data': self.context.get('form_data', {}), "is_submit": self.context.get("is_submit", False)} - form = f'{json.dumps(form_setting,ensure_ascii=False)}' + form = f'{json.dumps(form_setting, ensure_ascii=False)}' context = self.workflow_manage.get_workflow_content() form_content_format = self.workflow_manage.reset_prompt(form_content_format) prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') diff --git a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py index d21424f750d..0678b81243c 100644 --- a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py +++ b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py @@ -45,6 +45,8 @@ def get_field_value(debug_field_list, name, is_required): def valid_reference_value(_type, value, name): + if value is None: + return if _type == 'int': instance_type = int | float elif _type == 'float': @@ -65,15 +67,22 @@ def valid_reference_value(_type, value, name): def convert_value(name: str, value, _type, is_required, source, node): - if not is_required and value is None: + if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)): return None if not is_required and source == 'reference' and (value is None or len(value) == 0): return None if source == 'reference': + if value and isinstance(value, list) and len(value) == 0: + if not is_required: + return None + else: + raise Exception(f"字段:{name}类型:{_type}值:{value}必填参数") value = node.workflow_manage.get_reference_field( value[0], value[1:]) valid_reference_value(_type, value, name) + if value is None: + return None if _type == 'int': return int(value) if _type == 'float': @@ -113,7 +122,8 @@ def valid_function(function_lib, user_id): class BaseFunctionLibNodeNode(IFunctionLibNode): def save_context(self, details, workflow_manage): self.context['result'] = details.get('result') - self.answer_text = str(details.get('result')) + if self.node_params.get('is_result'): + self.answer_text = str(details.get('result')) def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult: function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first() diff --git a/apps/application/flow/step_node/function_node/impl/base_function_node.py b/apps/application/flow/step_node/function_node/impl/base_function_node.py index 4a5c75c8132..f6127e55550 100644 --- a/apps/application/flow/step_node/function_node/impl/base_function_node.py +++ b/apps/application/flow/step_node/function_node/impl/base_function_node.py @@ -32,6 +32,8 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): def valid_reference_value(_type, value, name): + if value is None: + return if _type == 'int': instance_type = int | float elif _type == 'float': @@ -49,13 +51,20 @@ def valid_reference_value(_type, value, name): def convert_value(name: str, value, _type, is_required, source, node): - if not is_required and value is None: + if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)): return None if source == 'reference': + if value and isinstance(value, list) and len(value) == 0: + if not is_required: + return None + else: + raise Exception(f"字段:{name}类型:{_type}值:{value}必填参数") value = node.workflow_manage.get_reference_field( value[0], value[1:]) valid_reference_value(_type, value, name) + if value is None: + return None if _type == 'int': return int(value) if _type == 'float': @@ -84,7 +93,8 @@ def convert_value(name: str, value, _type, is_required, source, node): class BaseFunctionNodeNode(IFunctionNode): def save_context(self, details, workflow_manage): self.context['result'] = details.get('result') - self.answer_text = str(details.get('result')) + if self.node_params.get('is_result', False): + self.answer_text = str(details.get('result')) def execute(self, input_field_list, code, **kwargs) -> NodeResult: params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py index d5cc2c5a211..16423eafd61 100644 --- a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py +++ b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py @@ -16,7 +16,8 @@ class BaseImageGenerateNode(IImageGenerateNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['question'] = details.get('question') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, model_params_setting, @@ -24,7 +25,8 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t **kwargs) -> NodeResult: print(model_params_setting) application = self.workflow_manage.work_flow_post_handler.chat_info.application - tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) + tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 3b96f15cd6f..0b405619dde 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -62,14 +62,15 @@ def file_id_to_base64(file_id: str): file = QuerySet(File).filter(id=file_id).first() file_bytes = file.get_byte() base64_image = base64.b64encode(file_bytes).decode("utf-8") - return [base64_image, what(None, file_bytes.tobytes())] + return [base64_image, what(None, file_bytes)] class BaseImageUnderstandNode(IImageUnderstandNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['question'] = details.get('question') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, model_params_setting, @@ -171,7 +172,7 @@ def generate_message_list(self, image_model, system: str, prompt: str, history_m file = QuerySet(File).filter(id=file_id).first() image_bytes = file.get_byte() base64_image = base64.b64encode(image_bytes).decode("utf-8") - image_format = what(None, image_bytes.tobytes()) + image_format = what(None, image_bytes) images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}}) messages = [HumanMessage( content=[ diff --git a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py index 6c9fe97fc69..d5197e9ad11 100644 --- a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py +++ b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py @@ -14,7 +14,6 @@ def save_context(self, details, workflow_manage): self.context['result'] = details.get('result') self.context['tool_params'] = details.get('tool_params') self.context['mcp_tool'] = details.get('mcp_tool') - self.answer_text = details.get('result') def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult: servers = json.loads(mcp_servers) @@ -27,7 +26,8 @@ async def call_tool(s, session, t, a): return s res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params)) - return NodeResult({'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {}) + return NodeResult( + {'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {}) def handle_variables(self, tool_params): # 处理参数中的变量 diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 48a2639b782..e1fd5b86069 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -80,7 +80,8 @@ def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['message_tokens'] = details.get('message_tokens') self.context['answer_tokens'] = details.get('answer_tokens') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py index c85588cd4d2..8f48823f00c 100644 --- a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py +++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py @@ -18,7 +18,9 @@ class BaseSpeechToTextNode(ISpeechToTextNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') - self.answer_text = details.get('answer') + self.context['result'] = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id')) @@ -30,7 +32,7 @@ def process_audio_item(audio_item, model): # 根据file_name 吧文件转成mp3格式 file_format = file.file_name.split('.')[-1] with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_format}') as temp_file: - temp_file.write(file.get_byte().tobytes()) + temp_file.write(file.get_byte()) temp_file_path = temp_file.name with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_amr_file: temp_mp3_path = temp_amr_file.name diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index bf5203274eb..24b9684714e 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -40,10 +40,13 @@ def save_context(self, details, workflow_manage): self.context['document'] = details.get('document_list') self.context['image'] = details.get('image_list') self.context['audio'] = details.get('audio_list') + self.context['other'] = details.get('other_list') self.status = details.get('status') self.err_message = details.get('err_message') for key, value in workflow_variable.items(): workflow_manage.context[key] = value + for item in details.get('global_fields', []): + workflow_manage.context[item.get('key')] = item.get('value') def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: pass @@ -59,7 +62,8 @@ def execute(self, question, **kwargs) -> NodeResult: 'question': question, 'image': self.workflow_manage.image_list, 'document': self.workflow_manage.document_list, - 'audio': self.workflow_manage.audio_list + 'audio': self.workflow_manage.audio_list, + 'other': self.workflow_manage.other_list, } return NodeResult(node_variable, workflow_variable) @@ -83,5 +87,6 @@ def get_details(self, index: int, **kwargs): 'image_list': self.context.get('image'), 'document_list': self.context.get('document'), 'audio_list': self.context.get('audio'), + 'other_list': self.context.get('other'), 'global_fields': global_fields } diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py index 72c4d3be514..330dc5f5804 100644 --- a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py +++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py @@ -37,7 +37,9 @@ def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"): class BaseTextToSpeechNode(ITextToSpeechNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') - self.answer_text = details.get('answer') + self.context['result'] = details.get('result') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, tts_model_id, chat_id, content, model_params_setting=None, @@ -72,4 +74,5 @@ def get_details(self, index: int, **kwargs): 'content': self.context.get('content'), 'err_message': self.err_message, 'answer': self.context.get('answer'), + 'result': self.context.get('result') } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index be91f69be9e..554b0b75f47 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -14,7 +14,7 @@ from functools import reduce from typing import List, Dict -from django.db import close_old_connections +from django.db import close_old_connections, connection from django.db.models import QuerySet from django.utils import translation from django.utils.translation import get_language @@ -238,6 +238,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, document_list=None, audio_list=None, + other_list=None, start_node_id=None, start_node_data=None, chat_record=None, child_node=None): if form_data is None: @@ -248,12 +249,15 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl document_list = [] if audio_list is None: audio_list = [] + if other_list is None: + other_list = [] self.start_node_id = start_node_id self.start_node = None self.form_data = form_data self.image_list = image_list self.document_list = document_list self.audio_list = audio_list + self.other_list = other_list self.params = params self.flow = flow self.context = {} @@ -294,8 +298,8 @@ def init_fields(self): if global_fields is not None: for global_field in global_fields: global_field_list.append({**global_field, 'node_id': node_id, 'node_name': node_name}) - field_list.sort(key=lambda f: len(f.get('node_name')), reverse=True) - global_field_list.sort(key=lambda f: len(f.get('node_name')), reverse=True) + field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True) + global_field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True) self.field_list = field_list self.global_field_list = global_field_list @@ -565,6 +569,8 @@ def hand_event_node_result(self, current_node, node_result_future): return None finally: current_node.node_chunk.end() + # 归还链接 + connection.close() def run_node_async(self, node): future = executor.submit(self.run_node, node) @@ -674,10 +680,16 @@ def get_next_node(self): return None @staticmethod - def dependent_node(up_node_id, node): + def dependent_node(edge, node): + up_node_id = edge.sourceNodeId if not node.node_chunk.is_end(): return False if node.id == up_node_id: + if node.context.get('branch_id', None): + if edge.sourceAnchorId == f"{node.id}_{node.context.get('branch_id', None)}_right": + return True + else: + return False if node.type == 'form-node': if node.context.get('form_data', None) is not None: return True @@ -690,9 +702,11 @@ def dependent_node_been_executed(self, node_id): @param node_id: 需要判断的节点id @return: """ - up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] - return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in - up_node_id_list]) + up_edge_list = [edge for edge in self.flow.edges if edge.targetNodeId == node_id] + return all( + [any([self.dependent_node(edge, node) for node in self.node_context if node.id == edge.sourceNodeId]) for + edge in + up_edge_list]) def get_up_node_id_list(self, node_id): up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] @@ -751,7 +765,10 @@ def get_reference_field(self, node_id: str, fields: List[str]): if node_id == 'global': return INode.get_field(self.context, fields) else: - return self.get_node_by_id(node_id).get_reference_field(fields) + node = self.get_node_by_id(node_id) + if node: + return node.get_reference_field(fields) + return None def get_workflow_content(self): context = { diff --git a/apps/application/migrations/0015_re_database_index.py b/apps/application/migrations/0015_re_database_index.py index 740a2a2d241..cafe14e209c 100644 --- a/apps/application/migrations/0015_re_database_index.py +++ b/apps/application/migrations/0015_re_database_index.py @@ -1,9 +1,8 @@ # Generated by Django 4.2.15 on 2024-09-18 16:14 import logging -import psycopg2 +import psycopg from django.db import migrations -from psycopg2 import extensions from smartdoc.const import CONFIG @@ -17,7 +16,7 @@ def get_connect(db_name): "port": CONFIG.get('DB_PORT') } # 建立连接 - connect = psycopg2.connect(**conn_params) + connect = psycopg.connect(**conn_params) return connect @@ -28,7 +27,7 @@ def sql_execute(conn, reindex_sql: str, alter_database_sql: str): @param conn: @param alter_database_sql: """ - conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) + conn.autocommit = True with conn.cursor() as cursor: cursor.execute(reindex_sql, []) cursor.execute(alter_database_sql, []) diff --git a/apps/application/migrations/0027_add_index.py b/apps/application/migrations/0027_add_index.py new file mode 100644 index 00000000000..bb7efaca2a6 --- /dev/null +++ b/apps/application/migrations/0027_add_index.py @@ -0,0 +1,24 @@ +# Generated by Django 4.2.18 on 2025-03-18 06:05 + +import application.models.application +import common.encoder.encoder +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0026_chat_asker'), + ] + + operations = [ + migrations.AddIndex( + model_name='chat', + index=models.Index(fields=['create_time'], name='application_chat_create_time_idx'), + ), + migrations.AddIndex( + model_name='chatrecord', + index=models.Index(fields=['create_time'], name='application_chat_record_update_time_idx'), + ), + ] + diff --git a/apps/application/models/application.py b/apps/application/models/application.py index dfe9534e82b..0032271a70b 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -11,7 +11,7 @@ from django.contrib.postgres.fields import ArrayField from django.db import models from langchain.schema import HumanMessage, AIMessage - +from django.utils.translation import gettext as _ from common.encoder.encoder import SystemEncoder from common.mixins.app_model_mixin import AppModelMixin from dataset.models.data_set import DataSet @@ -167,7 +167,11 @@ def get_human_message(self): return HumanMessage(content=self.problem_text) def get_ai_message(self): - return AIMessage(content=self.answer_text) + answer_text = self.answer_text + if answer_text is None or len(str(answer_text).strip()) == 0: + answer_text = _( + 'Sorry, no relevant content was found. Please re-describe your problem or provide more information. ') + return AIMessage(content=answer_text) def get_node_details_runtime_node_id(self, runtime_node_id): return self.details.get(runtime_node_id, None) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 3792076be7c..b898100160a 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -16,6 +16,7 @@ import uuid from functools import reduce from typing import Dict, List + from django.contrib.postgres.fields import ArrayField from django.core import cache, validators from django.core import signing @@ -24,8 +25,8 @@ from django.db.models.expressions import RawSQL from django.http import HttpResponse from django.template import Template, Context +from django.utils.translation import gettext_lazy as _, get_language, to_locale from langchain_mcp_adapters.client import MultiServerMCPClient -from mcp.client.sse import sse_client from rest_framework import serializers, status from rest_framework.utils.formatting import lazy_format @@ -38,7 +39,7 @@ from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list -from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed, ChatException +from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed from common.field.common import UploadedImageField, UploadedFileField from common.models.db_model_manage import DBModelManage from common.response import result @@ -57,7 +58,6 @@ from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR from users.models import User -from django.utils.translation import gettext_lazy as _, get_language, to_locale chat_cache = cache.caches['chat_cache'] @@ -148,10 +148,12 @@ class ModelSettingSerializer(serializers.Serializer): error_messages=ErrMessage.char(_("Thinking process switch"))) reasoning_content_start = serializers.CharField(required=False, allow_null=True, default="", allow_blank=True, max_length=256, + trim_whitespace=False, error_messages=ErrMessage.char( _("The thinking process begins to mark"))) reasoning_content_end = serializers.CharField(required=False, allow_null=True, allow_blank=True, default="", max_length=256, + trim_whitespace=False, error_messages=ErrMessage.char(_("End of thinking process marker"))) @@ -162,7 +164,7 @@ class ApplicationWorkflowSerializer(serializers.Serializer): max_length=256, min_length=1, error_messages=ErrMessage.char(_("Application Description"))) work_flow = serializers.DictField(required=False, error_messages=ErrMessage.dict(_("Workflow Objects"))) - prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, error_messages=ErrMessage.char(_("Opening remarks"))) @staticmethod @@ -225,7 +227,7 @@ class ApplicationSerializer(serializers.Serializer): min_value=0, max_value=1024, error_messages=ErrMessage.integer(_("Historical chat records"))) - prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, error_messages=ErrMessage.char(_("Opening remarks"))) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), allow_null=True, @@ -320,6 +322,7 @@ def get_embed(self, with_valid=True, params=None): def get_query_api_input(self, application, params): query = '' + is_asker = False if application.work_flow is not None: work_flow = application.work_flow if work_flow is not None: @@ -331,8 +334,10 @@ def get_query_api_input(self, application, params): if input_field_list is not None: for field in input_field_list: if field['assignment_method'] == 'api_input' and field['variable'] in params: + if field['variable'] == 'asker': + is_asker = True query += f"&{field['variable']}={params[field['variable']]}" - if 'asker' in params: + if 'asker' in params and not is_asker: query += f"&asker={params.get('asker')}" return query @@ -493,7 +498,7 @@ class Edit(serializers.Serializer): min_value=0, max_value=1024, error_messages=ErrMessage.integer(_("Historical chat records"))) - prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, error_messages=ErrMessage.char(_("Opening remarks"))) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), error_messages=ErrMessage.list(_("Related Knowledge Base")) @@ -1010,7 +1015,8 @@ def profile(self, with_valid=True): 'stt_autosend': application.stt_autosend, 'file_upload_enable': application.file_upload_enable, 'file_upload_setting': application.file_upload_setting, - 'work_flow': application.work_flow, + 'work_flow': {'nodes': [node for node in ((application.work_flow or {}).get('nodes', []) or []) if + node.get('id') == 'base-node']}, 'show_source': application_access_token.show_source, 'language': application_access_token.language, **application_setting_dict}) @@ -1071,6 +1077,7 @@ def edit(self, instance: Dict, with_valid=True): for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: application.__setattr__(update_key, instance.get(update_key)) + print(application.name) application.save() if 'dataset_id_list' in instance: @@ -1089,6 +1096,7 @@ def edit(self, instance: Dict, with_valid=True): chat_cache.clear_by_application_id(application_id) application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application_id).first() # 更新缓存数据 + print(application.name) get_application_access_token(application_access_token.access_token, False) return self.one(with_valid=False) @@ -1141,6 +1149,8 @@ def get_work_flow_model(instance): instance['file_upload_enable'] = node_data['file_upload_enable'] if 'file_upload_setting' in node_data: instance['file_upload_setting'] = node_data['file_upload_setting'] + if 'name' in node_data: + instance['name'] = node_data['name'] break def speech_to_text(self, file, with_valid=True): @@ -1318,7 +1328,12 @@ class McpServers(serializers.Serializer): def get_mcp_servers(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) + if '"stdio"' in self.data.get('mcp_servers'): + raise AppApiException(500, _('stdio is not supported')) servers = json.loads(self.data.get('mcp_servers')) + for server, config in servers.items(): + if config.get('transport') not in ['sse', 'streamable_http']: + raise AppApiException(500, _('Only support transport=sse or transport=streamable_http')) async def get_mcp_tools(servers): async with MultiServerMCPClient(servers) as client: diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 2194028e6dd..e0ea7e9f555 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -213,12 +213,21 @@ def get_message(instance): return instance.get('messages')[-1].get('content') @staticmethod - def generate_chat(chat_id, application_id, message, client_id): + def generate_chat(chat_id, application_id, message, client_id, asker=None): if chat_id is None: chat_id = str(uuid.uuid1()) chat = QuerySet(Chat).filter(id=chat_id).first() if chat is None: - Chat(id=chat_id, application_id=application_id, abstract=message[0:1024], client_id=client_id).save() + asker_dict = {'user_name': '游客'} + if asker is not None: + if isinstance(asker, str): + asker_dict = { + 'user_name': asker + } + elif isinstance(asker, dict): + asker_dict = asker + Chat(id=chat_id, application_id=application_id, abstract=message[0:1024], client_id=client_id, + asker=asker_dict).save() return chat_id def chat(self, instance: Dict, with_valid=True): @@ -232,7 +241,8 @@ def chat(self, instance: Dict, with_valid=True): application_id = self.data.get('application_id') client_id = self.data.get('client_id') client_type = self.data.get('client_type') - chat_id = self.generate_chat(chat_id, application_id, message, client_id) + chat_id = self.generate_chat(chat_id, application_id, message, client_id, + asker=instance.get('form_data', {}).get("asker")) return ChatMessageSerializer( data={ 'chat_id': chat_id, 'message': message, @@ -245,6 +255,7 @@ def chat(self, instance: Dict, with_valid=True): 'image_list': instance.get('image_list', []), 'document_list': instance.get('document_list', []), 'audio_list': instance.get('audio_list', []), + 'other_list': instance.get('other_list', []), } ).chat(base_to_response=OpenaiToResponse()) @@ -274,6 +285,7 @@ class ChatMessageSerializer(serializers.Serializer): image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture"))) document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document"))) audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio"))) + other_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Other"))) child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Child Nodes"))) @@ -372,6 +384,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): image_list = self.data.get('image_list') document_list = self.data.get('document_list') audio_list = self.data.get('audio_list') + other_list = self.data.get('other_list') user_id = chat_info.application.user_id chat_record_id = self.data.get('chat_record_id') chat_record = None @@ -382,13 +395,14 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), {'history_chat_record': history_chat_record, 'question': message, 'chat_id': chat_info.chat_id, 'chat_record_id': str( - uuid.uuid1()) if chat_record is None else chat_record.id, + uuid.uuid1()) if chat_record is None else str(chat_record.id), 'stream': stream, 're_chat': re_chat, 'client_id': client_id, 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), base_to_response, form_data, image_list, document_list, audio_list, + other_list, self.data.get('runtime_node_id'), self.data.get('node_data'), chat_record, self.data.get('child_node')) r = work_flow_manage.run() diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index b90194d5ae2..14237434c1e 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -13,8 +13,9 @@ from functools import reduce from io import BytesIO from typing import Dict -import pytz + import openpyxl +import pytz from django.core import validators from django.core.cache import caches from django.db import transaction, models @@ -33,8 +34,8 @@ ModelSettingSerializer from application.serializers.chat_message_serializers import ChatInfo from common.constants.permission_constants import RoleConstants -from common.db.search import native_search, native_page_search, page_search, get_dynamics_model -from common.exception.app_exception import AppApiException +from common.db.search import native_search, native_page_search, page_search, get_dynamics_model, native_page_handler +from common.exception.app_exception import AppApiException, AppUnauthorizedFailed from common.util.common import post from common.util.field_message import ErrMessage from common.util.file_util import get_file_content @@ -144,7 +145,8 @@ def get_query_set(self, select_ids=None): 'trample_num': models.IntegerField(), 'comparer': models.CharField(), 'application_chat.update_time': models.DateTimeField(), - 'application_chat.id': models.UUIDField(), })) + 'application_chat.id': models.UUIDField(), + 'application_chat_record_temp.id': models.UUIDField()})) base_query_dict = {'application_chat.application_id': self.data.get("application_id"), 'application_chat.update_time__gte': start_time, @@ -174,7 +176,14 @@ def get_query_set(self, select_ids=None): condition = base_condition & min_trample_query else: condition = base_condition - return query_set.filter(condition).order_by("-application_chat.update_time") + inner_queryset = QuerySet(Chat).filter(application_id=self.data.get("application_id")) + if 'abstract' in self.data and self.data.get('abstract') is not None: + inner_queryset = inner_queryset.filter(abstract__icontains=self.data.get('abstract')) + + return { + 'inner_queryset': inner_queryset, + 'default_queryset': query_set.filter(condition).order_by("-application_chat.update_time") + } def list(self, with_valid=True): if with_valid: @@ -192,31 +201,40 @@ def paragraph_list_to_string(paragraph_list): @staticmethod def to_row(row: Dict): details = row.get('details') - padding_problem_text = ' '.join(node.get("answer", "") for key, node in details.items() if - node.get("type") == 'question-node') + if not details: + details = {} + padding_problem_text_list = [ + node.get("answer") or "" + for key, node in details.items() + if node.get("type") == 'question-node' + ] + padding_problem_text = ' '.join(padding_problem_text_list) + search_dataset_node_list = [(key, node) for key, node in details.items() if node.get("type") == 'search-dataset-node' or node.get( "step_type") == 'search_step'] reference_paragraph_len = '\n'.join([str(len(node.get('paragraph_list', - []))) if key == 'search_step' else node.get( - 'name') + ':' + str( + []))) if key == 'search_step' else (node.get( + 'name') or '') + ':' + str( len(node.get('paragraph_list', [])) if node.get('paragraph_list', []) is not None else '0') for key, node in search_dataset_node_list]) reference_paragraph = '\n----------\n'.join( [ChatSerializers.Query.paragraph_list_to_string(node.get('paragraph_list', - [])) if key == 'search_step' else node.get( - 'name') + ':\n' + ChatSerializers.Query.paragraph_list_to_string(node.get('paragraph_list', - [])) for + [])) if key == 'search_step' else (node.get( + 'name') or '') + ':\n' + ChatSerializers.Query.paragraph_list_to_string(node.get('paragraph_list', + [])) for key, node in search_dataset_node_list]) improve_paragraph_list = row.get('improve_paragraph_list') vote_status_map = {'-1': '未投票', '0': '赞同', '1': '反对'} + asker = row.get('asker') or {} return [str(row.get('chat_id')), row.get('abstract'), row.get('problem_text'), padding_problem_text, row.get('answer_text'), vote_status_map.get(row.get('vote_status')), reference_paragraph_len, reference_paragraph, "\n".join([ f"{improve_paragraph_list[index].get('title')}\n{improve_paragraph_list[index].get('content')}" - for index in range(len(improve_paragraph_list))]), - row.get('asker').get('user_name'), + for index in range(len(improve_paragraph_list)) + ]) if improve_paragraph_list is not None else "", + asker.get('user_name', ''), row.get('message_tokens') + row.get('answer_tokens'), row.get('run_time'), str(row.get('create_time').astimezone(pytz.timezone(TIME_ZONE)).strftime('%Y-%m-%d %H:%M:%S') )] @@ -225,55 +243,90 @@ def export(self, data, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - data_list = native_search(self.get_query_set(data.get('select_ids')), - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "application", 'sql', - 'export_application_chat.sql')), - with_table_name=False) + batch_size = 2000 - batch_size = 500 + select_sql = get_file_content( + os.path.join( + PROJECT_DIR, + "apps", + "application", + "sql", + "export_application_chat.sql" + ) + ) def stream_response(): - workbook = openpyxl.Workbook() - worksheet = workbook.active - worksheet.title = 'Sheet1' - - headers = [gettext('Conversation ID'), gettext('summary'), gettext('User Questions'), - gettext('Problem after optimization'), - gettext('answer'), gettext('User feedback'), - gettext('Reference segment number'), - gettext('Section title + content'), - gettext('Annotation'), gettext('USER'), gettext('Consuming tokens'), - gettext('Time consumed (s)'), - gettext('Question Time')] - for col_idx, header in enumerate(headers, 1): - cell = worksheet.cell(row=1, column=col_idx) - cell.value = header - - for i in range(0, len(data_list), batch_size): - batch_data = data_list[i:i + batch_size] - - for row_idx, row in enumerate(batch_data, start=i + 2): - for col_idx, value in enumerate(self.to_row(row), 1): - cell = worksheet.cell(row=row_idx, column=col_idx) - if isinstance(value, str): - value = re.sub(ILLEGAL_CHARACTERS_RE, '', value) - if isinstance(value, datetime.datetime): - eastern = pytz.timezone(TIME_ZONE) - c = datetime.timezone(eastern._utcoffset) - value = value.astimezone(c) - cell.value = value - - output = BytesIO() - workbook.save(output) - output.seek(0) - yield output.getvalue() - output.close() - workbook.close() - - response = StreamingHttpResponse(stream_response(), - content_type='application/vnd.open.xmlformats-officedocument.spreadsheetml.sheet') + import tempfile + + headers = [ + gettext('Conversation ID'), + gettext('summary'), + gettext('User Questions'), + gettext('Problem after optimization'), + gettext('answer'), + gettext('User feedback'), + gettext('Reference segment number'), + gettext('Section title + content'), + gettext('Annotation'), + gettext('USER'), + gettext('Consuming tokens'), + gettext('Time consumed (s)'), + gettext('Question Time') + ] + + with tempfile.NamedTemporaryFile(suffix=".xlsx") as tmp: + + workbook = openpyxl.Workbook(write_only=True) + worksheet = workbook.create_sheet(title="Sheet1") + + # 写表头 + worksheet.append(headers) + + for data_list in native_page_handler( + batch_size, + self.get_query_set(data.get('select_ids')), + primary_key='application_chat_record_temp.id', + primary_queryset='default_queryset', + get_primary_value=lambda item: item.get('id'), + select_string=select_sql, + with_table_name=False + ): + + for row in data_list: + + row_values = [] + for value in self.to_row(row): + + if isinstance(value, str): + value = re.sub(ILLEGAL_CHARACTERS_RE, '', value) + + elif isinstance(value, datetime.datetime): + eastern = pytz.timezone(TIME_ZONE) + c = datetime.timezone(eastern._utcoffset) + value = value.astimezone(c) + + row_values.append(value) + + worksheet.append(row_values) + + workbook.save(tmp.name) + workbook.close() + + # 分块返回文件 + with open(tmp.name, "rb") as f: + while True: + chunk = f.read(8192) + if not chunk: + break + yield chunk + + response = StreamingHttpResponse( + stream_response(), + content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ) + response['Content-Disposition'] = 'attachment; filename="data.xlsx"' + return response def page(self, current_page: int, page_size: int, with_valid=True): @@ -476,6 +529,13 @@ class Query(serializers.Serializer): chat_id = serializers.UUIDField(required=True) order_asc = serializers.BooleanField(required=False, allow_null=True) + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + exist = QuerySet(Chat).filter(id=self.data.get("chat_id"), + application_id=self.data.get("application_id")).exists() + if not exist: + raise AppUnauthorizedFailed(403, _('No permission to access')) + def list(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) diff --git a/apps/application/sql/export_application_chat.sql b/apps/application/sql/export_application_chat.sql index bb265ea5b02..503e4a91d15 100644 --- a/apps/application/sql/export_application_chat.sql +++ b/apps/application/sql/export_application_chat.sql @@ -1,38 +1,38 @@ -SELECT - application_chat."id" as chat_id, - application_chat.abstract as abstract, - application_chat_record_temp.problem_text as problem_text, - application_chat_record_temp.answer_text as answer_text, - application_chat_record_temp.message_tokens as message_tokens, - application_chat_record_temp.answer_tokens as answer_tokens, - application_chat_record_temp.run_time as run_time, - application_chat_record_temp.details::JSON as details, - application_chat_record_temp."index" as "index", - application_chat_record_temp.improve_paragraph_list as improve_paragraph_list, - application_chat_record_temp.vote_status as vote_status, - application_chat_record_temp.create_time as create_time, - to_json(application_chat.asker) as asker -FROM - application_chat application_chat - LEFT JOIN ( - SELECT COUNT - ( "id" ) AS chat_record_count, - SUM ( CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END ) AS star_num, - SUM ( CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END ) AS trample_num, - SUM ( CASE WHEN array_length( application_chat_record.improve_paragraph_id_list, 1 ) IS NULL THEN 0 ELSE array_length( application_chat_record.improve_paragraph_id_list, 1 ) END ) AS mark_sum, - chat_id - FROM - application_chat_record - GROUP BY - application_chat_record.chat_id - ) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id - LEFT JOIN ( - SELECT - *, - CASE - WHEN array_length( application_chat_record.improve_paragraph_id_list, 1 ) IS NULL THEN - '{}' ELSE ( SELECT ARRAY_AGG ( row_to_json ( paragraph ) ) FROM paragraph WHERE "id" = ANY ( application_chat_record.improve_paragraph_id_list ) ) - END as improve_paragraph_list - FROM - application_chat_record application_chat_record - ) application_chat_record_temp ON application_chat_record_temp.chat_id = application_chat."id" \ No newline at end of file +SELECT application_chat_record_temp.id AS id, + application_chat."id" as chat_id, + application_chat.abstract as abstract, + application_chat_record_temp.problem_text as problem_text, + application_chat_record_temp.answer_text as answer_text, + application_chat_record_temp.message_tokens as message_tokens, + application_chat_record_temp.answer_tokens as answer_tokens, + application_chat_record_temp.run_time as run_time, + application_chat_record_temp.details::JSON as details, application_chat_record_temp."index" as "index", + application_chat_record_temp.improve_paragraph_list as improve_paragraph_list, + application_chat_record_temp.vote_status as vote_status, + application_chat_record_temp.create_time as create_time, + to_json(application_chat.asker) as asker +FROM application_chat application_chat + + LEFT JOIN (SELECT COUNT(acr."id") AS chat_record_count, + SUM((acr."vote_status" = '0')::int) AS star_num, + SUM((acr."vote_status" = '1')::int) AS trample_num, + SUM(COALESCE(array_length(acr.improve_paragraph_id_list, 1), 0)) AS mark_sum, + acr.chat_id + FROM application_chat_record acr + WHERE EXISTS (SELECT 1 + FROM application_chat ac2 + ${inner_queryset} + AND ac2.id = acr.chat_id) + GROUP BY acr.chat_id) chat_record_temp + ON application_chat."id" = chat_record_temp.chat_id + + LEFT JOIN (SELECT acr.*, + COALESCE(p.paragraph_list, '{}') as improve_paragraph_list + FROM application_chat_record acr + LEFT JOIN LATERAL ( + SELECT ARRAY_AGG(row_to_json(paragraph)) as paragraph_list + FROM paragraph + WHERE paragraph."id" = ANY (acr.improve_paragraph_id_list) + ) p ON TRUE) application_chat_record_temp + ON application_chat_record_temp.chat_id = application_chat."id" + ${default_queryset} \ No newline at end of file diff --git a/apps/application/sql/list_application_chat.sql b/apps/application/sql/list_application_chat.sql index 7f3e1680c99..6c807a6c83c 100644 --- a/apps/application/sql/list_application_chat.sql +++ b/apps/application/sql/list_application_chat.sql @@ -1,16 +1,21 @@ SELECT - *,to_json(asker) as asker -FROM - application_chat application_chat - LEFT JOIN ( - SELECT COUNT - ( "id" ) AS chat_record_count, - SUM ( CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END ) AS star_num, - SUM ( CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END ) AS trample_num, - SUM ( CASE WHEN array_length( application_chat_record.improve_paragraph_id_list, 1 ) IS NULL THEN 0 ELSE array_length( application_chat_record.improve_paragraph_id_list, 1 ) END ) AS mark_sum, - chat_id - FROM - application_chat_record - GROUP BY - application_chat_record.chat_id - ) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id \ No newline at end of file + application_chat.*, + to_json(application_chat.asker) AS asker, + chat_record_temp.chat_record_count, + chat_record_temp.star_num, + chat_record_temp.trample_num, + chat_record_temp.mark_sum +FROM application_chat +LEFT JOIN ( + SELECT + application_chat_record.chat_id, + COUNT(application_chat_record.id) AS chat_record_count, + SUM((application_chat_record.vote_status = '0')::int) AS star_num, + SUM((application_chat_record.vote_status = '1')::int) AS trample_num, + SUM(COALESCE(array_length(application_chat_record.improve_paragraph_id_list, 1), 0)) AS mark_sum + FROM application_chat_record + JOIN application_chat application_chat ON application_chat.id = application_chat_record.chat_id + ${inner_queryset} + GROUP BY application_chat_record.chat_id +) chat_record_temp ON application_chat.id = chat_record_temp.chat_id +${default_queryset} \ No newline at end of file diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 2c9cbd86bf4..024279832b1 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -61,8 +61,6 @@ def get_response_body_api(): 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title=_("Affiliation user"), description=_("Affiliation user")), - 'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title=_("Is publish"), description=_('Is publish')), - 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title=_("Creation time"), description=_('Creation time')), @@ -302,7 +300,19 @@ def get_request_body_api(): 'no_references_prompt': openapi.Schema(type=openapi.TYPE_STRING, title=_("No citation segmentation prompt"), default="{question}", - description=_("No citation segmentation prompt")) + description=_("No citation segmentation prompt")), + 'reasoning_content_enable': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title=_("Reasoning enable"), + default=False, + description=_("Reasoning enable")), + 'reasoning_content_end': openapi.Schema(type=openapi.TYPE_STRING, + title=_("Reasoning end tag"), + default="", + description=_("Reasoning end tag")), + "reasoning_content_start": openapi.Schema(type=openapi.TYPE_STRING, + title=_("Reasoning start tag"), + default="", + description=_("Reasoning start tag")) } ) diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 54b5678f747..f27a19c200e 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -326,11 +326,6 @@ def get_request_params_api(): type=openapi.TYPE_STRING, required=True, description=_('Application ID')), - openapi.Parameter(name='history_day', - in_=openapi.IN_QUERY, - type=openapi.TYPE_NUMBER, - required=True, - description=_('Historical days')), openapi.Parameter(name='abstract', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, description=_("abstract")), openapi.Parameter(name='min_star', in_=openapi.IN_QUERY, type=openapi.TYPE_INTEGER, required=False, diff --git a/apps/application/views/application_version_views.py b/apps/application/views/application_version_views.py index de900936268..1cd42a643a0 100644 --- a/apps/application/views/application_version_views.py +++ b/apps/application/views/application_version_views.py @@ -48,7 +48,11 @@ class Page(APIView): ApplicationVersionApi.Query.get_request_params_api()), responses=result.get_page_api_response(ApplicationVersionApi.get_response_body_api()), tags=[_('Application/Version')]) - @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + @has_permissions(PermissionConstants.APPLICATION_READ, + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), compare=CompareConstants.AND) def get(self, request: Request, application_id: str, current_page: int, page_size: int): return result.success( ApplicationVersionSerializer.Query( @@ -65,7 +69,14 @@ class Operate(APIView): manual_parameters=ApplicationVersionApi.Operate.get_request_params_api(), responses=result.get_api_response(ApplicationVersionApi.get_response_body_api()), tags=[_('Application/Version')]) - @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + @has_permissions(PermissionConstants.APPLICATION_READ, ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission( + group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND), + compare=CompareConstants.AND) def get(self, request: Request, application_id: str, work_flow_version_id: str): return result.success( ApplicationVersionSerializer.Operate( diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index f16041d1de3..8c3e8059bcb 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -7,16 +7,6 @@ @desc: """ -from django.core import cache -from django.http import HttpResponse -from django.utils.translation import gettext_lazy as _, gettext -from drf_yasg.utils import swagger_auto_schema -from langchain_core.prompts import PromptTemplate -from rest_framework.decorators import action -from rest_framework.parsers import MultiPartParser -from rest_framework.request import Request -from rest_framework.views import APIView - from application.serializers.application_serializers import ApplicationSerializer from application.serializers.application_statistics_serializers import ApplicationStatisticsSerializer from application.swagger_api.application_api import ApplicationApi @@ -31,6 +21,14 @@ from common.swagger_api.common_api import CommonApi from common.util.common import query_params_to_single_dict from dataset.serializers.dataset_serializers import DataSetSerializers +from django.core import cache +from django.http import HttpResponse +from django.utils.translation import gettext_lazy as _ +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.request import Request +from rest_framework.views import APIView chat_cache = cache.caches['chat_cache'] @@ -494,7 +492,7 @@ def get(self, request: Request): class HitTest(APIView): authentication_classes = [TokenAuth] - @action(methods="GET", detail=False) + @action(methods="PUT", detail=False) @swagger_auto_schema(operation_summary=_("Hit Test List"), operation_id=_("Hit Test List"), manual_parameters=CommonApi.HitTestApi.get_request_params_api(), responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), @@ -505,15 +503,15 @@ class HitTest(APIView): [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], compare=CompareConstants.AND)) - def get(self, request: Request, application_id: str): - return result.success( - ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id, - "query_text": request.query_params.get("query_text"), - "top_number": request.query_params.get("top_number"), - 'similarity': request.query_params.get('similarity'), - 'search_mode': request.query_params.get( - 'search_mode')}).hit_test( - )) + def put(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.HitTest(data={ + 'id': application_id, + 'user_id': request.user.id, + "query_text": request.data.get("query_text"), + "top_number": request.data.get("top_number"), + 'similarity': request.data.get('similarity'), + 'search_mode': request.data.get('search_mode')} + ).hit_test()) class Publish(APIView): authentication_classes = [TokenAuth] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 0415f8208dc..30d54fa65a4 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -59,7 +59,8 @@ class Export(APIView): @has_permissions( ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) @log(menu='Conversation Log', operate="Export conversation", get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id'))) @@ -144,6 +145,8 @@ def post(self, request: Request, chat_id: str): 'document_list') if 'document_list' in request.data else [], 'audio_list': request.data.get( 'audio_list') if 'audio_list' in request.data else [], + 'other_list': request.data.get( + 'other_list') if 'other_list' in request.data else [], 'client_type': request.auth.client_type, 'node_id': request.data.get('node_id', None), 'runtime_node_id': request.data.get('runtime_node_id', None), @@ -162,7 +165,9 @@ def post(self, request: Request, chat_id: str): @has_permissions( ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND + ) ) def get(self, request: Request, application_id: str): return result.success(ChatSerializers.Query( @@ -180,8 +185,7 @@ class Operate(APIView): [RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, dynamic_tag=keywords.get('application_id'))], - compare=CompareConstants.AND), - compare=CompareConstants.AND) + compare=CompareConstants.AND)) @log(menu='Conversation Log', operate="Delete a conversation", get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id'))) def delete(self, request: Request, application_id: str, chat_id: str): @@ -204,7 +208,8 @@ class ClientChatHistoryPage(APIView): @has_permissions( ViewPermission([RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, current_page: int, page_size: int): return result.success(ChatSerializers.ClientChatHistory( @@ -239,7 +244,7 @@ def delete(self, request: Request, application_id: str, chat_id: str): request_body=ChatClientHistoryApi.Operate.ReAbstract.get_request_body_api(), tags=[_("Application/Conversation Log")]) @has_permissions(ViewPermission( - [RoleConstants.APPLICATION_ACCESS_TOKEN], + [RoleConstants.APPLICATION_ACCESS_TOKEN, RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], compare=CompareConstants.AND), @@ -265,7 +270,8 @@ class Page(APIView): @has_permissions( ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, current_page: int, page_size: int): return result.success(ChatSerializers.Query( @@ -290,7 +296,8 @@ class Operate(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): return result.success(ChatRecordSerializer.Operate( @@ -308,7 +315,8 @@ def get(self, request: Request, application_id: str, chat_id: str, chat_record_i @has_permissions( ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, chat_id: str): return result.success(ChatRecordSerializer.Query( @@ -327,9 +335,11 @@ class Page(APIView): tags=[_("Application/Conversation Log")] ) @has_permissions( - ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, chat_id: str, current_page: int, page_size: int): return result.success(ChatRecordSerializer.Query( @@ -352,7 +362,8 @@ class Vote(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) @log(menu='Conversation Log', operate="Like, Dislike", get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id'))) @@ -375,7 +386,7 @@ class ChatRecordImprove(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))] - )) + , compare=CompareConstants.AND)) def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): return result.success(ChatRecordSerializer.ChatRecordImprove( data={'chat_id': chat_id, 'chat_record_id': chat_record_id}).get()) @@ -395,7 +406,7 @@ class Improve(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], - + compare=CompareConstants.AND ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, @@ -422,6 +433,7 @@ def put(self, request: Request, application_id: str, chat_id: str, chat_record_i ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.DATASET, @@ -449,6 +461,7 @@ class Operate(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.DATASET, @@ -497,7 +510,8 @@ class UploadFile(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))] + , compare=CompareConstants.AND) ) def post(self, request: Request, application_id: str, chat_id: str): files = request.FILES.getlist('file') diff --git a/apps/common/auth/handle/impl/user_token.py b/apps/common/auth/handle/impl/user_token.py index dbb6bd2b51a..bdb041f9f79 100644 --- a/apps/common/auth/handle/impl/user_token.py +++ b/apps/common/auth/handle/impl/user_token.py @@ -6,18 +6,18 @@ @date:2024/3/14 03:02 @desc: 用户认证 """ +from django.core import cache from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ from common.auth.handle.auth_base_handle import AuthBaseHandle from common.constants.authentication_type import AuthenticationType from common.constants.permission_constants import RoleConstants, get_permission_list_by_role, Auth from common.exception.app_exception import AppAuthenticationFailed -from smartdoc.settings import JWT_AUTH +from smartdoc.const import CONFIG from users.models import User -from django.core import cache - from users.models.user import get_user_dynamics_permission -from django.utils.translation import gettext_lazy as _ + token_cache = cache.caches['token_cache'] @@ -35,7 +35,7 @@ def handle(self, request, token: str, get_token_details): auth_details = get_token_details() user = QuerySet(User).get(id=auth_details['id']) # 续期 - token_cache.touch(token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds()) + token_cache.touch(token, timeout=CONFIG.get_session_timeout()) rule = RoleConstants[user.role] permission_list = get_permission_list_by_role(RoleConstants[user.role]) # 获取用户的应用和知识库的权限 diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index a6e9ab9aa9b..69081be055d 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -11,35 +11,50 @@ from common.cache.mem_cache import MemCache -lock = threading.Lock() +_lock = threading.Lock() +locks = {} class ModelManage: cache = MemCache('model', {}) up_clear_time = time.time() + @staticmethod + def _get_lock(_id): + lock = locks.get(_id) + if lock is None: + with _lock: + lock = locks.get(_id) + if lock is None: + lock = threading.Lock() + locks[_id] = lock + + return lock + @staticmethod def get_model(_id, get_model): - # 获取锁 - lock.acquire() - try: - model_instance = ModelManage.cache.get(_id) - if model_instance is None or not model_instance.is_cache_model(): + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + lock = ModelManage._get_lock(_id) + with lock: + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8) + else: + if model_instance.is_cache_model(): + ModelManage.cache.touch(_id, timeout=60 * 60 * 8) + else: model_instance = get_model(_id) - ModelManage.cache.set(_id, model_instance, timeout=60 * 30) - return model_instance - # 续期 - ModelManage.cache.touch(_id, timeout=60 * 30) - ModelManage.clear_timeout_cache() - return model_instance - finally: - # 释放锁 - lock.release() + ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8) + ModelManage.clear_timeout_cache() + return model_instance @staticmethod def clear_timeout_cache(): - if time.time() - ModelManage.up_clear_time > 60: - ModelManage.cache.clear_timeout_data() + if time.time() - ModelManage.up_clear_time > 60 * 60: + threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start() + ModelManage.up_clear_time = time.time() @staticmethod def delete_key(_id): diff --git a/apps/common/db/search.py b/apps/common/db/search.py index bef42a1414a..07ecd1b0262 100644 --- a/apps/common/db/search.py +++ b/apps/common/db/search.py @@ -170,6 +170,51 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet | D return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size) +def native_page_handler(page_size: int, + queryset: QuerySet | Dict[str, QuerySet], + select_string: str, + field_replace_dict=None, + with_table_name=False, + primary_key=None, + get_primary_value=None, + primary_queryset: str = None, + ): + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict({**queryset, + primary_queryset: queryset[primary_queryset].order_by( + primary_key)}, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset.order_by( + primary_key), select_string, field_replace_dict, with_table_name) + total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql + total = select_one(total_sql, exec_params) + processed_count = 0 + last_id = None + while processed_count < total.get("count"): + if last_id is not None: + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict({**queryset, + primary_queryset: queryset[primary_queryset].filter( + **{f"{primary_key}__gt": last_id}).order_by( + primary_key)}, + select_string, field_replace_dict, + with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query( + queryset.filter(**{f"{primary_key}__gt": last_id}).order_by( + primary_key), + select_string, field_replace_dict, + with_table_name) + limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql( + 0, page_size + ) + page_sql = exec_sql + " " + limit_sql + result = select_list(page_sql, exec_params) + yield result + processed_count += page_size + last_id = get_primary_value(result[-1]) + + def get_field_replace_dict(queryset: QuerySet): """ 获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx" diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 72d16ebb523..6899c31f33e 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -24,6 +24,7 @@ from common.util.lock import try_lock, un_lock from common.util.page_utils import page_desc from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State +from dataset.serializers.common_serializers import create_dataset_index from embedding.models import SourceType, SearchMode from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext_lazy as _ @@ -238,11 +239,8 @@ def update_status(query_set: QuerySet, taskType: TaskType, state: State): for key in params_dict: _value_ = params_dict[key] exec_sql = exec_sql.replace(key, str(_value_)) - lock.acquire() - try: + with lock: native_update(query_set, exec_sql) - finally: - lock.release() @staticmethod def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None): @@ -272,7 +270,6 @@ def is_the_task_interrupted(): ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED) - # 根据段落进行向量化处理 page_desc(QuerySet(Paragraph) .annotate( @@ -285,6 +282,8 @@ def is_the_task_interrupted(): ListenerManagement.get_aggregation_document_status( document_id)), is_the_task_interrupted) + # 检查是否存在索引 + create_dataset_index(document_id=document_id) except Exception as e: max_kb_error.error(_('Vectorized document: {document_id} error {error} {traceback}').format( document_id=document_id, error=str(e), traceback=traceback.format_exc())) diff --git a/apps/common/forms/__init__.py b/apps/common/forms/__init__.py index 6095421935b..251f01df092 100644 --- a/apps/common/forms/__init__.py +++ b/apps/common/forms/__init__.py @@ -22,3 +22,4 @@ from .radio_card_field import * from .label import * from .slider_field import * +from .switch_field import * diff --git a/apps/common/forms/switch_field.py b/apps/common/forms/switch_field.py index 9fa176beea0..ea119c3ecfb 100644 --- a/apps/common/forms/switch_field.py +++ b/apps/common/forms/switch_field.py @@ -28,6 +28,6 @@ def __init__(self, label: str or BaseLabel, @param props_info: """ - super().__init__('Switch', label, required, default_value, relation_show_field_dict, + super().__init__('SwitchInput', label, required, default_value, relation_show_field_dict, {}, TriggerType.OPTION_LIST, attrs, props_info) diff --git a/apps/common/handle/impl/doc_split_handle.py b/apps/common/handle/impl/doc_split_handle.py index 1df7b6a66e0..4161f13a19d 100644 --- a/apps/common/handle/impl/doc_split_handle.py +++ b/apps/common/handle/impl/doc_split_handle.py @@ -112,11 +112,7 @@ def get_image_id(image_id): title_font_list = [ [36, 100], - [26, 36], - [24, 26], - [22, 24], - [18, 22], - [16, 18] + [30, 36] ] @@ -130,7 +126,7 @@ def get_title_level(paragraph: Paragraph): if len(paragraph.runs) == 1: font_size = paragraph.runs[0].font.size pt = font_size.pt - if pt >= 16: + if pt >= 30: for _value, index in zip(title_font_list, range(len(title_font_list))): if pt >= _value[0] and pt < _value[1]: return index + 1 diff --git a/apps/common/handle/impl/table/xls_parse_table_handle.py b/apps/common/handle/impl/table/xls_parse_table_handle.py index 5609e3e8835..897e347e8a8 100644 --- a/apps/common/handle/impl/table/xls_parse_table_handle.py +++ b/apps/common/handle/impl/table/xls_parse_table_handle.py @@ -82,7 +82,10 @@ def get_content(self, file, save_image): for row in data: # 将每个单元格中的内容替换换行符为
以保留原始格式 md_table += '| ' + ' | '.join( - [str(cell).replace('\n', '
') if cell else '' for cell in row]) + ' |\n' + [str(cell) + .replace('\r\n', '
') + .replace('\n', '
') + if cell else '' for cell in row]) + ' |\n' md_tables += md_table + '\n\n' return md_tables diff --git a/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/apps/common/handle/impl/table/xlsx_parse_table_handle.py index abaec05769a..a68eb14f1a1 100644 --- a/apps/common/handle/impl/table/xlsx_parse_table_handle.py +++ b/apps/common/handle/impl/table/xlsx_parse_table_handle.py @@ -19,36 +19,24 @@ def support(self, file, get_buffer): def fill_merged_cells(self, sheet, image_dict): data = [] - - # 获取第一行作为标题行 - headers = [] - for idx, cell in enumerate(sheet[1]): - if cell.value is None: - headers.append(' ' * (idx + 1)) - else: - headers.append(cell.value) - # 从第二行开始遍历每一行 - for row in sheet.iter_rows(min_row=2, values_only=False): - row_data = {} + for row in sheet.iter_rows(values_only=False): + row_data = [] for col_idx, cell in enumerate(row): cell_value = cell.value - - # 如果单元格为空,并且该单元格在合并单元格内,获取合并单元格的值 - if cell_value is None: - for merged_range in sheet.merged_cells.ranges: - if cell.coordinate in merged_range: - cell_value = sheet[merged_range.min_row][merged_range.min_col - 1].value - break - image = image_dict.get(cell_value, None) if image is not None: cell_value = f'![](/api/image/{image.id})' # 使用标题作为键,单元格的值作为值存入字典 - row_data[headers[col_idx]] = cell_value + row_data.insert(col_idx, cell_value) data.append(row_data) + for merged_range in sheet.merged_cells.ranges: + cell_value = data[merged_range.min_row - 1][merged_range.min_col - 1] + for row_index in range(merged_range.min_row, merged_range.max_row + 1): + for col_index in range(merged_range.min_col, merged_range.max_col + 1): + data[row_index - 1][col_index - 1] = cell_value return data def handle(self, file, get_buffer, save_image): @@ -65,11 +53,13 @@ def handle(self, file, get_buffer, save_image): paragraphs = [] ws = wb[sheetname] data = self.fill_merged_cells(ws, image_dict) - - for row in data: - row_output = "; ".join([f"{key}: {value}" for key, value in row.items()]) - # print(row_output) - paragraphs.append({'title': '', 'content': row_output}) + if len(data) >= 2: + head_list = data[0] + for row_index in range(1, len(data)): + row_output = "; ".join( + [f"{head_list[col_index]}: {data[row_index][col_index]}" for col_index in + range(0, len(data[row_index]))]) + paragraphs.append({'title': '', 'content': row_output}) result.append({'name': sheetname, 'paragraphs': paragraphs}) @@ -78,7 +68,6 @@ def handle(self, file, get_buffer, save_image): return [{'name': file.name, 'paragraphs': []}] return result - def get_content(self, file, save_image): try: # 加载 Excel 文件 @@ -94,18 +83,18 @@ def get_content(self, file, save_image): # 如果未指定 sheet_name,则使用第一个工作表 for sheetname in workbook.sheetnames: sheet = workbook[sheetname] if sheetname else workbook.active - rows = self.fill_merged_cells(sheet, image_dict) - if len(rows) == 0: + data = self.fill_merged_cells(sheet, image_dict) + if len(data) == 0: continue # 提取表头和内容 - headers = [f"{key}" for key, value in rows[0].items()] + headers = [f"{value}" for value in data[0]] # 构建 Markdown 表格 md_table = '| ' + ' | '.join(headers) + ' |\n' md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n' - for row in rows: - r = [f'{value}' for key, value in row.items()] + for row_index in range(1, len(data)): + r = [f'{value}' for value in data[row_index]] md_table += '| ' + ' | '.join( [str(cell).replace('\n', '
') if cell is not None else '' for cell in r]) + ' |\n' diff --git a/apps/common/handle/impl/xls_split_handle.py b/apps/common/handle/impl/xls_split_handle.py index 3d8afdf62de..dbdcc95506d 100644 --- a/apps/common/handle/impl/xls_split_handle.py +++ b/apps/common/handle/impl/xls_split_handle.py @@ -14,7 +14,7 @@ def post_cell(cell_value): - return cell_value.replace('\n', '
').replace('|', '|') + return cell_value.replace('\r\n', '
').replace('\n', '
').replace('|', '|') def row_to_md(row): diff --git a/apps/common/init/init_template.py b/apps/common/init/init_template.py new file mode 100644 index 00000000000..f77a86ec2a8 --- /dev/null +++ b/apps/common/init/init_template.py @@ -0,0 +1,54 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: init_jinja.py + @date:2025/12/1 17:16 + @desc: +""" +from typing import Any + +from jinja2.sandbox import SandboxedEnvironment +from langchain_core.prompts.string import DEFAULT_FORMATTER_MAPPING, _HAS_JINJA2 + + +def jinja2_formatter(template: str, /, **kwargs: Any) -> str: + """Format a template using jinja2. + + *Security warning*: + As of LangChain 0.0.329, this method uses Jinja2's + SandboxedEnvironment by default. However, this sand-boxing should + be treated as a best-effort approach rather than a guarantee of security. + Do not accept jinja2 templates from untrusted sources as they may lead + to arbitrary Python code execution. + + https://jinja.palletsprojects.com/en/3.1.x/sandbox/ + + Args: + template: The template string. + **kwargs: The variables to format the template with. + + Returns: + The formatted string. + + Raises: + ImportError: If jinja2 is not installed. + """ + if not _HAS_JINJA2: + msg = ( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + "Please be cautious when using jinja2 templates. " + "Do not expand jinja2 templates using unverified or user-controlled " + "inputs as that can result in arbitrary Python code execution." + ) + raise ImportError(msg) + + # Use a restricted sandbox that blocks ALL attribute/method access + # Only simple variable lookups like {{variable}} are allowed + # Attribute access like {{variable.attr}} or {{variable.method()}} is blocked + return SandboxedEnvironment().from_string(template).render(**kwargs) + + +def run(): + DEFAULT_FORMATTER_MAPPING['jinja2'] = jinja2_formatter diff --git a/apps/common/management/commands/services/services/gunicorn.py b/apps/common/management/commands/services/services/gunicorn.py index cc42c4f7cb3..a32220ab881 100644 --- a/apps/common/management/commands/services/services/gunicorn.py +++ b/apps/common/management/commands/services/services/gunicorn.py @@ -16,13 +16,14 @@ def cmd(self): log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'{HTTP_HOST}:{HTTP_PORT}' + max_requests = 10240 if int(self.worker) > 1 else 0 cmd = [ 'gunicorn', 'smartdoc.wsgi:application', '-b', bind, '-k', 'gthread', '--threads', '200', '-w', str(self.worker), - '--max-requests', '10240', + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', '--access-logformat', log_format, '--access-logfile', '-' diff --git a/apps/common/management/commands/services/services/local_model.py b/apps/common/management/commands/services/services/local_model.py index 4511f8f5fee..db11d2d404f 100644 --- a/apps/common/management/commands/services/services/local_model.py +++ b/apps/common/management/commands/services/services/local_model.py @@ -24,13 +24,15 @@ def cmd(self): os.environ.setdefault('SERVER_NAME', 'local_model') log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + worker = CONFIG.get("LOCAL_MODEL_HOST_WORKER", 1) + max_requests = 10240 if int(worker) > 1 else 0 cmd = [ 'gunicorn', 'smartdoc.wsgi:application', '-b', bind, '-k', 'gthread', '--threads', '200', - '-w', "1", - '--max-requests', '10240', + '-w', str(worker), + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', '--access-logformat', log_format, '--access-logfile', '-' diff --git a/apps/common/middleware/doc_headers_middleware.py b/apps/common/middleware/doc_headers_middleware.py index d818b842ca5..83419b19fb0 100644 --- a/apps/common/middleware/doc_headers_middleware.py +++ b/apps/common/middleware/doc_headers_middleware.py @@ -9,43 +9,102 @@ from django.http import HttpResponse from django.utils.deprecation import MiddlewareMixin +from common.auth import handles, TokenDetails + content = """ - + Document + + + + - - + """ @@ -54,9 +113,18 @@ class DocHeadersMiddleware(MiddlewareMixin): def process_response(self, request, response): if request.path.startswith('/doc/') or request.path.startswith('/doc/chat/'): - HTTP_REFERER = request.META.get('HTTP_REFERER') - if HTTP_REFERER is None: + auth = request.COOKIES.get('Authorization') + if auth is None: return HttpResponse(content) - if HTTP_REFERER == request._current_scheme_host + request.path: - return response + else: + try: + token = auth + token_details = TokenDetails(token) + for handle in handles: + if handle.support(request, token, token_details.get_token_details): + handle.handle(request, token, token_details.get_token_details) + return response + return HttpResponse(content) + except Exception as e: + return HttpResponse(content) return response diff --git a/apps/common/swagger_api/common_api.py b/apps/common/swagger_api/common_api.py index 3134db0d083..9e7d1976298 100644 --- a/apps/common/swagger_api/common_api.py +++ b/apps/common/swagger_api/common_api.py @@ -15,33 +15,21 @@ class CommonApi: class HitTestApi(ApiMixin): @staticmethod - def get_request_params_api(): - return [ - openapi.Parameter(name='query_text', - in_=openapi.IN_QUERY, - type=openapi.TYPE_STRING, - required=True, - description=_('query text')), - openapi.Parameter(name='top_number', - in_=openapi.IN_QUERY, - type=openapi.TYPE_NUMBER, - default=10, - required=True, - description='topN'), - openapi.Parameter(name='similarity', - in_=openapi.IN_QUERY, - type=openapi.TYPE_NUMBER, - default=0.6, - required=True, - description=_('similarity')), - openapi.Parameter(name='search_mode', - in_=openapi.IN_QUERY, - type=openapi.TYPE_STRING, - default="embedding", - required=True, - description=_('Retrieval pattern embedding|keywords|blend') - ) - ] + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['query_text', 'top_number', 'similarity', 'search_mode'], + properties={ + 'query_text': openapi.Schema(type=openapi.TYPE_STRING, title=_('query text'), + description=_('query text')), + 'top_number': openapi.Schema(type=openapi.TYPE_NUMBER, title=_('top number'), + description=_('top number')), + 'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title=_('similarity'), + description=_('similarity')), + 'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title=_('search mode'), + description=_('search mode')) + } + ) @staticmethod def get_response_body_api(): diff --git a/apps/common/util/common.py b/apps/common/util/common.py index b0111029af9..8583a1c989f 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -11,6 +11,7 @@ import io import mimetypes import pickle +import random import re import shutil from functools import reduce @@ -297,3 +298,14 @@ def markdown_to_plain_text(md: str) -> str: # 去除首尾空格 text = text.strip() return text + + +SAFE_CHAR_SET = ( + [chr(i) for i in range(65, 91) if chr(i) not in {'I', 'O'}] + # 大写字母 A-H, J-N, P-Z + [chr(i) for i in range(97, 123) if chr(i) not in {'i', 'l', 'o'}] + # 小写字母 a-h, j-n, p-z + [str(i) for i in range(10) if str(i) not in {'0', '1', '7'}] # 数字 2-6, 8-9 +) + + +def get_random_chars(number=4): + return ''.join(random.choices(SAFE_CHAR_SET, k=number)) diff --git a/apps/common/util/fork.py b/apps/common/util/fork.py index 4405b9b76e4..dc27ccf1982 100644 --- a/apps/common/util/fork.py +++ b/apps/common/util/fork.py @@ -3,6 +3,7 @@ import re import traceback from functools import reduce +from pathlib import Path from typing import List, Set from urllib.parse import urljoin, urlparse, ParseResult, urlsplit, urlunparse @@ -52,6 +53,28 @@ def remove_fragment(url: str) -> str: return urlunparse(modified_url) +def remove_last_path_robust(url): + """健壮地删除URL的最后一个路径部分""" + parsed = urlparse(url) + + # 分割路径并过滤空字符串 + paths = [p for p in parsed.path.split('/') if p] + + if paths: + paths.pop() # 移除最后一个路径 + + # 重建路径 + new_path = '/' + '/'.join(paths) if paths else '/' + + # 重建URL + return urlunparse(( + parsed.scheme, + parsed.netloc, + new_path, + parsed.params, + parsed.query, + parsed.fragment + )) class Fork: class Response: def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str): @@ -70,6 +93,8 @@ def error(message: str): def __init__(self, base_fork_url: str, selector_list: List[str]): base_fork_url = remove_fragment(base_fork_url) + if any([True for end_str in ['index.html', '.htm', '.html'] if base_fork_url.endswith(end_str)]): + base_fork_url =remove_last_path_robust(base_fork_url) self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.') parsed = urlsplit(base_fork_url) query = parsed.query @@ -137,18 +162,30 @@ def get_beautiful_soup(response): html_content = response.content.decode(encoding) beautiful_soup = BeautifulSoup(html_content, "html.parser") meta_list = beautiful_soup.find_all('meta') - charset_list = [meta.attrs.get('charset') for meta in meta_list if - meta.attrs is not None and 'charset' in meta.attrs] + charset_list = Fork.get_charset_list(meta_list) if len(charset_list) > 0: charset = charset_list[0] if charset != encoding: try: - html_content = response.content.decode(charset) + html_content = response.content.decode(charset, errors='replace') except Exception as e: - logging.getLogger("max_kb").error(f'{e}') + logging.getLogger("max_kb").error(f'{e}: {traceback.format_exc()}') return BeautifulSoup(html_content, "html.parser") return beautiful_soup + @staticmethod + def get_charset_list(meta_list): + charset_list = [] + for meta in meta_list: + if meta.attrs is not None: + if 'charset' in meta.attrs: + charset_list.append(meta.attrs.get('charset')) + elif meta.attrs.get('http-equiv', '').lower() == 'content-type' and 'content' in meta.attrs: + match = re.search(r'charset=([^\s;]+)', meta.attrs['content'], re.I) + if match: + charset_list.append(match.group(1)) + return charset_list + def fork(self): try: @@ -175,4 +212,4 @@ def fork(self): def handler(base_url, response: Fork.Response): print(base_url.url, base_url.tag.text if base_url.tag else None, response.content) -# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler) +# ForkManage('https://hzqcgc.htc.edu.cn/jxky.htm', ['.md-content']).fork(3, set(), handler) diff --git a/apps/common/util/function_code.py b/apps/common/util/function_code.py index 30ce3a33d20..3a877a62367 100644 --- a/apps/common/util/function_code.py +++ b/apps/common/util/function_code.py @@ -7,13 +7,12 @@ @desc: """ import os +import pickle import subprocess import sys import uuid from textwrap import dedent -from diskcache import Cache - from smartdoc.const import BASE_DIR from smartdoc.const import PROJECT_DIR @@ -37,6 +36,8 @@ def _createdir(self): old_mask = os.umask(0o077) try: os.makedirs(self.sandbox_path, 0o700, exist_ok=True) + os.makedirs(os.path.join(self.sandbox_path, 'execute'), 0o700, exist_ok=True) + os.makedirs(os.path.join(self.sandbox_path, 'result'), 0o700, exist_ok=True) finally: os.umask(old_mask) @@ -44,10 +45,11 @@ def exec_code(self, code_str, keywords): _id = str(uuid.uuid1()) success = '{"code":200,"msg":"成功","data":exec_result}' err = '{"code":500,"msg":str(e),"data":None}' - path = r'' + self.sandbox_path + '' + result_path = f'{self.sandbox_path}/result/{_id}.result' _exec_code = f""" try: import os + import pickle env = dict(os.environ) for key in list(env.keys()): if key in os.environ and (key.startswith('MAXKB') or key.startswith('POSTGRES') or key.startswith('PG')): @@ -60,13 +62,11 @@ def exec_code(self, code_str, keywords): for local in locals_v: globals_v[local] = locals_v[local] exec_result=f(**keywords) - from diskcache import Cache - cache = Cache({path!a}) - cache.set({_id!a},{success}) + with open({result_path!a}, 'wb') as file: + file.write(pickle.dumps({success})) except Exception as e: - from diskcache import Cache - cache = Cache({path!a}) - cache.set({_id!a},{err}) + with open({result_path!a}, 'wb') as file: + file.write(pickle.dumps({err})) """ if self.sandbox: subprocess_result = self._exec_sandbox(_exec_code, _id) @@ -74,18 +74,18 @@ def exec_code(self, code_str, keywords): subprocess_result = self._exec(_exec_code) if subprocess_result.returncode == 1: raise Exception(subprocess_result.stderr) - cache = Cache(self.sandbox_path) - result = cache.get(_id) - cache.delete(_id) + with open(result_path, 'rb') as file: + result = pickle.loads(file.read()) + os.remove(result_path) if result.get('code') == 200: return result.get('data') raise Exception(result.get('msg')) def _exec_sandbox(self, _code, _id): - exec_python_file = f'{self.sandbox_path}/{_id}.py' + exec_python_file = f'{self.sandbox_path}/execute/{_id}.py' with open(exec_python_file, 'w') as file: file.write(_code) - os.system(f"chown {self.user}:{self.user} {exec_python_file}") + os.system(f"chown {self.user}:root {exec_python_file}") kwargs = {'cwd': BASE_DIR} subprocess_result = subprocess.run( ['su', '-s', python_directory, '-c', "exec(open('" + exec_python_file + "').read())", self.user], diff --git a/apps/common/util/rsa_util.py b/apps/common/util/rsa_util.py index 00301867208..452ca678d9e 100644 --- a/apps/common/util/rsa_util.py +++ b/apps/common/util/rsa_util.py @@ -40,15 +40,12 @@ def generate(): def get_key_pair(): rsa_value = rsa_cache.get(cache_key) if rsa_value is None: - lock.acquire() - rsa_value = rsa_cache.get(cache_key) - if rsa_value is not None: - return rsa_value - try: + with lock: + rsa_value = rsa_cache.get(cache_key) + if rsa_value is not None: + return rsa_value rsa_value = get_key_pair_by_sql() rsa_cache.set(cache_key, rsa_value) - finally: - lock.release() return rsa_value diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 856f3da1584..edf064236b2 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -18,13 +18,13 @@ from common.config.embedding_config import ModelManage from common.db.search import native_search -from common.db.sql_execute import update_execute +from common.db.sql_execute import update_execute, sql_execute from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork -from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image +from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image, Document from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext_lazy as _ @@ -224,6 +224,46 @@ def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List): return str(dataset_list[0].embedding_mode_id) + +def create_dataset_index(dataset_id=None, document_id=None): + if dataset_id is None and document_id is None: + raise AppApiException(500, _('Dataset ID or Document ID must be provided')) + + if dataset_id is not None: + k_id = dataset_id + else: + document = QuerySet(Document).filter(id=document_id).first() + k_id = document.dataset_id + + sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'" + index = sql_execute(sql, []) + if not index: + sql = f"SELECT vector_dims(embedding) AS dims FROM embedding WHERE dataset_id = '{k_id}' LIMIT 1" + result = sql_execute(sql, []) + if len(result) == 0: + return + dims = result[0]['dims'] + sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE dataset_id = '{k_id}'""" + update_execute(sql, []) + + +def drop_dataset_index(dataset_id=None, document_id=None): + if dataset_id is None and document_id is None: + raise AppApiException(500, _('Dataset ID or Document ID must be provided')) + + if dataset_id is not None: + k_id = dataset_id + else: + document = QuerySet(Document).filter(id=document_id).first() + k_id = document.dataset_id + + sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'" + index = sql_execute(sql, []) + if index: + sql = f'DROP INDEX "embedding_hnsw_idx_{k_id}"' + update_execute(sql, []) + + class GenerateRelatedSerializer(ApiMixin, serializers.Serializer): model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Model id'))) prompt = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_('Prompt word'))) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 895443d997f..b92ecb33a54 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -44,7 +44,7 @@ State, File, Image from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir, \ - GenerateRelatedSerializer + GenerateRelatedSerializer, drop_dataset_index from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.task import sync_web_dataset, sync_replace_web_dataset, generate_related_by_dataset_id from embedding.models import SearchMode @@ -526,7 +526,7 @@ def get_response_body_api(): def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['name', 'desc'], + required=['name', 'desc', 'embedding_mode_id'], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title=_('dataset name'), description=_('dataset name')), @@ -788,6 +788,7 @@ def delete(self): QuerySet(ProblemParagraphMapping).filter(dataset=dataset).delete() QuerySet(Paragraph).filter(dataset=dataset).delete() QuerySet(Problem).filter(dataset=dataset).delete() + drop_dataset_index(dataset_id=dataset.id) dataset.delete() delete_embedding_by_dataset(self.data.get('id')) return True diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 5915877fc7c..94f1b9db6ea 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -23,6 +23,8 @@ from django.db.models import QuerySet, Count from django.db.models.functions import Substr, Reverse from django.http import HttpResponse +from django.utils.translation import get_language +from django.utils.translation import gettext_lazy as _, gettext, to_locale from drf_yasg import openapi from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE from rest_framework import serializers @@ -64,8 +66,6 @@ embedding_by_document_list from setting.models import Model from smartdoc.conf import PROJECT_DIR -from django.utils.translation import gettext_lazy as _, gettext, to_locale -from django.utils.translation import get_language parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()] parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()] @@ -141,7 +141,8 @@ def is_valid(self, *, document: Document = None): if 'meta' in self.data and self.data.get('meta') is not None: dataset_meta_valid_map = self.get_meta_valid_map() valid_class = dataset_meta_valid_map.get(document.type) - valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) + if valid_class is not None: + valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer): @@ -661,6 +662,8 @@ def get_workbook(data_dict, document_dict): cell = worksheet.cell(row=row_idx + 1, column=col_idx + 1) if isinstance(col, str): col = re.sub(ILLEGAL_CHARACTERS_RE, '', col) + if col.startswith(('=', '+', '-', '@')): + col = '\ufeff' + col cell.value = col # 创建HttpResponse对象返回Excel文件 return workbook @@ -806,27 +809,40 @@ def delete(self): def get_response_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active' - 'update_time', 'create_time'], + required=['create_time', 'update_time', 'id', 'name', 'char_length', 'status', 'is_active', + 'type', 'meta', 'dataset_id', 'hit_handling_method', 'directly_return_similarity', + 'status_meta', 'paragraph_count'], properties={ + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title=_('create time'), + description=_('create time'), + default="1970-01-01 00:00:00"), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title=_('update time'), + description=_('update time'), + default="1970-01-01 00:00:00"), 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", description="id", default="xx"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title=_('name'), description=_('name'), default="xx"), 'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title=_('char length'), description=_('char length'), default=10), - 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title=_('user id'), description=_('user id')), - 'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="_('document count')", - description="_('document count')", default=1), + 'status':openapi.Schema(type=openapi.TYPE_STRING, title=_('status'), + description=_('status'), default="xx"), 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title=_('Is active'), description=_('Is active'), default=True), - 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title=_('update time'), - description=_('update time'), - default="1970-01-01 00:00:00"), - 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title=_('create time'), - description=_('create time'), - default="1970-01-01 00:00:00" - ) + 'type': openapi.Schema(type=openapi.TYPE_STRING, title=_('type'), + description=_('type'), default="xx"), + 'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title=_('meta'), + description=_('meta'), default="{}"), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title=_('dataset_id'), + description=_('dataset_id'), default="xx"), + 'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title=_('hit_handling_method'), + description=_('hit_handling_method'), default="xx"), + 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title=_('directly_return_similarity'), + description=_('directly_return_similarity'), default="xx"), + 'status_meta': openapi.Schema(type=openapi.TYPE_OBJECT, title=_('status_meta'), + description=_('status_meta'), default="{}"), + 'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="_('document count')", + description="_('document count')", default=1), } ) @@ -853,7 +869,7 @@ def get_request_body_api(): class Create(ApiMixin, serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( - _('document id'))) + _('dataset id'))) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -981,7 +997,7 @@ def get_request_params_api(): in_=openapi.IN_PATH, type=openapi.TYPE_STRING, required=True, - description=_('document id')) + description=_('dataset id')) ] class Split(ApiMixin, serializers.Serializer): diff --git a/apps/dataset/serializers/file_serializers.py b/apps/dataset/serializers/file_serializers.py index 37f72fc8429..899c8a088de 100644 --- a/apps/dataset/serializers/file_serializers.py +++ b/apps/dataset/serializers/file_serializers.py @@ -28,6 +28,9 @@ "woff2": "font/woff2", "jar": "application/java-archive", "war": "application/java-archive", "ear": "application/java-archive", "json": "application/json", "hqx": "application/mac-binhex40", "doc": "application/msword", "pdf": "application/pdf", "ps": "application/postscript", + "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", "eps": "application/postscript", "ai": "application/postscript", "rtf": "application/rtf", "m3u8": "application/vnd.apple.mpegurl", "kml": "application/vnd.google-earth.kml+xml", "kmz": "application/vnd.google-earth.kmz", "xls": "application/vnd.ms-excel", @@ -87,4 +90,4 @@ def get(self, with_valid=True): 'Content-Disposition': 'attachment; filename="{}"'.format( file.file_name)}) return HttpResponse(file.get_byte(), status=200, - headers={'Content-Type': mime_types.get(file.file_name.split(".")[-1], 'text/plain')}) + headers={'Content-Type': mime_types.get(file_type, 'text/plain')}) diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 3a63fd95cd0..9b6e096ba00 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -226,6 +226,14 @@ def is_valid(self, *, raise_exception=True): def association(self, with_valid=True, with_embedding=True): if with_valid: self.is_valid(raise_exception=True) + # 已关联则直接返回 + if QuerySet(ProblemParagraphMapping).filter( + dataset_id=self.data.get('dataset_id'), + document_id=self.data.get('document_id'), + paragraph_id=self.data.get('paragraph_id'), + problem_id=self.data.get('problem_id') + ).exists(): + return True problem = QuerySet(Problem).filter(id=self.data.get("problem_id")).first() problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), document_id=self.data.get('document_id'), diff --git a/apps/dataset/sql/update_document_char_length.sql b/apps/dataset/sql/update_document_char_length.sql index 4a4060cd9d4..2781809b23d 100644 --- a/apps/dataset/sql/update_document_char_length.sql +++ b/apps/dataset/sql/update_document_char_length.sql @@ -2,6 +2,7 @@ UPDATE "document" SET "char_length" = ( SELECT CASE WHEN "sum" ( "char_length" ( "content" ) ) IS NULL THEN 0 ELSE "sum" ( "char_length" ( "content" ) ) - END FROM paragraph WHERE "document_id" = %s ) + END FROM paragraph WHERE "document_id" = %s ), + "update_time" = CURRENT_TIMESTAMP WHERE "id" = %s \ No newline at end of file diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index bbb9e033980..aeb1af28932 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -7,13 +7,13 @@ @desc: """ +from django.utils.translation import gettext_lazy as _ from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action from rest_framework.parsers import MultiPartParser from rest_framework.views import APIView from rest_framework.views import Request -import dataset.models from common.auth import TokenAuth, has_permissions from common.constants.permission_constants import PermissionConstants, CompareConstants, Permission, Group, Operate, \ ViewPermission, RoleConstants @@ -25,7 +25,6 @@ from dataset.serializers.dataset_serializers import DataSetSerializers from dataset.views.common import get_dataset_operation_object from setting.serializers.provider_serializers import ModelSerializer -from django.utils.translation import gettext_lazy as _ class Dataset(APIView): @@ -141,21 +140,22 @@ def post(self, request: Request): class HitTest(APIView): authentication_classes = [TokenAuth] - @action(methods="GET", detail=False) + @action(methods="PUT", detail=False) @swagger_auto_schema(operation_summary=_('Hit test list'), operation_id=_('Hit test list'), - manual_parameters=CommonApi.HitTestApi.get_request_params_api(), + request_body=CommonApi.HitTestApi.get_request_body_api(), responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), tags=[_('Knowledge Base')]) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=keywords.get('dataset_id'))) - def get(self, request: Request, dataset_id: str): - return result.success( - DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id, - "query_text": request.query_params.get("query_text"), - "top_number": request.query_params.get("top_number"), - 'similarity': request.query_params.get('similarity'), - 'search_mode': request.query_params.get('search_mode')}).hit_test( - )) + def put(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.HitTest(data={ + 'id': dataset_id, + 'user_id': request.user.id, + "query_text": request.data.get("query_text"), + "top_number": request.data.get("top_number"), + 'similarity': request.data.get('similarity'), + 'search_mode': request.data.get('search_mode')} + ).hit_test()) class Embedding(APIView): authentication_classes = [TokenAuth] diff --git a/apps/embedding/sql/blend_search.sql b/apps/embedding/sql/blend_search.sql index afb1f0040d1..c70e66464ee 100644 --- a/apps/embedding/sql/blend_search.sql +++ b/apps/embedding/sql/blend_search.sql @@ -5,15 +5,17 @@ SELECT FROM ( SELECT DISTINCT ON - ( "paragraph_id" ) ( similarity ),* , - similarity AS comprehensive_score + ( "paragraph_id" ) ( 1 - distince + ts_similarity ) as similarity, *, + (1 - distince + ts_similarity) AS comprehensive_score FROM ( SELECT *, - (( 1 - ( embedding.embedding <=> %s ) )+ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS similarity + (embedding.embedding::vector(%s) <=> %s) as distince, + (ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS ts_similarity FROM embedding ${embedding_query} + ORDER BY distince ) TEMP ORDER BY paragraph_id, diff --git a/apps/embedding/sql/embedding_search.sql b/apps/embedding/sql/embedding_search.sql index ce3d4a580d5..1b5689959b8 100644 --- a/apps/embedding/sql/embedding_search.sql +++ b/apps/embedding/sql/embedding_search.sql @@ -5,12 +5,12 @@ SELECT FROM ( SELECT DISTINCT ON - ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + ("paragraph_id") ( 1 - distince ),* ,(1 - distince) AS comprehensive_score FROM - ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP + ( SELECT *, ( embedding.embedding::vector(%s) <=> %s ) AS distince FROM embedding ${embedding_query} ORDER BY distince) TEMP ORDER BY paragraph_id, - similarity DESC + distince ) DISTINCT_TEMP WHERE comprehensive_score>%s ORDER BY comprehensive_score DESC diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 3b26bd7a1db..48846750006 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -17,6 +17,7 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \ UpdateEmbeddingDocumentIdArgs from dataset.models import Document, TaskType, State +from dataset.serializers.common_serializers import drop_dataset_index from ops import celery_app from setting.models import Model from setting.models_provider import get_model @@ -110,6 +111,7 @@ def embedding_by_dataset(dataset_id, model_id): max_kb.info(_('Start--->Vectorized dataset: {dataset_id}').format(dataset_id=dataset_id)) try: ListenerManagement.delete_embedding_by_dataset(dataset_id) + drop_dataset_index(dataset_id=dataset_id) document_list = QuerySet(Document).filter(dataset_id=dataset_id) max_kb.info(_('Dataset documentation: {document_names}').format( document_names=", ".join([d.name for d in document_list]))) diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 7929685a37c..af9ff7e4ca3 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -12,7 +12,6 @@ from abc import ABC, abstractmethod from typing import Dict, List -import jieba from django.contrib.postgres.search import SearchVector from django.db.models import QuerySet, Value from langchain_core.embeddings import Embeddings @@ -169,8 +168,13 @@ def handle(self, os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'embedding_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(query_embedding), *exec_params, similarity, top_number]) + embedding_model = select_list(exec_sql, [ + len(query_embedding), + json.dumps(query_embedding), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode): @@ -190,8 +194,12 @@ def handle(self, os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'keywords_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [to_query(query_text), *exec_params, similarity, top_number]) + embedding_model = select_list(exec_sql, [ + to_query(query_text), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode): @@ -211,9 +219,14 @@ def handle(self, os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'blend_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(query_embedding), to_query(query_text), *exec_params, similarity, - top_number]) + embedding_model = select_list(exec_sql, [ + len(query_embedding), + json.dumps(query_embedding), + to_query(query_text), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode): diff --git a/apps/function_lib/migrations/0004_functionlib_decimal_date.py b/apps/function_lib/migrations/0004_functionlib_decimal_date.py new file mode 100644 index 00000000000..82e4a6d029a --- /dev/null +++ b/apps/function_lib/migrations/0004_functionlib_decimal_date.py @@ -0,0 +1,127 @@ +# Generated by Django 4.2.15 on 2025-03-13 07:21 + +from django.db import migrations +from django.db.models import Q + +mysql_template = """ +def query_mysql(host,port, user, password, database, sql): + import pymysql + import json + from pymysql.cursors import DictCursor + from datetime import datetime, date + + def default_serializer(obj): + from decimal import Decimal + if isinstance(obj, (datetime, date)): + return obj.isoformat() # 将 datetime/date 转换为 ISO 格式字符串 + elif isinstance(obj, Decimal): + return float(obj) # 将 Decimal 转换为 float + raise TypeError(f"Type {type(obj)} not serializable") + + try: + # 创建连接 + db = pymysql.connect( + host=host, + port=int(port), + user=user, + password=password, + database=database, + cursorclass=DictCursor # 使用字典游标 + ) + + # 使用 cursor() 方法创建一个游标对象 cursor + cursor = db.cursor() + + # 使用 execute() 方法执行 SQL 查询 + cursor.execute(sql) + + # 使用 fetchall() 方法获取所有数据 + data = cursor.fetchall() + + # 处理 bytes 类型的数据 + for row in data: + for key, value in row.items(): + if isinstance(value, bytes): + row[key] = value.decode("utf-8") # 转换为字符串 + + # 将数据序列化为 JSON + json_data = json.dumps(data, default=default_serializer, ensure_ascii=False) + return json_data + + # 关闭数据库连接 + db.close() + + except Exception as e: + print(f"Error while connecting to MySQL: {e}") + raise e +""" + +pgsql_template = """ +def queryPgSQL(database, user, password, host, port, query): + import psycopg2 + import json + from datetime import datetime + + # 自定义 JSON 序列化函数 + def default_serializer(obj): + from decimal import Decimal + if isinstance(obj, datetime): + return obj.isoformat() # 将 datetime 转换为 ISO 格式字符串 + elif isinstance(obj, Decimal): + return float(obj) # 将 Decimal 转换为 float + raise TypeError(f"Type {type(obj)} not serializable") + + # 数据库连接信息 + conn_params = { + "dbname": database, + "user": user, + "password": password, + "host": host, + "port": port + } + try: + # 建立连接 + conn = psycopg2.connect(**conn_params) + print("连接成功!") + # 创建游标对象 + cursor = conn.cursor() + # 执行查询语句 + cursor.execute(query) + # 获取查询结果 + rows = cursor.fetchall() + # 处理 bytes 类型的数据 + columns = [desc[0] for desc in cursor.description] + result = [dict(zip(columns, row)) for row in rows] + # 转换为 JSON 格式 + json_result = json.dumps(result, default=default_serializer, ensure_ascii=False) + return json_result + except Exception as e: + print(f"发生错误:{e}") + raise e + finally: + # 关闭游标和连接 + if cursor: + cursor.close() + if conn: + conn.close() +""" + + +def fix_type(apps, schema_editor): + FunctionLib = apps.get_model('function_lib', 'FunctionLib') + FunctionLib.objects.filter( + Q(id='22c21b76-0308-11f0-9694-5618c4394482') | Q(template_id='22c21b76-0308-11f0-9694-5618c4394482') + ).update(code=mysql_template) + FunctionLib.objects.filter( + Q(id='bd1e8b88-0302-11f0-87bb-5618c4394482') | Q(template_id='bd1e8b88-0302-11f0-87bb-5618c4394482') + ).update(code=pgsql_template) + + +class Migration(migrations.Migration): + dependencies = [ + ('function_lib', '0003_functionlib_function_type_functionlib_icon_and_more'), + ] + + operations = [ + migrations.RunPython(fix_type) + ] diff --git a/apps/function_lib/serializers/function_lib_serializer.py b/apps/function_lib/serializers/function_lib_serializer.py index 440eb22c786..ad7ff3cce61 100644 --- a/apps/function_lib/serializers/function_lib_serializer.py +++ b/apps/function_lib/serializers/function_lib_serializer.py @@ -33,11 +33,13 @@ function_executor = FunctionExecutor(CONFIG.get('SANDBOX')) + class FlibInstance: def __init__(self, function_lib: dict, version: str): self.function_lib = function_lib self.version = version + def encryption(message: str): """ 加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890 @@ -68,7 +70,8 @@ def encryption(message: str): class FunctionLibModelSerializer(serializers.ModelSerializer): class Meta: model = FunctionLib - fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list','init_field_list', 'init_params', 'permission_type', 'is_active', 'user_id', 'template_id', + fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list', 'init_field_list', 'init_params', + 'permission_type', 'is_active', 'user_id', 'template_id', 'create_time', 'update_time'] @@ -148,7 +151,6 @@ class Query(serializers.Serializer): select_user_id = serializers.CharField(required=False, allow_null=True, allow_blank=True) function_type = serializers.CharField(required=False, allow_null=True, allow_blank=True) - def get_query_set(self): query_set = QuerySet(FunctionLib).filter( (Q(user_id=self.data.get('user_id')) | Q(permission_type='PUBLIC'))) @@ -269,7 +271,7 @@ class Operate(serializers.Serializer): def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) - if not QuerySet(FunctionLib).filter(id=self.data.get('id')).exists(): + if not QuerySet(FunctionLib).filter(user_id=self.data.get('user_id'), id=self.data.get('id')).exists(): raise AppApiException(500, _('Function does not exist')) def delete(self, with_valid=True): @@ -285,7 +287,8 @@ def edit(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) EditFunctionLib(data=instance).is_valid(raise_exception=True) - edit_field_list = ['name', 'desc', 'code', 'icon', 'input_field_list', 'init_field_list', 'init_params', 'permission_type', 'is_active'] + edit_field_list = ['name', 'desc', 'code', 'icon', 'input_field_list', 'init_field_list', 'init_params', + 'permission_type', 'is_active'] edit_dict = {field: instance.get(field) for field in edit_field_list if ( field in instance and instance.get(field) is not None)} @@ -317,7 +320,8 @@ def one(self, with_valid=True): if function_lib.init_params: function_lib.init_params = json.loads(rsa_long_decrypt(function_lib.init_params)) if function_lib.init_field_list: - password_fields = [i["field"] for i in function_lib.init_field_list if i.get("input_type") == "PasswordInput"] + password_fields = [i["field"] for i in function_lib.init_field_list if + i.get("input_type") == "PasswordInput"] if function_lib.init_params: for k in function_lib.init_params: if k in password_fields and function_lib.init_params[k]: diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index d13912928b9..9b83be9686d 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -7238,7 +7238,7 @@ msgstr "" msgid "" "The confirmation password must be 6-20 characters long and must be a " "combination of letters, numbers, and special characters." -msgstr "" +msgstr "The confirmation password must be 6-20 characters long and must be a combination of letters, numbers, and special characters.(Special character support:_、!、@、#、$、(、) ……)" #: community/apps/users/serializers/user_serializers.py:380 #, python-brace-format @@ -7490,4 +7490,22 @@ msgid "Field: {name} No value set" msgstr "" msgid "Generate related" +msgstr "" + +msgid "Obtain graphical captcha" +msgstr "" + +msgid "Captcha code error or expiration" +msgstr "" + +msgid "captcha" +msgstr "" + +msgid "Reasoning enable" +msgstr "" + +msgid "Reasoning start tag" +msgstr "" + +msgid "Reasoning end tag" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index b0ab7871bf6..9500103c702 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -4536,7 +4536,7 @@ msgstr "修改知识库信息" #: community/apps/dataset/views/document.py:463 #: community/apps/dataset/views/document.py:464 msgid "Get the knowledge base paginated list" -msgstr "获取知识库分页列表" +msgstr "获取知识库文档分页列表" #: community/apps/dataset/views/document.py:31 #: community/apps/dataset/views/document.py:32 @@ -7395,7 +7395,7 @@ msgstr "语言只支持:" msgid "" "The confirmation password must be 6-20 characters long and must be a " "combination of letters, numbers, and special characters." -msgstr "确认密码长度6-20个字符,必须字母、数字、特殊字符组合" +msgstr "确认密码长度6-20个字符,必须字母、数字、特殊字符组合(特殊字符支持:_、!、@、#、$、(、) ……)" #: community/apps/users/serializers/user_serializers.py:380 #, python-brace-format @@ -7653,4 +7653,22 @@ msgid "Field: {name} No value set" msgstr "字段: {name} 未设置值" msgid "Generate related" -msgstr "生成问题" \ No newline at end of file +msgstr "生成问题" + +msgid "Obtain graphical captcha" +msgstr "获取图形验证码" + +msgid "Captcha code error or expiration" +msgstr "验证码错误或过期" + +msgid "captcha" +msgstr "验证码" + +msgid "Reasoning enable" +msgstr "开启思考过程" + +msgid "Reasoning start tag" +msgstr "思考过程开始标签" + +msgid "Reasoning end tag" +msgstr "思考过程结束标签" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index dab1d176c26..ab471689f6b 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -4545,7 +4545,7 @@ msgstr "修改知識庫信息" #: community/apps/dataset/views/document.py:463 #: community/apps/dataset/views/document.py:464 msgid "Get the knowledge base paginated list" -msgstr "獲取知識庫分頁列表" +msgstr "獲取知識庫文档分頁列表" #: community/apps/dataset/views/document.py:31 #: community/apps/dataset/views/document.py:32 @@ -5054,7 +5054,7 @@ msgstr "語音合成" #: community/apps/setting/models_provider/base_model_provider.py:150 msgid "Vision Model" -msgstr "圖片理解" +msgstr "視覺模型" #: community/apps/setting/models_provider/base_model_provider.py:151 msgid "Image Generation" @@ -7405,7 +7405,7 @@ msgstr "語言只支持:" msgid "" "The confirmation password must be 6-20 characters long and must be a " "combination of letters, numbers, and special characters." -msgstr "確認密碼長度6-20個字符,必須字母、數字、特殊字符組合" +msgstr "確認密碼長度6-20個字符,必須字母、數字、特殊字符組合(特殊字元支持:_、!、@、#、$、(、) ……)" #: community/apps/users/serializers/user_serializers.py:380 #, python-brace-format @@ -7663,4 +7663,22 @@ msgid "Field: {name} No value set" msgstr "欄位: {name} 未設定值" msgid "Generate related" -msgstr "生成問題" \ No newline at end of file +msgstr "生成問題" + +msgid "Obtain graphical captcha" +msgstr "獲取圖形驗證碼" + +msgid "Captcha code error or expiration" +msgstr "驗證碼錯誤或過期" + +msgid "captcha" +msgstr "驗證碼" + +msgid "Reasoning enable" +msgstr "開啟思考過程" + +msgid "Reasoning start tag" +msgstr "思考過程開始標籤" + +msgid "Reasoning end tag" +msgstr "思考過程結束標籤" \ No newline at end of file diff --git a/apps/ops/celery/signal_handler.py b/apps/ops/celery/signal_handler.py index 46671a0d8fa..bfded3e4027 100644 --- a/apps/ops/celery/signal_handler.py +++ b/apps/ops/celery/signal_handler.py @@ -2,7 +2,7 @@ # import logging import os - +from common.init import init_template from celery import subtask from celery.signals import ( worker_ready, worker_shutdown, after_setup_logger, task_revoked, task_prerun @@ -31,6 +31,7 @@ def on_app_ready(sender=None, headers=None, **kwargs): logger.debug("Periodic task [{}] is disabled!".format(task)) continue subtask(task).delay() + init_template.run() def delete_files(directory): diff --git a/apps/setting/migrations/0011_refresh_collation_reindex.py b/apps/setting/migrations/0011_refresh_collation_reindex.py new file mode 100644 index 00000000000..0f93d4ad481 --- /dev/null +++ b/apps/setting/migrations/0011_refresh_collation_reindex.py @@ -0,0 +1,61 @@ +import logging + +import psycopg +from django.db import migrations + +from smartdoc.const import CONFIG + + +def get_connect(db_name): + conn_params = { + "dbname": db_name, + "user": CONFIG.get('DB_USER'), + "password": CONFIG.get('DB_PASSWORD'), + "host": CONFIG.get('DB_HOST'), + "port": CONFIG.get('DB_PORT') + } + # 建立连接 + connect = psycopg.connect(**conn_params) + return connect + + +def sql_execute(conn, reindex_sql: str, alter_database_sql: str): + """ + 执行一条sql + @param reindex_sql: + @param conn: + @param alter_database_sql: + """ + conn.autocommit = True + with conn.cursor() as cursor: + cursor.execute(reindex_sql, []) + cursor.execute(alter_database_sql, []) + cursor.close() + +def re_index(apps, schema_editor): + app_db_name = CONFIG.get('DB_NAME') + try: + re_index_database(app_db_name) + except Exception as e: + logging.error(f'reindex database {app_db_name}发送错误:{str(e)}') + try: + re_index_database('root') + except Exception as e: + logging.error(f'reindex database root 发送错误:{str(e)}') + + +def re_index_database(db_name): + db_conn = get_connect(db_name) + sql_execute(db_conn, f'REINDEX DATABASE "{db_name}";', f'ALTER DATABASE "{db_name}" REFRESH COLLATION VERSION;') + db_conn.close() + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0010_log'), + ] + + operations = [ + migrations.RunPython(re_index, atomic=False) + ] diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 622be703dad..2b02bdc1fb1 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -106,7 +106,10 @@ def filter_optional_params(model_kwargs): optional_params = {} for key, value in model_kwargs.items(): if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']: - optional_params[key] = value + if key == 'extra_body' and isinstance(value, dict): + optional_params = {**optional_params, **value} + else: + optional_params[key] = value return optional_params diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index e6bf698b01a..e68b9361f0b 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -19,6 +19,8 @@ from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider +from setting.models_provider.impl.regolo_model_provider.regolo_model_provider import \ + RegoloModelProvider from setting.models_provider.impl.siliconCloud_model_provider.siliconCloud_model_provider import \ SiliconCloudModelProvider from setting.models_provider.impl.tencent_cloud_model_provider.tencent_cloud_model_provider import \ @@ -55,3 +57,4 @@ class ModelProvideConstants(Enum): aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider() model_anthropic_provider = AnthropicModelProvider() model_siliconCloud_provider = SiliconCloudModelProvider() + model_regolo_provider = RegoloModelProvider() diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py index 8c5031f08f2..b1d72f0869a 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -51,6 +51,23 @@ _("Universal text vector is Tongyi Lab's multi-language text unified vector model based on the LLM base. It provides high-level vector services for multiple mainstream languages around the world and helps developers quickly convert text data into high-quality vector data."), ModelTypeConst.EMBEDDING, aliyun_bai_lian_embedding_model_credential, AliyunBaiLianEmbedding), + ModelInfo('qwen3-0.6b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-1.7b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-4b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-8b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-14b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-32b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-30b-a3b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-235b-a22b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, BaiLianChatModel), ModelInfo('qwen-plus', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py index f316a0c6d1c..9da30b72796 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py @@ -30,6 +30,29 @@ class BaiLianLLMModelParams(BaseForm): precision=0) +class BaiLianLLMStreamModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel(_('Temperature'), + _('Higher values make the output more random, while lower values make it more focused and deterministic')), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + stream = forms.SwitchField(label=TooltipLabel(_('Is the answer in streaming mode'), + _('Is the answer in streaming mode')), + required=True, default_value=True) + + class BaiLianLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -47,7 +70,11 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje return False try: model = provider.get_model(model_type, model_name, model_credential, **model_params) - model.invoke([HumanMessage(content=gettext('Hello'))]) + if model_params.get('stream'): + for res in model.stream([HumanMessage(content=gettext('Hello'))]): + pass + else: + model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: traceback.print_exc() if isinstance(e, AppApiException): @@ -68,4 +95,6 @@ def encryption_dict(self, model: Dict[str, object]): api_key = forms.PasswordInputField('API Key', required=True) def get_model_params_setting_form(self, model_name): + if 'qwen3' in model_name: + return BaiLianLLMStreamModelParams() return BaiLianLLMModelParams() diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py index 2b1fe31f228..7cda97f2388 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py @@ -15,9 +15,8 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model_name=model_name, openai_api_key=model_credential.get('api_key'), openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', - # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) return chat_tong_yi diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py index d914f7c8ad6..ee3ee6488c2 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py @@ -20,5 +20,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index ef1c133378e..7b0088a4ab4 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -1,10 +1,12 @@ import os import re -from typing import Dict +from typing import Dict, List from botocore.config import Config from langchain_community.chat_models import BedrockChat +from langchain_core.messages import BaseMessage, get_buffer_string +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -72,6 +74,20 @@ def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[s config=config ) + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) + def _update_aws_credentials(profile_name, access_key_id, secret_access_key): credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") diff --git a/apps/setting/models_provider/impl/base_chat_open_ai.py b/apps/setting/models_provider/impl/base_chat_open_ai.py index 54076b7efda..626a751f740 100644 --- a/apps/setting/models_provider/impl/base_chat_open_ai.py +++ b/apps/setting/models_provider/impl/base_chat_open_ai.py @@ -1,15 +1,16 @@ # coding=utf-8 -import warnings -from typing import List, Dict, Optional, Any, Iterator, cast, Type, Union +from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping -import openai -from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk -from langchain_core.outputs import ChatGenerationChunk, ChatGeneration +from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \ + SystemMessageChunk, FunctionMessageChunk, ChatMessageChunk +from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.tool import tool_call_chunk, ToolMessageChunk +from langchain_core.outputs import ChatGenerationChunk from langchain_core.runnables import RunnableConfig, ensure_config -from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_core.tools import BaseTool from langchain_openai import ChatOpenAI +from langchain_openai.chat_models.base import _create_usage_metadata from common.config.tokenizer_manage_config import TokenizerManage @@ -19,6 +20,65 @@ def custom_get_token_ids(text: str): return tokenizer.encode(text) +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] +) -> BaseMessageChunk: + id_ = _dict.get("id") + role = cast(str, _dict.get("role")) + content = cast(str, _dict.get("content") or "") + additional_kwargs: dict = {} + if 'reasoning_content' in _dict: + additional_kwargs['reasoning_content'] = _dict.get('reasoning_content') + if _dict.get("function_call"): + function_call = dict(_dict["function_call"]) + if "name" in function_call and function_call["name"] is None: + function_call["name"] = "" + additional_kwargs["function_call"] = function_call + tool_call_chunks = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc["index"], + ) + for rtc in raw_tool_calls + ] + except KeyError: + pass + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content, id=id_) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + id=id_, + tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + ) + elif role in ("system", "developer") or default_class == SystemMessageChunk: + if role == "developer": + additional_kwargs = {"__openai_role__": "developer"} + else: + additional_kwargs = {} + return SystemMessageChunk( + content=content, id=id_, additional_kwargs=additional_kwargs + ) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"], id=id_) + elif role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk( + content=content, tool_call_id=_dict["tool_call_id"], id=id_ + ) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role, id=id_) + else: + return default_class(content=content, id=id_) # type: ignore + + class BaseChatOpenAI(ChatOpenAI): usage_metadata: dict = {} custom_get_token_ids = custom_get_token_ids @@ -26,14 +86,20 @@ class BaseChatOpenAI(ChatOpenAI): def get_last_generation_info(self) -> Optional[Dict[str, Any]]: return self.usage_metadata - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + def get_num_tokens_from_messages( + self, + messages: list[BaseMessage], + tools: Optional[ + Sequence[Union[dict[str, Any], type, Callable, BaseTool]] + ] = None, + ) -> int: if self.usage_metadata is None or self.usage_metadata == {}: try: return super().get_num_tokens_from_messages(messages) except Exception as e: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - return self.usage_metadata.get('input_tokens', 0) + return self.usage_metadata.get('input_tokens', self.usage_metadata.get('prompt_tokens', 0)) def get_num_tokens(self, text: str) -> int: if self.usage_metadata is None or self.usage_metadata == {}: @@ -42,116 +108,80 @@ def get_num_tokens(self, text: str) -> int: except Exception as e: tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) - return self.get_last_generation_info().get('output_tokens', 0) + return self.get_last_generation_info().get('output_tokens', + self.get_last_generation_info().get('completion_tokens', 0)) + + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: + kwargs['stream_usage'] = True + for chunk in super()._stream(*args, **kwargs): + if chunk.message.usage_metadata is not None: + self.usage_metadata = chunk.message.usage_metadata + yield chunk - def _stream( + def _convert_chunk_to_generation_chunk( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - kwargs["stream"] = True - kwargs["stream_options"] = {"include_usage": True} - """Set default stream_options.""" - stream_usage = self._should_stream_usage(kwargs.get('stream_usage'), **kwargs) - # Note: stream_options is not a valid parameter for Azure OpenAI. - # To support users proxying Azure through ChatOpenAI, here we only specify - # stream_options if include_usage is set to True. - # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new - # for release notes. - if stream_usage: - kwargs["stream_options"] = {"include_usage": stream_usage} - - payload = self._get_request_payload(messages, stop=stop, **kwargs) - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - base_generation_info = {} - - if "response_format" in payload and is_basemodel_subclass( - payload["response_format"] - ): - # TODO: Add support for streaming with Pydantic response_format. - warnings.warn("Streaming with Pydantic response_format not yet supported.") - chat_result = self._generate( - messages, stop, run_manager=run_manager, **kwargs - ) - msg = chat_result.generations[0].message - yield ChatGenerationChunk( - message=AIMessageChunk( - **msg.dict(exclude={"type", "additional_kwargs"}), - # preserve the "parsed" Pydantic object without converting to dict - additional_kwargs=msg.additional_kwargs, - ), - generation_info=chat_result.generations[0].generation_info, + chunk: dict, + default_chunk_class: type, + base_generation_info: Optional[dict], + ) -> Optional[ChatGenerationChunk]: + if chunk.get("type") == "content.delta": # from beta.chat.completions.stream + return None + token_usage = chunk.get("usage") + choices = ( + chunk.get("choices", []) + # from beta.chat.completions.stream + or chunk.get("chunk", {}).get("choices", []) + ) + + usage_metadata: Optional[UsageMetadata] = ( + _create_usage_metadata(token_usage) if token_usage and token_usage.get("prompt_tokens") else None + ) + if len(choices) == 0: + # logprobs is implicitly None + generation_chunk = ChatGenerationChunk( + message=default_chunk_class(content="", usage_metadata=usage_metadata) ) - return - if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} - else: - response = self.client.create(**payload) - with response: - is_first_chunk = True - for chunk in response: - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - - generation_chunk = super()._convert_chunk_to_generation_chunk( - chunk, - default_chunk_class, - base_generation_info if is_first_chunk else {}, - ) - if generation_chunk is None: - continue - - # custom code - if len(chunk['choices']) > 0 and 'reasoning_content' in chunk['choices'][0]['delta']: - generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta'][ - 'reasoning_content'] - - default_chunk_class = generation_chunk.message.__class__ - logprobs = (generation_chunk.generation_info or {}).get("logprobs") - if run_manager: - run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk, logprobs=logprobs - ) - is_first_chunk = False - # custom code - if generation_chunk.message.usage_metadata is not None: - self.usage_metadata = generation_chunk.message.usage_metadata - yield generation_chunk - - def _create_chat_result(self, - response: Union[dict, openai.BaseModel], - generation_info: Optional[Dict] = None): - result = super()._create_chat_result(response, generation_info) - try: - reasoning_content = '' - reasoning_content_enable = False - for res in response.choices: - if 'reasoning_content' in res.message.model_extra: - reasoning_content_enable = True - _reasoning_content = res.message.model_extra.get('reasoning_content') - if _reasoning_content is not None: - reasoning_content += _reasoning_content - if reasoning_content_enable: - result.llm_output['reasoning_content'] = reasoning_content - except Exception as e: - pass - return result + return generation_chunk + + choice = choices[0] + if choice["delta"] is None: + return None + + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {**base_generation_info} if base_generation_info else {} + + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + + if usage_metadata and isinstance(message_chunk, AIMessageChunk): + message_chunk.usage_metadata = usage_metadata + + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + return generation_chunk def invoke( self, input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> BaseMessage: config = ensure_config(config) chat_result = cast( - ChatGeneration, + "ChatGeneration", self.generate_prompt( [self._convert_input(input)], stop=stop, @@ -162,7 +192,9 @@ def invoke( run_id=config.pop("run_id", None), **kwargs, ).generations[0][0], + ).message + self.usage_metadata = chat_result.response_metadata[ 'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata return chat_result diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py index 9db4faca7cc..081d648a716 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py @@ -26,6 +26,6 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base='https://api.deepseek.com', openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) return deepseek_chat_open_ai diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py index 4106cc1d6e3..af23d0341a4 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -13,7 +13,7 @@ Tool as GoogleTool, ) from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.outputs import ChatGenerationChunk from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict @@ -22,6 +22,8 @@ from langchain_google_genai._common import ( SafetySettingDict, ) + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -46,10 +48,18 @@ def get_last_generation_info(self) -> Optional[Dict[str, Any]]: return self.__dict__.get('_last_generation_info') def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('input_tokens', 0) + try: + return self.get_last_generation_info().get('input_tokens', 0) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('output_tokens', 0) + try: + return self.get_last_generation_info().get('output_tokens', 0) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) def _stream( self, diff --git a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py index c389c177e4e..c0ce2ec029a 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py @@ -21,11 +21,10 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) - kimi_chat_open_ai = KimiChatModel( openai_api_base=model_credential['api_base'], openai_api_key=model_credential['api_key'], model_name=model_name, - **optional_params + extra_body=optional_params, ) return kimi_chat_open_ai diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py index 0194d1f0d27..add06621937 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class OllamaLLMModelParams(BaseForm): _step=0.01, precision=2) - max_tokens = forms.SliderField( + num_predict = forms.SliderField( TooltipLabel(_('Output the maximum Tokens'), _('Specify the maximum number of tokens that the model can generate')), required=True, default_value=1024, diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/image.py b/apps/setting/models_provider/impl/ollama_model_provider/model/image.py index 4cf0f1d56fc..215ce0130d7 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/image.py @@ -28,5 +28,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/image.py b/apps/setting/models_provider/impl/openai_model_provider/model/image.py index 731f476c45f..7ac0906a786 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index 2e6dd89ac93..1893852100b 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -9,7 +9,6 @@ from typing import List, Dict from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_openai.chat_models import ChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -35,9 +34,9 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** streaming = False azure_chat_open_ai = OpenAIChatModel( model=model_name, - openai_api_base=model_credential.get('api_base'), - openai_api_key=model_credential.get('api_key'), - **optional_params, + base_url=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + extra_body=optional_params, streaming=streaming, custom_get_token_ids=custom_get_token_ids ) diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/image.py b/apps/setting/models_provider/impl/qwen_model_provider/model/image.py index 97166757e67..bf3af0e3484 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/image.py @@ -18,9 +18,8 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model_name=model_name, openai_api_key=model_credential.get('api_key'), openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', - # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) return chat_tong_yi diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index 3b66ddfd62a..c4df28af9bb 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -26,6 +26,6 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) return chat_tong_yi diff --git a/apps/setting/models_provider/impl/regolo_model_provider/__init__.py b/apps/setting/models_provider/impl/regolo_model_provider/__init__.py new file mode 100644 index 00000000000..2dc4ab10db4 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/3/28 16:25 + @desc: +""" diff --git a/apps/setting/models_provider/impl/regolo_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/regolo_model_provider/credential/embedding.py new file mode 100644 index 00000000000..ddea7fed52d --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/credential/embedding.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 16:45 + @desc: +""" +import traceback +from typing import Dict + +from django.utils.translation import gettext as _ + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class RegoloEmbeddingCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + _('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query(_('Hello')) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + _('Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/regolo_model_provider/credential/image.py b/apps/setting/models_provider/impl/regolo_model_provider/credential/image.py new file mode 100644 index 00000000000..5975c774806 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/credential/image.py @@ -0,0 +1,74 @@ +# coding=utf-8 +import base64 +import os +import traceback +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from django.utils.translation import gettext_lazy as _, gettext + + +class RegoloImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel(_('Temperature'), + _('Higher values make the output more random, while lower values make it more focused and deterministic')), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class RegoloImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API URL', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])]) + for chunk in res: + print(chunk) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return RegoloImageModelParams() diff --git a/apps/setting/models_provider/impl/regolo_model_provider/credential/llm.py b/apps/setting/models_provider/impl/regolo_model_provider/credential/llm.py new file mode 100644 index 00000000000..60eb4ff0abf --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/credential/llm.py @@ -0,0 +1,78 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:32 + @desc: +""" +import traceback +from typing import Dict + +from django.utils.translation import gettext_lazy as _, gettext +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class RegoloLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel(_('Temperature'), + _('Higher values make the output more random, while lower values make it more focused and deterministic')), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class RegoloLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) + else: + return False + try: + + model = provider.get_model(model_type, model_name, model_credential, **model_params) + model.invoke([HumanMessage(content=gettext('Hello'))]) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return RegoloLLMModelParams() diff --git a/apps/setting/models_provider/impl/regolo_model_provider/credential/tti.py b/apps/setting/models_provider/impl/regolo_model_provider/credential/tti.py new file mode 100644 index 00000000000..88f46ce4143 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/credential/tti.py @@ -0,0 +1,89 @@ +# coding=utf-8 +import traceback +from typing import Dict + +from django.utils.translation import gettext_lazy as _, gettext + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class RegoloTTIModelParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel(_('Image size'), + _('The image generation endpoint allows you to create raw images based on text prompts. ')), + required=True, + default_value='1024x1024', + option_list=[ + {'value': '1024x1024', 'label': '1024x1024'}, + {'value': '1024x1792', 'label': '1024x1792'}, + {'value': '1792x1024', 'label': '1792x1024'}, + ], + text_field='label', + value_field='value' + ) + + quality = forms.SingleSelect( + TooltipLabel(_('Picture quality'), _(''' +By default, images are produced in standard quality. + ''')), + required=True, + default_value='standard', + option_list=[ + {'value': 'standard', 'label': 'standard'}, + {'value': 'hd', 'label': 'hd'}, + ], + text_field='label', + value_field='value' + ) + + n = forms.SliderField( + TooltipLabel(_('Number of pictures'), + _('1 as default')), + required=True, default_value=1, + _min=1, + _max=10, + _step=1, + precision=0) + + +class RegoloTextToImageModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.check_auth() + print(res) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return RegoloTTIModelParams() diff --git a/apps/setting/models_provider/impl/regolo_model_provider/icon/regolo_icon_svg b/apps/setting/models_provider/impl/regolo_model_provider/icon/regolo_icon_svg new file mode 100644 index 00000000000..b69154451ad --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/icon/regolo_icon_svg @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + diff --git a/apps/setting/models_provider/impl/regolo_model_provider/model/embedding.py b/apps/setting/models_provider/impl/regolo_model_provider/model/embedding.py new file mode 100644 index 00000000000..b067b8eff29 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/model/embedding.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 17:44 + @desc: +""" +from typing import Dict + +from langchain_community.embeddings import OpenAIEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class RegoloEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return RegoloEmbeddingModel( + api_key=model_credential.get('api_key'), + model=model_name, + openai_api_base="https://api.regolo.ai/v1", + ) diff --git a/apps/setting/models_provider/impl/regolo_model_provider/model/image.py b/apps/setting/models_provider/impl/regolo_model_provider/model/image.py new file mode 100644 index 00000000000..f16768fad1e --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/model/image.py @@ -0,0 +1,19 @@ +from typing import Dict + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +class RegoloImage(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return RegoloImage( + model_name=model_name, + openai_api_base="https://api.regolo.ai/v1", + openai_api_key=model_credential.get('api_key'), + streaming=True, + stream_usage=True, + extra_body=optional_params + ) diff --git a/apps/setting/models_provider/impl/regolo_model_provider/model/llm.py b/apps/setting/models_provider/impl/regolo_model_provider/model/llm.py new file mode 100644 index 00000000000..126a756a20d --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/model/llm.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2024/4/18 15:28 + @desc: +""" +from typing import List, Dict + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class RegoloChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return RegoloChatModel( + model=model_name, + openai_api_base="https://api.regolo.ai/v1", + openai_api_key=model_credential.get('api_key'), + extra_body=optional_params + ) diff --git a/apps/setting/models_provider/impl/regolo_model_provider/model/tti.py b/apps/setting/models_provider/impl/regolo_model_provider/model/tti.py new file mode 100644 index 00000000000..a92527295ac --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/model/tti.py @@ -0,0 +1,58 @@ +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class RegoloTextToImage(MaxKBBaseModel, BaseTextToImage): + api_base: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = "https://api.regolo.ai/v1" + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return RegoloTextToImage( + model=model_name, + api_base="https://api.regolo.ai/v1", + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def is_cache_model(self): + return False + + def check_auth(self): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + response_list = chat.models.with_raw_response.list() + + # self.generate_image('生成一个小猫图片') + + def generate_image(self, prompt: str, negative_prompt: str = None): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + res = chat.images.generate(model=self.model, prompt=prompt, **self.params) + file_urls = [] + for content in res.data: + url = content.url + file_urls.append(url) + + return file_urls diff --git a/apps/setting/models_provider/impl/regolo_model_provider/regolo_model_provider.py b/apps/setting/models_provider/impl/regolo_model_provider/regolo_model_provider.py new file mode 100644 index 00000000000..a5e7dc36550 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/regolo_model_provider.py @@ -0,0 +1,89 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: openai_model_provider.py + @date:2024/3/28 16:26 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ + ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.regolo_model_provider.credential.embedding import \ + RegoloEmbeddingCredential +from setting.models_provider.impl.regolo_model_provider.credential.llm import RegoloLLMModelCredential +from setting.models_provider.impl.regolo_model_provider.credential.tti import \ + RegoloTextToImageModelCredential +from setting.models_provider.impl.regolo_model_provider.model.embedding import RegoloEmbeddingModel +from setting.models_provider.impl.regolo_model_provider.model.llm import RegoloChatModel +from setting.models_provider.impl.regolo_model_provider.model.tti import RegoloTextToImage +from smartdoc.conf import PROJECT_DIR +from django.utils.translation import gettext as _ + +openai_llm_model_credential = RegoloLLMModelCredential() +openai_tti_model_credential = RegoloTextToImageModelCredential() +model_info_list = [ + ModelInfo('Phi-4', '', ModelTypeConst.LLM, + openai_llm_model_credential, RegoloChatModel + ), + ModelInfo('DeepSeek-R1-Distill-Qwen-32B', '', ModelTypeConst.LLM, + openai_llm_model_credential, + RegoloChatModel), + ModelInfo('maestrale-chat-v0.4-beta', '', + ModelTypeConst.LLM, openai_llm_model_credential, + RegoloChatModel), + ModelInfo('Llama-3.3-70B-Instruct', + '', + ModelTypeConst.LLM, openai_llm_model_credential, + RegoloChatModel), + ModelInfo('Llama-3.1-8B-Instruct', + '', + ModelTypeConst.LLM, openai_llm_model_credential, + RegoloChatModel), + ModelInfo('DeepSeek-Coder-6.7B-Instruct', '', + ModelTypeConst.LLM, openai_llm_model_credential, + RegoloChatModel) +] +open_ai_embedding_credential = RegoloEmbeddingCredential() +model_info_embedding_list = [ + ModelInfo('gte-Qwen2', '', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + RegoloEmbeddingModel), +] + +model_info_tti_list = [ + ModelInfo('FLUX.1-dev', '', + ModelTypeConst.TTI, openai_tti_model_credential, + RegoloTextToImage), + ModelInfo('sdxl-turbo', '', + ModelTypeConst.TTI, openai_tti_model_credential, + RegoloTextToImage), +] +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(model_info_list) + .append_default_model_info( + ModelInfo('gpt-3.5-turbo', _('The latest gpt-3.5-turbo, updated with OpenAI adjustments'), ModelTypeConst.LLM, + openai_llm_model_credential, RegoloChatModel + )) + .append_model_info_list(model_info_embedding_list) + .append_default_model_info(model_info_embedding_list[0]) + .append_model_info_list(model_info_tti_list) + .append_default_model_info(model_info_tti_list[0]) + + .build() +) + + +class RegoloModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_regolo_provider', name='Regolo', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'regolo_model_provider', + 'icon', + 'regolo_icon_svg'))) diff --git a/apps/setting/models_provider/impl/siliconCloud_model_provider/model/image.py b/apps/setting/models_provider/impl/siliconCloud_model_provider/model/image.py index bb840f8c6dc..2ec0689d4d2 100644 --- a/apps/setting/models_provider/impl/siliconCloud_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/siliconCloud_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/siliconCloud_model_provider/model/llm.py b/apps/setting/models_provider/impl/siliconCloud_model_provider/model/llm.py index 9d79c6e0761..6fb0c7816fa 100644 --- a/apps/setting/models_provider/impl/siliconCloud_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/siliconCloud_model_provider/model/llm.py @@ -34,5 +34,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/tencent_cloud_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_cloud_model_provider/model/llm.py index 7653cfc2f1f..cfcdf7aca21 100644 --- a/apps/setting/models_provider/impl/tencent_cloud_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/tencent_cloud_model_provider/model/llm.py @@ -33,21 +33,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params, + extra_body=optional_params, custom_get_token_ids=custom_get_token_ids ) return azure_chat_open_ai - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - try: - return super().get_num_tokens_from_messages(messages) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - try: - return super().get_num_tokens(text) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/image.py b/apps/setting/models_provider/impl/tencent_model_provider/model/image.py index 1b66ab6d23f..6800cdd567c 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/image.py b/apps/setting/models_provider/impl/vllm_model_provider/model/image.py index 4d5dda29dd7..c8cb0a84db9 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/image.py @@ -19,7 +19,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) def is_cache_model(self): diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py index 7d2a63acd08..4662a616965 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py @@ -1,9 +1,10 @@ # coding=utf-8 -from typing import Dict, List +from typing import Dict, Optional, Sequence, Union, Any, Callable from urllib.parse import urlparse, ParseResult from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.tools import BaseTool from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -31,13 +32,19 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params, streaming=True, stream_usage=True, + extra_body=optional_params ) return vllm_chat_open_ai - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + def get_num_tokens_from_messages( + self, + messages: list[BaseMessage], + tools: Optional[ + Sequence[Union[dict[str, Any], type, Callable, BaseTool]] + ] = None, + ) -> int: if self.usage_metadata is None or self.usage_metadata == {}: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py index 98c119e21cb..bb0b6881dea 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py @@ -15,23 +15,27 @@ class VolcanicEngineTTIModelGeneralParams(BaseForm): TooltipLabel(_('Image size'), _('If the gap between width, height and 512 is too large, the picture rendering effect will be poor and the probability of excessive delay will increase significantly. Recommended ratio and corresponding width and height before super score: width*height')), required=True, - default_value='512*512', + default_value='512x512', option_list=[ - {'value': '512*512', 'label': '512*512'}, - {'value': '512*384', 'label': '512*384'}, - {'value': '384*512', 'label': '384*512'}, - {'value': '512*341', 'label': '512*341'}, - {'value': '341*512', 'label': '341*512'}, - {'value': '512*288', 'label': '512*288'}, - {'value': '288*512', 'label': '288*512'}, + {'label': '512x512', 'value': '512x512'}, + {'label': '1024x1024', 'value': '1024x1024'}, + {'label': '864x1152', 'value': '864x1152'}, + {'label': '1152x864', 'value': '1152x864'}, + {'label': '1280x720', 'value': '1280x720'}, + {'label': '720x1280', 'value': '720x1280'}, + {'label': '832x1248', 'value': '832x1248'}, + {'label': '1248x832', 'value': '1248x832'}, + {'label': '1512x648', 'value': '1512x648'}, + ], text_field='label', value_field='value') class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential): - access_key = forms.PasswordInputField('Access Key ID', required=True) - secret_key = forms.PasswordInputField('Secret Access Key', required=True) + volcanic_api_url = forms.TextInputField('API URL', required=True, + default_value='https://ark.cn-beijing.volces.com/api/v3') + api_key = forms.PasswordInputField('Api key', required=True) def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): @@ -40,7 +44,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje raise AppApiException(ValidCode.valid_error.value, gettext('{model_type} Model type is not supported').format(model_type=model_type)) - for key in ['access_key', 'secret_key']: + for key in ['api_key']: if key not in model_credential: if raise_exception: raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) @@ -62,7 +66,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje return True def encryption_dict(self, model: Dict[str, object]): - return {**model, 'secret_key': super().encryption(model.get('secret_key', ''))} + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): return VolcanicEngineTTIModelGeneralParams() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py index 39446b4e19c..6e2517bd4ad 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py index 181ad2971db..8f089f26988 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -17,5 +17,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py index dd021c64320..5caed19fac8 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py @@ -7,130 +7,32 @@ pip install websockets ''' - -import datetime -import hashlib -import hmac -import json -import sys from typing import Dict -import requests + +from volcenginesdkarkruntime import Ark from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_tti import BaseTextToImage -method = 'POST' -host = 'visual.volcengineapi.com' -region = 'cn-north-1' -endpoint = 'https://visual.volcengineapi.com' -service = 'cv' - -req_key_dict = { - 'general_v1.4': 'high_aes_general_v14', - 'general_v2.0': 'high_aes_general_v20', - 'general_v2.0_L': 'high_aes_general_v20_L', - 'anime_v1.3': 'high_aes', - 'anime_v1.3.1': 'high_aes', -} - - -def sign(key, msg): - return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() - - -def getSignatureKey(key, dateStamp, regionName, serviceName): - kDate = sign(key.encode('utf-8'), dateStamp) - kRegion = sign(kDate, regionName) - kService = sign(kRegion, serviceName) - kSigning = sign(kService, 'request') - return kSigning - - -def formatQuery(parameters): - request_parameters_init = '' - for key in sorted(parameters): - request_parameters_init += key + '=' + parameters[key] + '&' - request_parameters = request_parameters_init[:-1] - return request_parameters - - -def signV4Request(access_key, secret_key, service, req_query, req_body): - if access_key is None or secret_key is None: - print('No access key is available.') - sys.exit() - - t = datetime.datetime.utcnow() - current_date = t.strftime('%Y%m%dT%H%M%SZ') - # current_date = '20210818T095729Z' - datestamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope - canonical_uri = '/' - canonical_querystring = req_query - signed_headers = 'content-type;host;x-content-sha256;x-date' - payload_hash = hashlib.sha256(req_body.encode('utf-8')).hexdigest() - content_type = 'application/json' - canonical_headers = 'content-type:' + content_type + '\n' + 'host:' + host + \ - '\n' + 'x-content-sha256:' + payload_hash + \ - '\n' + 'x-date:' + current_date + '\n' - canonical_request = method + '\n' + canonical_uri + '\n' + canonical_querystring + \ - '\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash - # print(canonical_request) - algorithm = 'HMAC-SHA256' - credential_scope = datestamp + '/' + region + '/' + service + '/' + 'request' - string_to_sign = algorithm + '\n' + current_date + '\n' + credential_scope + '\n' + hashlib.sha256( - canonical_request.encode('utf-8')).hexdigest() - # print(string_to_sign) - signing_key = getSignatureKey(secret_key, datestamp, region, service) - # print(signing_key) - signature = hmac.new(signing_key, (string_to_sign).encode( - 'utf-8'), hashlib.sha256).hexdigest() - # print(signature) - - authorization_header = algorithm + ' ' + 'Credential=' + access_key + '/' + \ - credential_scope + ', ' + 'SignedHeaders=' + \ - signed_headers + ', ' + 'Signature=' + signature - # print(authorization_header) - headers = {'X-Date': current_date, - 'Authorization': authorization_header, - 'X-Content-Sha256': payload_hash, - 'Content-Type': content_type - } - # print(headers) - - # ************* SEND THE REQUEST ************* - request_url = endpoint + '?' + canonical_querystring - - print('\nBEGIN REQUEST++++++++++++++++++++++++++++++++++++') - print('Request URL = ' + request_url) - try: - r = requests.post(request_url, headers=headers, data=req_body) - except Exception as err: - print(f'error occurred: {err}') - raise - else: - print('\nRESPONSE++++++++++++++++++++++++++++++++++++') - print(f'Response code: {r.status_code}\n') - # 使用 replace 方法将 \u0026 替换为 & - resp_str = r.text.replace("\\u0026", "&") - if r.status_code != 200: - raise Exception(f'Error: {resp_str}') - print(f'Response body: {resp_str}\n') - return json.loads(resp_str)['data']['image_urls'] - class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage): - access_key: str - secret_key: str + api_key: str + api_base: str model_version: str params: dict def __init__(self, **kwargs): super().__init__(**kwargs) - self.access_key = kwargs.get('access_key') - self.secret_key = kwargs.get('secret_key') + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') self.model_version = kwargs.get('model_version') self.params = kwargs.get('params') + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = {'params': {}} @@ -139,34 +41,29 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** optional_params['params'][key] = value return VolcanicEngineTextToImage( model_version=model_name, - access_key=model_credential.get('access_key'), - secret_key=model_credential.get('secret_key'), + api_key=model_credential.get('api_key'), + api_base=model_credential.get('volcanic_api_url') or 'https://ark-api.volcengine.com', **optional_params ) def check_auth(self): - res = self.generate_image('生成一张小猫图片') - print(res) + return True def generate_image(self, prompt: str, negative_prompt: str = None): - # 请求Query,按照接口文档中填入即可 - query_params = { - 'Action': 'CVProcess', - 'Version': '2022-08-31', - } - formatted_query = formatQuery(query_params) - size = self.params.pop('size', '512*512').split('*') - body_params = { - "req_key": req_key_dict[self.model_version], - "prompt": prompt, - "model_version": self.model_version, - "return_url": True, - "width": int(size[0]), - "height": int(size[1]), + client = Ark( + # 此为默认路径,您可根据业务所在地域进行配置 + base_url=self.api_base, + # 从环境变量中获取您的 API Key。此为默认方式,您可根据需要进行修改 + api_key=self.api_key, + ) + file_urls = [] + imagesResponse = client.images.generate( + model=self.model_version, + prompt=prompt, **self.params - } - formatted_body = json.dumps(body_params) - return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body) - - def is_cache_model(self): - return False + ) + if imagesResponse.data[0].url: + file_urls.append(imagesResponse.data[0].url) + elif imagesResponse.data[0].b64_json: + file_urls.append(imagesResponse.data[0].b64_json) + return file_urls diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index d963a144625..c9aaf06e0a1 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -98,6 +98,7 @@ .append_default_model_info(model_info_list[2]) .append_default_model_info(model_info_list[3]) .append_default_model_info(model_info_list[4]) + .append_default_model_info(model_info_list[5]) .append_model_info_list(model_info_embedding_list) .append_default_model_info(model_info_embedding_list[0]) .build() diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py index 06ec94aae34..d4d379db3d5 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -27,7 +27,7 @@ class WenxinLLMModelParams(BaseForm): _step=0.01, precision=2) - max_tokens = forms.SliderField( + max_output_tokens = forms.SliderField( TooltipLabel(_('Output the maximum Tokens'), _('Specify the maximum number of tokens that the model can generate')), required=True, default_value=1024, diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py index a195b86491b..66a766ba8c0 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py @@ -19,7 +19,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py index d76979bd3a3..9c0316ad20a 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -34,7 +34,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py index 8820a198607..28c8d267839 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py @@ -22,6 +22,9 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor): """UID of the launched model""" api_key: Optional[str] + @staticmethod + def is_cache_model(): + return False @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name, diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/image.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/image.py index f13c7153803..6ac7830d8ff 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/smartdoc/conf.py b/apps/smartdoc/conf.py index de61cb8e339..630f32cb152 100644 --- a/apps/smartdoc/conf.py +++ b/apps/smartdoc/conf.py @@ -7,6 +7,7 @@ 2. 程序需要, 用户不需要更改的写到settings中 3. 程序需要, 用户需要更改的写到本config中 """ +import datetime import errno import logging import os @@ -93,7 +94,8 @@ class Config(dict): 'SANDBOX': False, 'LOCAL_MODEL_HOST': '127.0.0.1', 'LOCAL_MODEL_PORT': '11636', - 'LOCAL_MODEL_PROTOCOL': "http" + 'LOCAL_MODEL_PROTOCOL': "http", + 'LOCAL_MODEL_HOST_WORKER': 1 } @@ -111,12 +113,19 @@ def get_db_setting(self) -> dict: "USER": self.get('DB_USER'), "PASSWORD": self.get('DB_PASSWORD'), "ENGINE": self.get('DB_ENGINE'), + "CONN_MAX_AGE": 0, "POOL_OPTIONS": { "POOL_SIZE": 20, - "MAX_OVERFLOW": int(self.get('DB_MAX_OVERFLOW')) + "MAX_OVERFLOW": int(self.get('DB_MAX_OVERFLOW')), + "RECYCLE": 1800, + "TIMEOUT": 30, + 'PRE_PING': True } } + def get_session_timeout(self): + return datetime.timedelta(seconds=int(self.get('SESSION_TIMEOUT', 60 * 60 * 2))) + def get_language_code(self): return self.get('LANGUAGE_CODE', 'zh-CN') diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index edf4586629d..de81420798a 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -126,6 +126,10 @@ "token_cache": { 'BACKEND': 'common.cache.file_cache.FileCache', 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径 + }, + 'captcha_cache': { + 'BACKEND': 'common.cache.file_cache.FileCache', + 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "captcha_cache") # 文件夹路径 } } diff --git a/apps/smartdoc/urls.py b/apps/smartdoc/urls.py index b243809cc77..996330471e2 100644 --- a/apps/smartdoc/urls.py +++ b/apps/smartdoc/urls.py @@ -23,10 +23,9 @@ from application.urls import urlpatterns as application_urlpatterns from common.cache_data.static_resource_cache import get_index_html -from common.constants.cache_code_constants import CacheCodeConstants +from common.init import init_template from common.init.init_doc import init_doc from common.response.result import Result -from common.util.cache_util import get_cache from smartdoc import settings from smartdoc.conf import PROJECT_DIR @@ -72,3 +71,4 @@ def page_not_found(request, exception): handler404 = page_not_found init_doc(urlpatterns, application_urlpatterns) +init_template.run() diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index 6093819a46a..cea2919f31d 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -6,18 +6,23 @@ @date:2023/9/5 16:32 @desc: """ +import base64 import datetime +import json import os import random import re import uuid +from captcha.image import ImageCaptcha from django.conf import settings from django.core import validators, signing, cache from django.core.mail import send_mail from django.core.mail.backends.smtp import EmailBackend from django.db import transaction from django.db.models import Q, QuerySet, Prefetch +from django.utils.translation import get_language +from django.utils.translation import gettext_lazy as _, to_locale from drf_yasg import openapi from rest_framework import serializers @@ -30,18 +35,39 @@ from common.mixins.api_mixin import ApiMixin from common.models.db_model_manage import DBModelManage from common.response.result import get_api_response -from common.util.common import valid_license +from common.util.common import valid_license, get_random_chars from common.util.field_message import ErrMessage from common.util.lock import lock +from common.util.rsa_util import decrypt, get_key_pair_by_sql from dataset.models import DataSet, Document, Paragraph, Problem, ProblemParagraphMapping from embedding.task import delete_embedding_by_dataset_id_list from function_lib.models.function import FunctionLib from setting.models import Team, SystemSetting, SettingType, Model, TeamMember, TeamMemberPermission from smartdoc.conf import PROJECT_DIR from users.models.user import User, password_encrypt, get_user_dynamics_permission -from django.utils.translation import gettext_lazy as _, gettext, to_locale -from django.utils.translation import get_language + user_cache = cache.caches['user_cache'] +captcha_cache = cache.caches['captcha_cache'] + + +class CaptchaSerializer(ApiMixin, serializers.Serializer): + @staticmethod + def get_response_body_api(): + return get_api_response(openapi.Schema( + type=openapi.TYPE_STRING, + title="captcha", + default="xxxx", + description="captcha" + )) + + @staticmethod + def generate(): + chars = get_random_chars() + image = ImageCaptcha() + data = image.generate(chars) + captcha = base64.b64encode(data.getbuffer()) + captcha_cache.set(f"LOGIN:{chars.lower()}", chars, timeout=5 * 60) + return 'data:image/png;base64,' + captcha.decode() class SystemSerializer(ApiMixin, serializers.Serializer): @@ -51,7 +77,8 @@ def get_profile(): xpack_cache = DBModelManage.get_model('xpack_cache') return {'version': version, 'IS_XPACK': hasattr(settings, 'IS_XPACK'), 'XPACK_LICENSE_IS_VALID': False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', - False)} + False), + 'ras': get_key_pair_by_sql().get('key')} @staticmethod def get_response_body_api(): @@ -71,30 +98,14 @@ class LoginSerializer(ApiMixin, serializers.Serializer): password = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Password"))) - def is_valid(self, *, raise_exception=False): - """ - 校验参数 - :param raise_exception: Whether to throw an exception can only be True - :return: User information - """ - super().is_valid(raise_exception=True) - username = self.data.get("username") - password = password_encrypt(self.data.get("password")) - user = QuerySet(User).filter(Q(username=username, - password=password) | Q(email=username, - password=password)).first() - if user is None: - raise ExceptionCodeConstants.INCORRECT_USERNAME_AND_PASSWORD.value.to_app_api_exception() - if not user.is_active: - raise AppApiException(1005, _("The user has been disabled, please contact the administrator!")) - return user + captcha = serializers.CharField(required=True, error_messages=ErrMessage.char(_("captcha"))) + encryptedData = serializers.CharField(required=False, label=_('encryptedData'), allow_null=True, + allow_blank=True) - def get_user_token(self): + def get_user_token(self, user): """ - Get user token :return: User Token (authentication information) """ - user = self.is_valid() token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email, 'type': AuthenticationType.USER.value}) return token @@ -106,10 +117,13 @@ class Meta: def get_request_body_api(self): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['username', 'password'], + required=['username', 'encryptedData'], properties={ 'username': openapi.Schema(type=openapi.TYPE_STRING, title=_("Username"), description=_("Username")), - 'password': openapi.Schema(type=openapi.TYPE_STRING, title=_("Password"), description=_("Password")) + 'password': openapi.Schema(type=openapi.TYPE_STRING, title=_("Password"), description=_("Password")), + 'captcha': openapi.Schema(type=openapi.TYPE_STRING, title=_("captcha"), description=_("captcha")), + 'encryptedData': openapi.Schema(type=openapi.TYPE_STRING, title=_("encryptedData"), + description=_("encryptedData")) } ) @@ -121,6 +135,29 @@ def get_response_body_api(self): description="认证token" )) + @staticmethod + def login(instance): + username = instance.get("username", "") + encryptedData = instance.get("encryptedData", "") + if encryptedData: + json_data = json.loads(decrypt(encryptedData)) + instance.update(json_data) + LoginSerializer(data=instance).is_valid(raise_exception=True) + password = instance.get("password") + captcha = instance.get("captcha", "") + captcha_value = captcha_cache.get(f"LOGIN:{captcha.lower()}") + if captcha_value is None: + raise AppApiException(1005, _("Captcha code error or expiration")) + user = QuerySet(User).filter(Q(username=username, + password=password_encrypt(password)) | Q(email=username, + password=password_encrypt( + password))).first() + if user is None: + raise ExceptionCodeConstants.INCORRECT_USERNAME_AND_PASSWORD.value.to_app_api_exception() + if not user.is_active: + raise AppApiException(1005, _("The user has been disabled, please contact the administrator!")) + return user + class RegisterSerializer(ApiMixin, serializers.Serializer): """ diff --git a/apps/users/urls.py b/apps/users/urls.py index e5e2fe0dfb2..a9d1e134c90 100644 --- a/apps/users/urls.py +++ b/apps/users/urls.py @@ -6,6 +6,7 @@ urlpatterns = [ path('profile', views.Profile.as_view()), path('user', views.User.as_view(), name="profile"), + path('user/captcha', views.CaptchaView.as_view(), name='captcha'), path('user/language', views.SwitchUserLanguageView.as_view(), name='language'), path('user/list', views.User.Query.as_view()), path('user/login', views.Login.as_view(), name='login'), diff --git a/apps/users/views/user.py b/apps/users/views/user.py index 55d4b6b9ad9..c77dce5bbd1 100644 --- a/apps/users/views/user.py +++ b/apps/users/views/user.py @@ -22,11 +22,11 @@ from common.log.log import log from common.response import result from common.util.common import encryption -from smartdoc.settings import JWT_AUTH +from smartdoc.const import CONFIG from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \ RePasswordSerializer, \ SendEmailSerializer, UserProfile, UserSerializer, UserManageSerializer, UserInstanceSerializer, SystemSerializer, \ - SwitchLanguageSerializer + SwitchLanguageSerializer, CaptchaSerializer from users.views.common import get_user_operation_object, get_re_password_details user_cache = cache.caches['user_cache'] @@ -84,7 +84,7 @@ class SwitchUserLanguageView(APIView): description=_("language")), } ), - responses=RePasswordSerializer().get_response_body_api(), + responses=result.get_default_response(), tags=[_("User management")]) @log(menu='User management', operate='Switch Language', get_operation_object=lambda r, k: {'name': r.user.username}) @@ -111,7 +111,7 @@ class ResetCurrentUserPasswordView(APIView): description=_("Password")) } ), - responses=RePasswordSerializer().get_response_body_api(), + responses=result.get_default_response(), tags=[_("User management")]) @log(menu='User management', operate='Modify current user password', get_operation_object=lambda r, k: {'name': r.user.username}, @@ -170,6 +170,18 @@ def _get_details(request): } +class CaptchaView(APIView): + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary=_("Obtain graphical captcha"), + operation_id=_("Obtain graphical captcha"), + responses=CaptchaSerializer().get_response_body_api(), + security=[], + tags=[_("User management")]) + def get(self, request: Request): + return result.success(CaptchaSerializer().generate()) + + class Login(APIView): @action(methods=['POST'], detail=False) @@ -183,11 +195,9 @@ class Login(APIView): get_details=_get_details, get_operation_object=lambda r, k: {'name': r.data.get('username')}) def post(self, request: Request): - login_request = LoginSerializer(data=request.data) - # 校验请求参数 - user = login_request.is_valid(raise_exception=True) - token = login_request.get_user_token() - token_cache.set(token, user, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA']) + user = LoginSerializer().login(request.data) + token = LoginSerializer().get_user_token(user) + token_cache.set(token, user, timeout=CONFIG.get_session_timeout()) return result.success(token) diff --git a/installer/Dockerfile b/installer/Dockerfile index d2c1eefb6fa..81db7241543 100644 --- a/installer/Dockerfile +++ b/installer/Dockerfile @@ -5,7 +5,7 @@ RUN cd ui && \ npm install && \ npm run build && \ rm -rf ./node_modules -FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.8 AS stage-build +FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.14 AS stage-build ARG DEPENDENCIES=" \ python3-pip" @@ -25,11 +25,11 @@ RUN python3 -m venv /opt/py3 && \ pip install poetry==1.8.5 --break-system-packages && \ poetry config virtualenvs.create false && \ . /opt/py3/bin/activate && \ - if [ "$(uname -m)" = "x86_64" ]; then sed -i 's/^torch.*/torch = {version = "^2.6.0+cpu", source = "pytorch"}/g' pyproject.toml; fi && \ + if [ "$(uname -m)" = "x86_64" ]; then sed -i 's/^torch.*/torch = {version = "2.6.0+cpu", source = "pytorch"}/g' pyproject.toml; fi && \ poetry install && \ export MAXKB_CONFIG_TYPE=ENV && python3 /opt/maxkb/app/apps/manage.py compilemessages -FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.8 +FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.14 ARG DOCKER_IMAGE_TAG=dev \ BUILD_AT \ GITHUB_COMMIT @@ -70,7 +70,9 @@ RUN chmod 755 /opt/maxkb/app/installer/run-maxkb.sh && \ useradd --no-create-home --home /opt/maxkb/app/sandbox sandbox -g root && \ chown -R sandbox:root /opt/maxkb/app/sandbox && \ chmod g-x /usr/local/bin/* /usr/bin/* /bin/* /usr/sbin/* /sbin/* /usr/lib/postgresql/15/bin/* && \ - chmod g+x /usr/local/bin/python* + chmod g+xr /usr/bin/ld.so && \ + chmod g+x /usr/local/bin/python* && \ + find /etc/ -type f ! -path '/etc/resolv.conf' ! -path '/etc/hosts' | xargs chmod g-rx EXPOSE 8080 diff --git a/installer/Dockerfile-python-pg b/installer/Dockerfile-python-pg index f871ac4ef4f..eb52eec17fe 100644 --- a/installer/Dockerfile-python-pg +++ b/installer/Dockerfile-python-pg @@ -1,5 +1,5 @@ -FROM python:3.11-slim-bullseye AS python-stage -FROM postgres:15.8-bullseye +FROM python:3.11-slim-trixie AS python-stage +FROM postgres:15.14-trixie ARG DEPENDENCIES=" \ libexpat1-dev \ diff --git a/pyproject.toml b/pyproject.toml index 35d74a52e95..747a15e59e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,70 +8,75 @@ package-mode = false [tool.poetry.dependencies] python = ">=3.11,<3.12" -django = "4.2.18" -djangorestframework = "^3.15.2" +django = "4.2.20" +djangorestframework = "3.16.0" drf-yasg = "1.21.7" django-filter = "23.2" -langchain-openai = "^0.3.0" -langchain-anthropic = "^0.3.0" -langchain-community = "^0.3.0" -langchain-deepseek = "^0.1.0" -langchain-google-genai = "^2.0.9" -langchain-mcp-adapters = "^0.0.5" -langchain-huggingface = "^0.1.2" -langchain-ollama = "^0.3.0" -langgraph = "^0.3.0" -mcp = "^1.4.1" -psycopg2-binary = "2.9.10" -jieba = "^0.42.1" -diskcache = "^5.6.3" -pillow = "^10.2.0" -filetype = "^1.2.0" +langchain = "0.3.23" +langchain-openai = "0.3.12" +langchain-anthropic = "0.3.12" +langchain-community = "0.3.21" +langchain-deepseek = "0.1.3" +langchain-google-genai = "2.1.2" +langchain-mcp-adapters = "0.0.11" +langchain-huggingface = "0.1.2" +langchain-ollama = "0.3.2" +langgraph = "0.3.27" +mcp = "1.8.0" +psycopg = { extras = ["binary"], version = "3.2.9" } +jieba = "0.42.1" +diskcache = "5.6.3" +pillow = "10.4.0" +filetype = "1.2.0" torch = "2.6.0" -sentence-transformers = "^4.0.2" -openai = "^1.13.3" -tiktoken = "^0.7.0" -qianfan = "^0.3.6.1" -pycryptodome = "^3.19.0" -beautifulsoup4 = "^4.12.2" -html2text = "^2024.2.26" -django-ipware = "^6.0.4" -django-apscheduler = "^0.6.2" +sentence-transformers = "4.0.2" +openai = "1.72.0" +tiktoken = "0.7.0" +qianfan = "0.3.18" +pycryptodome = "3.22.0" +beautifulsoup4 = "4.13.3" +html2text = "2024.2.26" +django-ipware = "6.0.5" +django-apscheduler = "0.6.2" pymupdf = "1.24.9" -pypdf = "4.3.1" +pypdf = "6.0.0" rapidocr-onnxruntime = "1.3.24" -python-docx = "^1.1.0" -xlwt = "^1.3.0" -dashscope = "^1.17.0" -zhipuai = "^2.0.1" -httpx = "^0.27.0" -httpx-sse = "^0.4.0" -websockets = "^13.0" -openpyxl = "^3.1.2" -xlrd = "^2.0.1" -gunicorn = "^23.0.0" +python-docx = "1.1.2" +xlwt = "1.3.0" +dashscope = "1.23.1" +zhipuai = "2.1.5.20250410" +httpx = "0.27.2" +httpx-sse = "0.4.0" +websockets = "13.1" +openpyxl = "3.1.5" +xlrd = "2.0.1" +gunicorn = "23.0.0" python-daemon = "3.0.1" -boto3 = "^1.34.160" -tencentcloud-sdk-python = "^3.0.1209" -xinference-client = "^1.3.0" -psutil = "^6.0.0" -celery = { extras = ["sqlalchemy"], version = "^5.4.0" } -django-celery-beat = "^2.6.0" -celery-once = "^3.0.1" -anthropic = "^0.49.0" -pylint = "3.1.0" -pydub = "^0.25.1" -cffi = "^1.17.1" -pysilk = "^0.0.1" -django-db-connection-pool = "^1.2.5" -opencv-python-headless = "^4.11.0.86" -pymysql = "^1.1.1" -accelerate = "^1.6.0" +boto3 = "1.37.31" +tencentcloud-sdk-python = "3.0.1357" +xinference-client = "1.4.1" +psutil = "6.1.1" +celery = { extras = ["sqlalchemy"], version = "5.5.1" } +django-celery-beat = "2.7.0" +celery-once = "3.0.1" +anthropic = "0.49.0" +pylint = "3.3.6" +pydub = "0.25.1" +cffi = "1.17.1" +pysilk = "0.0.1" +django-db-connection-pool = "1.2.6" +opencv-python-headless = "4.11.0.86" +pymysql = "1.1.1" +accelerate = "1.6.0" +captcha = "0.7.1" +setuptools = "^75.0.0" +volcengine-python-sdk = {extras = ["ark"], version = "4.0.5"} [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + [[tool.poetry.source]] name = "pytorch" url = "https://download.pytorch.org/whl/cpu" -priority = "explicit" \ No newline at end of file +priority = "explicit" diff --git a/ui/package.json b/ui/package.json index cee7a41c8fd..32adfd27108 100644 --- a/ui/package.json +++ b/ui/package.json @@ -27,17 +27,19 @@ "cropperjs": "^1.6.2", "dingtalk-jsapi": "^2.15.6", "echarts": "^5.5.0", - "element-plus": "^2.9.1", + "element-plus": "^2.13.5", "file-saver": "^2.0.5", "highlight.js": "^11.9.0", "install": "^0.13.0", "katex": "^0.16.10", "lodash": "^4.17.21", "marked": "^12.0.2", - "md-editor-v3": "^4.16.7", + "md-editor-v3": "^5.8.2", "mermaid": "^10.9.0", "mitt": "^3.0.0", "moment": "^2.30.1", + "nanoid": "^5.1.5", + "node-forge": "^1.3.1", "npm": "^10.2.4", "nprogress": "^0.2.0", "pinia": "^2.1.6", @@ -53,8 +55,7 @@ "vue-draggable-plus": "^0.6.0", "vue-i18n": "^9.13.1", "vue-router": "^4.2.4", - "vue3-menus": "^1.1.2", - "vuedraggable": "^4.1.0" + "vue3-menus": "^1.1.2" }, "devDependencies": { "@rushstack/eslint-patch": "^1.3.2", @@ -62,6 +63,7 @@ "@types/file-saver": "^2.0.7", "@types/jsdom": "^21.1.1", "@types/node": "^18.17.5", + "@types/node-forge": "^1.3.14", "@types/nprogress": "^0.2.0", "@vitejs/plugin-vue": "^4.3.1", "@vue/eslint-config-prettier": "^8.0.0", diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index efd4a4985a8..bc903c957eb 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -227,7 +227,7 @@ const getApplicationHitTest: ( data: any, loading?: Ref ) => Promise>> = (application_id, data, loading) => { - return get(`${prefix}/${application_id}/hit_test`, data, loading) + return put(`${prefix}/${application_id}/hit_test`, data, undefined, loading) } /** diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index a5a663b03c7..83de865b3bc 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -186,7 +186,7 @@ const getDatasetHitTest: ( data: any, loading?: Ref ) => Promise>> = (dataset_id, data, loading) => { - return get(`${prefix}/${dataset_id}/hit_test`, data, loading) + return put(`${prefix}/${dataset_id}/hit_test`, data, undefined, loading) } /** diff --git a/ui/src/api/team.ts b/ui/src/api/team.ts index 82e8f986e46..462534b0eba 100644 --- a/ui/src/api/team.ts +++ b/ui/src/api/team.ts @@ -36,7 +36,7 @@ const getMemberPermissions: (member_id: String) => Promise> = (membe } /** - * 获取成员权限 + * 修改成员权限 * @param 参数 member_id * @param 参数 { "team_member_permission_list": [ diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 077e230973e..c423f11105a 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -72,6 +72,7 @@ interface chatType { document_list: Array image_list: Array audio_list: Array + other_list: Array } } diff --git a/ui/src/api/type/user.ts b/ui/src/api/type/user.ts index a452673546a..197dba888c7 100644 --- a/ui/src/api/type/user.ts +++ b/ui/src/api/type/user.ts @@ -37,6 +37,11 @@ interface LoginRequest { * 密码 */ password: string + /** + * 验证码 + */ + captcha: string + encryptedData?: string } interface RegisterRequest { diff --git a/ui/src/api/user.ts b/ui/src/api/user.ts index eb12fd2ebf8..0d669705442 100644 --- a/ui/src/api/user.ts +++ b/ui/src/api/user.ts @@ -10,22 +10,26 @@ import type { } from '@/api/type/user' import type { Ref } from 'vue' + +const login: (request: LoginRequest, loading?: Ref) => Promise> = ( + request, + loading +) => { + return post('/user/login', request, undefined, loading) +} + +const ldapLogin: (request: LoginRequest, loading?: Ref) => Promise> = ( + request, + loading +) => { + return post('/LDAP/login', request, undefined, loading) +} /** - * 登录 - * @param auth_type - * @param request 登录接口请求表单 - * @param loading 接口加载器 - * @returns 认证数据 + * 获取图形验证码 + * @returns */ -const login: ( - auth_type: string, - request: LoginRequest, - loading?: Ref -) => Promise> = (auth_type, request, loading) => { - if (auth_type !== '') { - return post(`/${auth_type}/login`, request, undefined, loading) - } - return post('/user/login', request, undefined, loading) +const getCaptcha: () => Promise> = () => { + return get('user/captcha') } /** * 登出 @@ -226,5 +230,7 @@ export default { postLanguage, getDingOauth2Callback, getlarkCallback, - getQrSource + getQrSource, + getCaptcha, + ldapLogin } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 0f2296439ae..98ec1e6fa3e 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -125,6 +125,28 @@ +
+

+ {{ $t('common.fileUpload.document') }}: +

+ + + + +
diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue index 7f09fa04c68..26cd8a0d06f 100644 --- a/ui/src/components/ai-chat/component/answer-content/index.vue +++ b/ui/src/components/ai-chat/component/answer-content/index.vue @@ -80,7 +80,7 @@ const props = defineProps<{ chatRecord: chatType application: any loading: boolean - sendMessage: (question: string, other_params_data?: any, chat?: chatType) => void + sendMessage: (question: string, other_params_data?: any, chat?: chatType) => Promise chatManagement: any type: 'log' | 'ai-chat' | 'debug-ai-chat' }>() @@ -98,9 +98,10 @@ const showUserAvatar = computed(() => { const chatMessage = (question: string, type: 'old' | 'new', other_params_data?: any) => { if (type === 'old') { add_answer_text_list(props.chatRecord.answer_text_list) - props.sendMessage(question, other_params_data, props.chatRecord) - props.chatManagement.open(props.chatRecord.id) - props.chatManagement.write(props.chatRecord.id) + props.sendMessage(question, other_params_data, props.chatRecord).then(() => { + props.chatManagement.open(props.chatRecord.id) + props.chatManagement.write(props.chatRecord.id) + }) } else { props.sendMessage(question, other_params_data) } diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue index acf3085ed97..a2f95812365 100644 --- a/ui/src/components/ai-chat/component/chat-input-operate/index.vue +++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue @@ -10,7 +10,8 @@ uploadDocumentList.length || uploadImageList.length || uploadAudioList.length || - uploadVideoList.length + uploadVideoList.length || + uploadOtherList.length " > @@ -30,22 +31,62 @@ class="file cursor" >
+
+ +
+ {{ item && item?.name }} +
+
- +
- -
- {{ item && item?.name }} +
+ + + + +
+
+ +
+ {{ item && item?.name }} +
+
+
+ + +
@@ -63,23 +104,25 @@ >
+
+ +
+ {{ item && item?.name }} +
+
- +
- -
- {{ item && item?.name }} -
@@ -87,7 +130,7 @@ - + @@ -221,11 +268,11 @@ - - + +
@@ -241,7 +288,7 @@ diff --git a/ui/src/components/model-select/index.vue b/ui/src/components/model-select/index.vue index 116824e3c63..6a3b63acec1 100644 --- a/ui/src/components/model-select/index.vue +++ b/ui/src/components/model-select/index.vue @@ -72,7 +72,7 @@ @@ -82,8 +82,6 @@ import type { Provider } from '@/api/type/model' import { relatedObject } from '@/utils/utils' import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue' import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue' - -import { t } from '@/locales' import useStore from '@/stores' defineOptions({ name: 'ModelSelect' }) diff --git a/ui/src/layout/components/breadcrumb/index.vue b/ui/src/layout/components/breadcrumb/index.vue index 9140e8209a8..ae5011318d6 100644 --- a/ui/src/layout/components/breadcrumb/index.vue +++ b/ui/src/layout/components/breadcrumb/index.vue @@ -228,6 +228,11 @@ function getApplication() { } function refresh() { common.saveBreadcrumb(null) + if (isDataset.value) { + getDataset() + } else if (isApplication.value) { + getApplication() + } } onMounted(() => { if (!breadcrumbData.value) { diff --git a/ui/src/locales/lang/en-US/ai-chat.ts b/ui/src/locales/lang/en-US/ai-chat.ts index 3a52270977c..a857836869b 100644 --- a/ui/src/locales/lang/en-US/ai-chat.ts +++ b/ui/src/locales/lang/en-US/ai-chat.ts @@ -63,6 +63,9 @@ export default { limitMessage2: 'files', sizeLimit: 'Each file must not exceed', imageMessage: 'Please process the image content', + documentMessage: 'Please understand the content of the document', + audioMessage: 'Please understand the audio content', + otherMessage: 'Please understand the file content', errorMessage: 'Upload Failed' }, executionDetails: { diff --git a/ui/src/locales/lang/en-US/common.ts b/ui/src/locales/lang/en-US/common.ts index 96afd9916da..2fd0a30b32d 100644 --- a/ui/src/locales/lang/en-US/common.ts +++ b/ui/src/locales/lang/en-US/common.ts @@ -45,7 +45,10 @@ export default { document: 'Documents', image: 'Image', audio: 'Audio', - video: 'Video' + video: 'Video', + other: 'Other', + addExtensions: 'Add suffix', + existingExtensionsTip: 'File suffix already exists', }, status: { label: 'Status', @@ -55,7 +58,7 @@ export default { param: { outputParam: 'Output Parameters', inputParam: 'Input Parameters', - initParam: 'Startup Parameters', + initParam: 'Startup Parameters' }, inputPlaceholder: 'Please input', diff --git a/ui/src/locales/lang/en-US/views/application-workflow.ts b/ui/src/locales/lang/en-US/views/application-workflow.ts index e4385ea3791..e1fd8009e68 100644 --- a/ui/src/locales/lang/en-US/views/application-workflow.ts +++ b/ui/src/locales/lang/en-US/views/application-workflow.ts @@ -104,7 +104,8 @@ export default { label: 'File types allowed for upload', documentText: 'Requires "Document Content Extraction" node to parse document content', imageText: 'Requires "Image Understanding" node to parse image content', - audioText: 'Requires "Speech-to-Text" node to parse audio content' + audioText: 'Requires "Speech-to-Text" node to parse audio content', + otherText: 'Need to parse this type of file by yourself' } } }, @@ -222,14 +223,14 @@ export default { }, mcpNode: { label: 'MCP Server', - text: 'Call MCP Tools through SSE', + text: 'Call MCP Tools through SSE/Streamable HTTP', getToolsSuccess: 'Get Tools Successfully', getTool: 'Get Tools', tool: 'Tool', toolParam: 'Tool Params', mcpServerTip: 'Please enter the JSON format of the MCP server config', mcpToolTip: 'Please select a tool', - configLabel: 'MCP Server Config (Only supports SSE call method)' + configLabel: 'MCP Server Config (Only supports SSE/Streamable HTTP call method)' }, imageGenerateNode: { label: 'Image Generation', diff --git a/ui/src/locales/lang/en-US/views/application.ts b/ui/src/locales/lang/en-US/views/application.ts index b69ede6d890..8247e48ec6b 100644 --- a/ui/src/locales/lang/en-US/views/application.ts +++ b/ui/src/locales/lang/en-US/views/application.ts @@ -139,7 +139,7 @@ Response requirements: hybridSearch: 'Hybrid Search', hybridSearchTooltip: 'Hybrid search is a retrieval method based on both vector and text similarity, suitable for medium data volumes in the knowledge.', - similarityThreshold: 'Similarity higher than', + similarityThreshold: 'Similarity not lower than', similarityTooltip: 'The higher the similarity, the stronger the correlation.', topReferences: 'Top N Segments', maxCharacters: 'Maximum Characters per Reference', diff --git a/ui/src/locales/lang/en-US/views/document.ts b/ui/src/locales/lang/en-US/views/document.ts index 9a3f1da7387..ff17a61340c 100644 --- a/ui/src/locales/lang/en-US/views/document.ts +++ b/ui/src/locales/lang/en-US/views/document.ts @@ -149,7 +149,7 @@ export default { tooltip: 'When user asks a question, handle matched segments according to the set method.' }, similarity: { - label: 'Similarity Higher Than', + label: 'Similarity not lower than', placeholder: 'Directly return segment content', requiredMessage: 'Please enter similarity value' } diff --git a/ui/src/locales/lang/en-US/views/system.ts b/ui/src/locales/lang/en-US/views/system.ts index 303d1175dcf..8d3e50ad74b 100644 --- a/ui/src/locales/lang/en-US/views/system.ts +++ b/ui/src/locales/lang/en-US/views/system.ts @@ -1,5 +1,6 @@ export default { title: 'System', + subTitle: 'Setting', test: 'Test Connection', testSuccess: 'Successful', testFailed: 'Test connection failed', @@ -76,8 +77,8 @@ export default { dingtalk: 'DingTalk', lark: 'Lark', effective: 'Effective', - alreadyTurnedOn: 'Turned On', - notEnabled: 'Not Enabled', + alreadyTurnedOn: 'Enabled', + notEnabled: 'Disabled', validate: 'Validate', validateSuccess: 'Successful', validateFailed: 'Validation failed', @@ -122,7 +123,7 @@ export default { websiteSlogan: 'Welcome Slogan', websiteSloganPlaceholder: 'Please enter the welcome slogan', websiteSloganTip: 'The welcome slogan below the product logo', - defaultSlogan: 'Ready-to-use, flexible RAG Chatbot', + defaultSlogan: 'Ready-to-use open-source AI assistant', defaultTip: 'The default is the MaxKB platform interface, supports custom settings', logoDefaultTip: 'The default is the MaxKB login interface, supports custom settings', platformSetting: 'Platform Settings', diff --git a/ui/src/locales/lang/en-US/views/user.ts b/ui/src/locales/lang/en-US/views/user.ts index ae41fd564c0..2bbc1404363 100644 --- a/ui/src/locales/lang/en-US/views/user.ts +++ b/ui/src/locales/lang/en-US/views/user.ts @@ -28,6 +28,10 @@ export default { requiredMessage: 'Please enter username', lengthMessage: 'Length must be between 6 and 20 words' }, + captcha: { + label: 'captcha', + placeholder: 'Please enter the captcha' + }, nick_name: { label: 'Name', placeholder: 'Please enter name' diff --git a/ui/src/locales/lang/zh-CN/ai-chat.ts b/ui/src/locales/lang/zh-CN/ai-chat.ts index 76bb53d4f53..de702d2347c 100644 --- a/ui/src/locales/lang/zh-CN/ai-chat.ts +++ b/ui/src/locales/lang/zh-CN/ai-chat.ts @@ -61,6 +61,9 @@ export default { limitMessage2: '个文件', sizeLimit: '单个文件大小不能超过', imageMessage: '请解析图片内容', + documentMessage: '请理解文档内容', + audioMessage: '请理解音频内容', + otherMessage: '请理解文件内容', errorMessage: '上传失败' }, executionDetails: { diff --git a/ui/src/locales/lang/zh-CN/common.ts b/ui/src/locales/lang/zh-CN/common.ts index 97e25b704cf..db1e7e7318a 100644 --- a/ui/src/locales/lang/zh-CN/common.ts +++ b/ui/src/locales/lang/zh-CN/common.ts @@ -45,7 +45,10 @@ export default { document: '文档', image: '图片', audio: '音频', - video: '视频' + video: '视频', + other: '其他文件', + addExtensions: '添加后缀名', + existingExtensionsTip: '文件后缀已存在', }, status: { label: '状态', diff --git a/ui/src/locales/lang/zh-CN/views/application-workflow.ts b/ui/src/locales/lang/zh-CN/views/application-workflow.ts index 4c5a19d7686..c7c6038cc5f 100644 --- a/ui/src/locales/lang/zh-CN/views/application-workflow.ts +++ b/ui/src/locales/lang/zh-CN/views/application-workflow.ts @@ -105,8 +105,10 @@ export default { label: '上传的文件类型', documentText: '需要使用“文档内容提取”节点解析文档内容', imageText: '需要使用“视觉模型”节点解析图片内容', - audioText: '需要使用“语音转文本”节点解析音频内容' - } + audioText: '需要使用“语音转文本”节点解析音频内容', + otherText: '需要自行解析该类型文件' + }, + } }, aiChatNode: { @@ -222,14 +224,14 @@ export default { }, mcpNode: { label: 'MCP 调用', - text: '通过SSE方式执行MCP服务中的工具', + text: '通过SSE/Streamable HTTP方式执行MCP服务中的工具', getToolsSuccess: '获取工具成功', getTool: '获取工具', tool: '工具', toolParam: '工具参数', mcpServerTip: '请输入JSON格式的MCP服务器配置', mcpToolTip: '请选择工具', - configLabel: 'MCP Server Config (仅支持SSE调用方式)' + configLabel: 'MCP Server Config (仅支持SSE/Streamable HTTP调用方式)' }, imageGenerateNode: { label: '图片生成', @@ -264,7 +266,7 @@ export default { label: '文本转语音', text: '将文本通过语音合成模型转换为音频', tts_model: { - label: '语音识别模型' + label: '语音合成模型' }, content: { label: '选择文本内容' diff --git a/ui/src/locales/lang/zh-CN/views/application.ts b/ui/src/locales/lang/zh-CN/views/application.ts index dc9b16216bc..99db3f1f5cd 100644 --- a/ui/src/locales/lang/zh-CN/views/application.ts +++ b/ui/src/locales/lang/zh-CN/views/application.ts @@ -130,7 +130,7 @@ export default { hybridSearch: '混合检索', hybridSearchTooltip: '混合检索是一种基于向量和文本相似度的检索方式,适用于知识库中的中等数据量场景。', - similarityThreshold: '相似度高于', + similarityThreshold: '相似度不低于', similarityTooltip: '相似度越高相关性越强。', topReferences: '引用分段数 TOP', maxCharacters: '最多引用字符数', diff --git a/ui/src/locales/lang/zh-CN/views/document.ts b/ui/src/locales/lang/zh-CN/views/document.ts index bfdf2907ea4..0f5b03ba8b3 100644 --- a/ui/src/locales/lang/zh-CN/views/document.ts +++ b/ui/src/locales/lang/zh-CN/views/document.ts @@ -147,7 +147,7 @@ export default { tooltip: '用户提问时,命中文档下的分段时按照设置的方式进行处理。' }, similarity: { - label: '相似度高于', + label: '相似度不低于', placeholder: '直接返回分段内容', requiredMessage: '请输入相似度' } diff --git a/ui/src/locales/lang/zh-CN/views/system.ts b/ui/src/locales/lang/zh-CN/views/system.ts index 9ce23d90d86..72624d26a48 100644 --- a/ui/src/locales/lang/zh-CN/views/system.ts +++ b/ui/src/locales/lang/zh-CN/views/system.ts @@ -1,5 +1,6 @@ export default { - title: '系统设置', + title: '系统管理', + subTitle: '系统设置', test: '测试连接', testSuccess: '测试连接成功', testFailed: '测试连接失败', @@ -120,7 +121,7 @@ export default { websiteSlogan: '欢迎语', websiteSloganPlaceholder: '请输入欢迎语', websiteSloganTip: '产品 Logo 下的欢迎语', - defaultSlogan: '欢迎使用 MaxKB 智能知识库问答系统', + defaultSlogan: '欢迎使用 MaxKB 开源 AI 助手', logoDefaultTip: '默认为 MaxKB 登录界面,支持自定义设置', defaultTip: '默认为 MaxKB 平台界面,支持自定义设置', platformSetting: '平台设置', diff --git a/ui/src/locales/lang/zh-CN/views/user.ts b/ui/src/locales/lang/zh-CN/views/user.ts index 4e2a8760f92..191074c0c06 100644 --- a/ui/src/locales/lang/zh-CN/views/user.ts +++ b/ui/src/locales/lang/zh-CN/views/user.ts @@ -25,6 +25,10 @@ export default { requiredMessage: '请输入用户名', lengthMessage: '长度在 6 到 20 个字符' }, + captcha: { + label: '验证码', + placeholder: '请输入验证码' + }, nick_name: { label: '姓名', placeholder: '请输入姓名' @@ -33,7 +37,7 @@ export default { label: '邮箱', placeholder: '请输入邮箱', requiredMessage: '请输入邮箱', - validatorEmail: '请输入有效邮箱格式!', + validatorEmail: '请输入有效邮箱格式!' }, phone: { label: '手机号', @@ -48,13 +52,13 @@ export default { new_password: { label: '新密码', placeholder: '请输入新密码', - requiredMessage: '请输入新密码', + requiredMessage: '请输入新密码' }, re_password: { label: '确认密码', placeholder: '请输入确认密码', requiredMessage: '请输入确认密码', - validatorMessage: '密码不一致', + validatorMessage: '密码不一致' } } }, diff --git a/ui/src/locales/lang/zh-Hant/ai-chat.ts b/ui/src/locales/lang/zh-Hant/ai-chat.ts index 75f9949a6dc..1c717335574 100644 --- a/ui/src/locales/lang/zh-Hant/ai-chat.ts +++ b/ui/src/locales/lang/zh-Hant/ai-chat.ts @@ -61,6 +61,9 @@ export default { limitMessage2: '個文件', sizeLimit: '單個文件大小不能超過', imageMessage: '請解析圖片內容', + documentMessage: '請理解檔案內容', + audioMessage: '請理解音訊內容', + otherMessage: '請理解檔案內容', errorMessage: '上傳失敗' }, executionDetails: { diff --git a/ui/src/locales/lang/zh-Hant/common.ts b/ui/src/locales/lang/zh-Hant/common.ts index 0ccbb5c1159..8e6293076c9 100644 --- a/ui/src/locales/lang/zh-Hant/common.ts +++ b/ui/src/locales/lang/zh-Hant/common.ts @@ -45,7 +45,10 @@ export default { document: '文檔', image: '圖片', audio: '音頻', - video: '視頻' + video: '視頻', + other: '其他文件', + addExtensions: '添加後綴名', + existingExtensionsTip: '文件後綴已存在', }, status: { label: '狀態', diff --git a/ui/src/locales/lang/zh-Hant/views/application-workflow.ts b/ui/src/locales/lang/zh-Hant/views/application-workflow.ts index 60269c021b2..f3a0a0c1e7d 100644 --- a/ui/src/locales/lang/zh-Hant/views/application-workflow.ts +++ b/ui/src/locales/lang/zh-Hant/views/application-workflow.ts @@ -105,7 +105,8 @@ export default { label: '上傳的文件類型', documentText: '需要使用「文檔內容提取」節點解析文檔內容', imageText: '需要使用「圖片理解」節點解析圖片內容', - audioText: '需要使用「語音轉文本」節點解析音頻內容' + audioText: '需要使用「語音轉文本」節點解析音頻內容', + otherText: '需要自行解析該類型文件' } } }, @@ -207,8 +208,8 @@ export default { text: '識別出圖片中的物件、場景等信息回答用戶問題', answer: 'AI 回答內容', model: { - label: '圖片理解模型', - requiredMessage: '請選擇圖片理解模型' + label: '視覺模型', + requiredMessage: '請選擇視覺模型' }, image: { label: '選擇圖片', @@ -222,14 +223,14 @@ export default { }, mcpNode: { label: 'MCP 調用', - text: '透過SSE方式執行MCP服務中的工具', + text: '透過SSE/Streamable HTTP方式執行MCP服務中的工具', getToolsSuccess: '獲取工具成功', getTool: '獲取工具', tool: '工具', toolParam: '工具變數', mcpServerTip: '請輸入JSON格式的MCP服務器配置', mcpToolTip: '請選擇工具', - configLabel: 'MCP Server Config (僅支持SSE調用方式)' + configLabel: 'MCP Server Config (僅支持SSE/Streamable HTTP調用方式)' }, imageGenerateNode: { label: '圖片生成', diff --git a/ui/src/locales/lang/zh-Hant/views/application.ts b/ui/src/locales/lang/zh-Hant/views/application.ts index 3b6f1756ed7..d0df9b6b906 100644 --- a/ui/src/locales/lang/zh-Hant/views/application.ts +++ b/ui/src/locales/lang/zh-Hant/views/application.ts @@ -129,7 +129,7 @@ export default { hybridSearch: '混合檢索', hybridSearchTooltip: '混合檢索是一種基於向量和文本相似度的檢索方式,適用於知識庫中的中等數據量場景。', - similarityThreshold: '相似度高於', + similarityThreshold: '相似度不低於', similarityTooltip: '相似度越高相關性越強。', topReferences: '引用分段數 TOP', maxCharacters: '最多引用字元數', diff --git a/ui/src/locales/lang/zh-Hant/views/document.ts b/ui/src/locales/lang/zh-Hant/views/document.ts index adfc8cc463b..d8406908e6a 100644 --- a/ui/src/locales/lang/zh-Hant/views/document.ts +++ b/ui/src/locales/lang/zh-Hant/views/document.ts @@ -146,7 +146,7 @@ export default { tooltip: '用戶提問時,命中文檔下的分段時按照設置的方式進行處理。' }, similarity: { - label: '相似度高于', + label: '相似度不低於', placeholder: '直接返回分段内容', requiredMessage: '请输入相似度' } diff --git a/ui/src/locales/lang/zh-Hant/views/system.ts b/ui/src/locales/lang/zh-Hant/views/system.ts index 10259390be1..1e33f22fb33 100644 --- a/ui/src/locales/lang/zh-Hant/views/system.ts +++ b/ui/src/locales/lang/zh-Hant/views/system.ts @@ -1,5 +1,6 @@ export default { - title: '系統設置', + title: '系統管理', + subTitle: '系統設置', test: '測試連線', testSuccess: '測試連線成功', testFailed: '測試連線失敗', @@ -122,7 +123,7 @@ export default { websiteSloganPlaceholder: '請輸入歡迎語', websiteSloganTip: '產品 Logo 下的歡迎語', logoDefaultTip: '默认为 MaxKB 登錄界面,支持自定义设置', - defaultSlogan: '歡迎使用 MaxKB 智能知識庫問答系統', + defaultSlogan: '歡迎使用 MaxKB 開源 AI 助手', defaultTip: '默認為 MaxKB 平台界面,支持自定義設置', platformSetting: '平台設置', showUserManual: '顯示用戶手冊', diff --git a/ui/src/locales/lang/zh-Hant/views/template.ts b/ui/src/locales/lang/zh-Hant/views/template.ts index 241f9d8c516..05f24fed575 100644 --- a/ui/src/locales/lang/zh-Hant/views/template.ts +++ b/ui/src/locales/lang/zh-Hant/views/template.ts @@ -30,7 +30,7 @@ export default { RERANKER: '重排模型', STT: '語音辨識', TTS: '語音合成', - IMAGE: '圖片理解', + IMAGE: '視覺模型', TTI: '圖片生成' }, templateForm: { diff --git a/ui/src/locales/lang/zh-Hant/views/user.ts b/ui/src/locales/lang/zh-Hant/views/user.ts index 18ea3326acf..7b8f1a88000 100644 --- a/ui/src/locales/lang/zh-Hant/views/user.ts +++ b/ui/src/locales/lang/zh-Hant/views/user.ts @@ -26,6 +26,10 @@ export default { requiredMessage: '請輸入使用者名稱', lengthMessage: '長度須介於 6 到 20 個字元之間' }, + captcha: { + label: '驗證碼', + placeholder: '請輸入驗證碼' + }, nick_name: { label: '姓名', placeholder: '請輸入姓名' diff --git a/ui/src/request/index.ts b/ui/src/request/index.ts index 72588d2c6f2..a9f490149bf 100644 --- a/ui/src/request/index.ts +++ b/ui/src/request/index.ts @@ -11,7 +11,7 @@ import { ref, type WritableComputedRef } from 'vue' const axiosConfig = { baseURL: '/api', withCredentials: false, - timeout: 600000, + timeout: 1800000, headers: {} } diff --git a/ui/src/router/modules/setting.ts b/ui/src/router/modules/setting.ts index e97a658b02b..eaedb6a5f50 100644 --- a/ui/src/router/modules/setting.ts +++ b/ui/src/router/modules/setting.ts @@ -59,7 +59,7 @@ const settingRouter = { meta: { icon: 'app-setting', iconActive: 'app-setting-active', - title: 'common.setting', + title: 'views.system.subTitle', activeMenu: '/setting', parentPath: '/setting', parentName: 'setting', diff --git a/ui/src/stores/modules/user.ts b/ui/src/stores/modules/user.ts index b065d7596a5..c805715a662 100644 --- a/ui/src/stores/modules/user.ts +++ b/ui/src/stores/modules/user.ts @@ -8,6 +8,7 @@ import { useElementPlusTheme } from 'use-element-plus-theme' import { defaultPlatformSetting } from '@/utils/theme' import { useLocalStorage } from '@vueuse/core' import { localeConfigKey, getBrowserLang } from '@/locales/index' + export interface userStateTypes { userType: number // 1 系统操作者 2 对话用户 userInfo: User | null @@ -17,6 +18,7 @@ export interface userStateTypes { XPACK_LICENSE_IS_VALID: false isXPack: false themeInfo: any + rasKey: string } const useUserStore = defineStore({ @@ -29,7 +31,8 @@ const useUserStore = defineStore({ userAccessToken: '', XPACK_LICENSE_IS_VALID: false, isXPack: false, - themeInfo: null + themeInfo: null, + rasKey: '' }), actions: { getLanguage() { @@ -65,7 +68,7 @@ const useUserStore = defineStore({ if (token) { return token } - const local_token = localStorage.getItem(`${token}-accessToken`) + const local_token = localStorage.getItem(`${this.userAccessToken}-accessToken`) if (local_token) { return local_token } @@ -100,6 +103,7 @@ const useUserStore = defineStore({ this.version = ok.data?.version || '-' this.isXPack = ok.data?.IS_XPACK this.XPACK_LICENSE_IS_VALID = ok.data?.XPACK_LICENSE_IS_VALID + this.rasKey = ok.data?.ras || '' if (this.isEnterprise()) { await this.theme() @@ -135,8 +139,15 @@ const useUserStore = defineStore({ }) }, - async login(auth_type: string, username: string, password: string) { - return UserApi.login(auth_type, { username, password }).then((ok) => { + async login(data: any, loading?: Ref) { + return UserApi.login(data).then((ok) => { + this.token = ok.data + localStorage.setItem('token', ok.data) + return this.profile() + }) + }, + async asyncLdapLogin(data: any, loading?: Ref) { + return UserApi.ldapLogin(data).then((ok) => { this.token = ok.data localStorage.setItem('token', ok.data) return this.profile() diff --git a/ui/src/styles/app.scss b/ui/src/styles/app.scss index 8646670b07a..6d22ca6b575 100644 --- a/ui/src/styles/app.scss +++ b/ui/src/styles/app.scss @@ -713,6 +713,7 @@ h5 { border-radius: var(--el-border-radius-base) !important; padding: 5px 8px; font-weight: 400; + outline: none !important; } .el-radio-button__original-radio:checked + .el-radio-button__inner { color: var(--el-color-primary) !important; diff --git a/ui/src/styles/element-plus.scss b/ui/src/styles/element-plus.scss index d1f067b18fd..3314a43860f 100644 --- a/ui/src/styles/element-plus.scss +++ b/ui/src/styles/element-plus.scss @@ -62,7 +62,7 @@ } .el-form-item__label { font-weight: 400; - width: 100%; + width: 100% !important; } .el-form-item__error { @@ -145,6 +145,10 @@ .el-card { --el-card-padding: calc(var(--app-base-px) * 2); color: var(--el-text-color-regular); + overflow: visible; + .el-card__body { + overflow: inherit; + } } .el-dropdown { color: var(--app-text-color); diff --git a/ui/src/styles/md-editor.scss b/ui/src/styles/md-editor.scss index 6b117711412..c60f51f4e96 100644 --- a/ui/src/styles/md-editor.scss +++ b/ui/src/styles/md-editor.scss @@ -6,7 +6,8 @@ padding: 0; margin: 0; font-size: inherit; - table{ + word-break: break-word; + table { display: block; } p { diff --git a/ui/src/utils/utils.ts b/ui/src/utils/utils.ts index 44e68895c7f..7f76a93da3c 100644 --- a/ui/src/utils/utils.ts +++ b/ui/src/utils/utils.ts @@ -1,5 +1,5 @@ import { MsgError } from '@/utils/message' - +import { nanoid } from 'nanoid' export function toThousands(num: any) { return num?.toString().replace(/\d+/, function (n: any) { return n.replace(/(\d)(?=(?:\d{3})+$)/g, '$1,') @@ -25,7 +25,7 @@ export function filesize(size: number) { 随机id */ export const randomId = function () { - return Math.floor(Math.random() * 10000) + '' + return nanoid() } /* @@ -48,7 +48,9 @@ const typeList: any = { export function getImgUrl(name: string) { const list = Object.values(typeList).flat() - const type = list.includes(fileType(name).toLowerCase()) ? fileType(name).toLowerCase() : 'unknown' + const type = list.includes(fileType(name).toLowerCase()) + ? fileType(name).toLowerCase() + : 'unknown' return new URL(`../assets/fileType/${type}-icon.svg`, import.meta.url).href } // 是否是白名单后缀 diff --git a/ui/src/views/application-workflow/index.vue b/ui/src/views/application-workflow/index.vue index f9a30983943..e6a95cbf010 100644 --- a/ui/src/views/application-workflow/index.vue +++ b/ui/src/views/application-workflow/index.vue @@ -3,7 +3,7 @@
-

{{ detail?.name }}

+

{{ detail?.name }}

{{ $t('views.applicationWorkflow.info.previewVersion') }} @@ -101,7 +101,7 @@ />
-

+

{{ detail?.name || $t('views.application.applicationForm.form.appName.label') }}

@@ -279,7 +279,6 @@ async function publicHandle() { return } applicationApi.putPublishApplication(id as String, obj, loading).then(() => { - application.asyncGetApplicationDetail(id, loading).then((res: any) => { detail.value.name = res.data.name MsgSuccess(t('views.applicationWorkflow.tip.publicSuccess')) diff --git a/ui/src/views/application/ApplicationAccess.vue b/ui/src/views/application/ApplicationAccess.vue index ce2fe6aab82..8e1bf03b7e6 100644 --- a/ui/src/views/application/ApplicationAccess.vue +++ b/ui/src/views/application/ApplicationAccess.vue @@ -135,51 +135,4 @@ onMounted(() => { }) - + diff --git a/ui/src/views/application/component/CreateApplicationDialog.vue b/ui/src/views/application/component/CreateApplicationDialog.vue index 438bfe211a9..7415753c1af 100644 --- a/ui/src/views/application/component/CreateApplicationDialog.vue +++ b/ui/src/views/application/component/CreateApplicationDialog.vue @@ -242,6 +242,7 @@ const submitHandle = async (formEl: FormInstance | undefined) => { } applicationApi.postApplication(applicationForm.value, loading).then((res) => { MsgSuccess(t('common.createSuccess')) + emit('refresh') if (isWorkFlow(applicationForm.value.type)) { router.push({ path: `/application/${res.data.id}/workflow` }) } else { diff --git a/ui/src/views/application/component/ParamSettingDialog.vue b/ui/src/views/application/component/ParamSettingDialog.vue index cdae5bf6e85..bd0cb5545c3 100644 --- a/ui/src/views/application/component/ParamSettingDialog.vue +++ b/ui/src/views/application/component/ParamSettingDialog.vue @@ -11,7 +11,7 @@ >
- + { if (!bool) { - form.value = { - dataset_setting: { - search_mode: 'embedding', - top_n: 3, - similarity: 0.6, - max_paragraph_char_number: 5000, - no_references_setting: { - status: 'ai_questioning', - value: '{question}' - } - }, - problem_optimization: false, - problem_optimization_prompt: '' - } + // form.value = { + // dataset_setting: { + // search_mode: 'embedding', + // top_n: 3, + // similarity: 0.6, + // max_paragraph_char_number: 5000, + // no_references_setting: { + // status: 'ai_questioning', + // value: '{question}' + // } + // }, + // problem_optimization: false, + // problem_optimization_prompt: '' + // } noReferencesform.value = { ai_questioning: defaultValue['ai_questioning'], designated_answer: defaultValue['designated_answer'] diff --git a/ui/src/views/authentication/component/OIDC.vue b/ui/src/views/authentication/component/OIDC.vue index 2666bc6479d..d71158b9a8e 100644 --- a/ui/src/views/authentication/component/OIDC.vue +++ b/ui/src/views/authentication/component/OIDC.vue @@ -61,6 +61,15 @@ show-password /> + + + ({ state: '', clientId: '', clientSecret: '', + fieldMapping: '{"username": "preferred_username", "email": "email"}', redirectUrl: '' }, is_active: true @@ -156,6 +166,13 @@ const rules = reactive>({ trigger: 'blur' } ], + 'config_data.fieldMapping': [ + { + required: true, + message: t('views.system.authentication.oauth2.filedMappingPlaceholder'), + trigger: 'blur' + } + ], 'config_data.redirectUrl': [ { required: true, @@ -187,6 +204,12 @@ function getDetail() { authApi.getAuthSetting(form.value.auth_type, loading).then((res: any) => { if (res.data && JSON.stringify(res.data) !== '{}') { form.value = res.data + if ( + form.value.config_data.fieldMapping === '' || + form.value.config_data.fieldMapping === undefined + ) { + form.value.config_data.fieldMapping = '{"username": "preferred_username", "email": "email"}' + } } }) } diff --git a/ui/src/views/chat/base/index.vue b/ui/src/views/chat/base/index.vue index 27be286f25a..7156f7d894a 100644 --- a/ui/src/views/chat/base/index.vue +++ b/ui/src/views/chat/base/index.vue @@ -42,7 +42,6 @@
-
diff --git a/ui/src/views/login/index.vue b/ui/src/views/login/index.vue index 714c439c6bb..d4264406a65 100644 --- a/ui/src/views/login/index.vue +++ b/ui/src/views/login/index.vue @@ -34,6 +34,27 @@ +
+ +
+ + + + +
+
+
(false) const { user } = useStore() const router = useRouter() +import forge from 'node-forge' + const loginForm = ref({ username: '', - password: '' + password: '', + captcha: '', + encryptedData: '' }) +const identifyCode = ref('') + +function makeCode() { + useApi.getCaptcha().then((res: any) => { + identifyCode.value = res.data + }) +} const rules = ref>({ username: [ @@ -137,6 +170,13 @@ const rules = ref>({ message: t('views.user.userForm.form.password.requiredMessage'), trigger: 'blur' } + ], + captcha: [ + { + required: true, + message: t('views.user.userForm.form.captcha.placeholder'), + trigger: 'blur' + } ] }) const loginFormRef = ref() @@ -222,22 +262,43 @@ function changeMode(val: string) { showQrCodeTab.value = false loginForm.value = { username: '', - password: '' + password: '', + captcha: '' } redirectAuth(val) loginFormRef.value?.clearValidate() } const login = () => { - loginFormRef.value?.validate().then(() => { - loading.value = true - user - .login(loginMode.value, loginForm.value.username, loginForm.value.password) - .then(() => { - locale.value = localStorage.getItem('MaxKB-locale') || getBrowserLang() || 'en-US' - router.push({ name: 'home' }) - }) - .finally(() => (loading.value = false)) + if (!loginFormRef.value) { + return + } + loginFormRef.value?.validate((valid) => { + if (valid) { + loading.value = true + if (loginMode.value === 'LDAP') { + user + .asyncLdapLogin(loginForm.value) + .then(() => { + locale.value = localStorage.getItem('MaxKB-locale') || getBrowserLang() || 'en-US' + router.push({ name: 'home' }) + }) + .finally(() => (loading.value = false)) + } else { + const publicKey = forge.pki.publicKeyFromPem(user.rasKey) + const jsonData = JSON.stringify(loginForm.value) + const utf8Bytes = forge.util.encodeUtf8(jsonData) + const encrypted = publicKey.encrypt(utf8Bytes, 'RSAES-PKCS1-V1_5') + const encryptedBase64 = forge.util.encode64(encrypted) + user + .login({ encryptedData: encryptedBase64, username: loginForm.value.username }) + .then(() => { + locale.value = localStorage.getItem('MaxKB-locale') || getBrowserLang() || 'en-US' + router.push({ name: 'home' }) + }) + .finally(() => (loading.value = false)) + } + } }) } @@ -285,6 +346,7 @@ onBeforeMount(() => { declare const window: any onMounted(() => { + makeCode() const route = useRoute() const currentUrl = ref(route.fullPath) const params = new URLSearchParams(currentUrl.value.split('?')[1]) diff --git a/ui/src/views/login/reset-password/index.vue b/ui/src/views/login/reset-password/index.vue index 2c2ff02576e..876afde1470 100644 --- a/ui/src/views/login/reset-password/index.vue +++ b/ui/src/views/login/reset-password/index.vue @@ -1,6 +1,10 @@ - + diff --git a/ui/src/workflow/common/NodeContainer.vue b/ui/src/workflow/common/NodeContainer.vue index b61fe25b61a..990eb96c143 100644 --- a/ui/src/workflow/common/NodeContainer.vue +++ b/ui/src/workflow/common/NodeContainer.vue @@ -301,7 +301,8 @@ function clickNodes(item: any) { type: 'app-edge', sourceNodeId: props.nodeModel.id, sourceAnchorId: anchorData.value?.id, - targetNodeId: nodeModel.id + targetNodeId: nodeModel.id, + targetAnchorId: nodeModel.id + '_left' }) closeNodeMenu() diff --git a/ui/src/workflow/nodes/application-node/index.vue b/ui/src/workflow/nodes/application-node/index.vue index 77bff4ac0ca..4fc9fba5483 100644 --- a/ui/src/workflow/nodes/application-node/index.vue +++ b/ui/src/workflow/nodes/application-node/index.vue @@ -238,7 +238,8 @@ const update_field = () => { const new_user_input_field_list = cloneDeep( ok.data.work_flow.nodes[0].properties.user_input_field_list ) - const merge_api_input_field_list = new_api_input_field_list.map((item: any) => { + + const merge_api_input_field_list = (new_api_input_field_list || []).map((item: any) => { const find_field = old_api_input_field_list.find( (old_item: any) => old_item.variable == item.variable ) @@ -258,7 +259,7 @@ const update_field = () => { 'api_input_field_list', merge_api_input_field_list ) - const merge_user_input_field_list = new_user_input_field_list.map((item: any) => { + const merge_user_input_field_list = (new_user_input_field_list || []).map((item: any) => { const find_field = old_user_input_field_list.find( (old_item: any) => old_item.field == item.field ) @@ -294,6 +295,7 @@ const update_field = () => { } }) .catch((err) => { + console.log(err) set(props.nodeModel.properties, 'status', 500) }) } diff --git a/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue b/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue index c81ebc94f72..b7bed17fe86 100644 --- a/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue +++ b/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue @@ -13,7 +13,7 @@ :data="props.nodeModel.properties.api_input_field_list" class="mb-16" ref="tableRef" - row-key="field" + row-key="variable" >
Feature