Browse Source

增强对工具调用的识别,规范key

fulian23 3 months ago
parent
commit
abbfb1eb56
4 changed files with 20 additions and 48 deletions
  1. 2 1
      .gitignore
  2. 5 11
      api/aiRouter.py
  3. 0 13
      demo.py
  4. 13 23
      mcp_/client.py

+ 2 - 1
.gitignore

@@ -5,4 +5,5 @@ __pycache__/
 db_config.py
 db_config.py
 base_config.py
 base_config.py
 
 
-.idea/
+.idea/
+test/

+ 5 - 11
api/aiRouter.py

@@ -1,5 +1,5 @@
 import os, json, time, asyncio
 import os, json, time, asyncio
-from base_config import ai_key, path
+from base_config import ai_key, path, file_summary_app_id, commit_summary_app_id, filter_code_files_app_id, analysis_results_app_id
 from fastapi import APIRouter, BackgroundTasks
 from fastapi import APIRouter, BackgroundTasks
 from pathlib import Path
 from pathlib import Path
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -40,9 +40,8 @@ def generate_repo_path(uuid, repo_url):
 
 
 async def file_summary(content):
 async def file_summary(content):
     response = Application.call(
     response = Application.call(
-        # 若没有配置环境变量,可用百炼API Key将下行替换为:api_key="sk-xxx"。但不建议在生产环境中直接将API Key硬编码到代码中,以减少API Key泄露风险。
         api_key=ai_key,
         api_key=ai_key,
-        app_id='ef50d70cd4074a899a09875e6a6e36ea',
+        app_id=file_summary_app_id,
         prompt=content)
         prompt=content)
     if response.status_code == HTTPStatus.OK:
     if response.status_code == HTTPStatus.OK:
         try:
         try:
@@ -59,9 +58,8 @@ async def file_summary(content):
 
 
 async def commit_summary(content):
 async def commit_summary(content):
     response = Application.call(
     response = Application.call(
-        # 若没有配置环境变量,可用百炼API Key将下行替换为:api_key="sk-xxx"。但不建议在生产环境中直接将API Key硬编码到代码中,以减少API Key泄露风险。
         api_key=ai_key,
         api_key=ai_key,
-        app_id='88426cc2301b44bea5d28d41d187ebf2',
+        app_id=commit_summary_app_id,
         prompt=content)
         prompt=content)
     if response.status_code == HTTPStatus.OK:
     if response.status_code == HTTPStatus.OK:
         try:
         try:
@@ -78,9 +76,8 @@ async def commit_summary(content):
 
 
 def filter_code_files(prompt):
 def filter_code_files(prompt):
     response = Application.call(
     response = Application.call(
-        # 若没有配置环境变量,可用百炼API Key将下行替换为:api_key="sk-xxx"。但不建议在生产环境中直接将API Key硬编码到代码中,以减少API Key泄露风险。
         api_key=ai_key,
         api_key=ai_key,
-        app_id='b0725a23eafd4422bfa7d5eff278af7c',
+        app_id=filter_code_files_app_id,
         prompt=prompt)
         prompt=prompt)
     if response.status_code == HTTPStatus.OK:
     if response.status_code == HTTPStatus.OK:
         try:
         try:
@@ -102,9 +99,8 @@ def analysis_results(local_path,path):
         for line_num, line in enumerate(f, start=1):
         for line_num, line in enumerate(f, start=1):
             prompt+=f"{line_num}\t{line}"
             prompt+=f"{line_num}\t{line}"
     response = Application.call(
     response = Application.call(
-        # 若没有配置环境变量,可用百炼API Key将下行替换为:api_key="sk-xxx"。但不建议在生产环境中直接将API Key硬编码到代码中,以减少API Key泄露风险。
         api_key=ai_key,
         api_key=ai_key,
-        app_id='b6edb4f5ff1c49f9855af27b14a0e8b4',
+        app_id=analysis_results_app_id,
         prompt=prompt)
         prompt=prompt)
     if response.status_code == HTTPStatus.OK:
     if response.status_code == HTTPStatus.OK:
         try:
         try:
@@ -153,7 +149,6 @@ async def get_code_files(path):
     chunks = [files[i * 500: (i + 1) * 500]
     chunks = [files[i * 500: (i + 1) * 500]
               for i in range(0, len(files) // 500 + 1)]
               for i in range(0, len(files) // 500 + 1)]
     # 提交所有批次任务
     # 提交所有批次任务
-    # futures = [executor.submit(process_batch1, chunk) for chunk in chunks]
     tasks = [process_batch1(chunk) for chunk in chunks]
     tasks = [process_batch1(chunk) for chunk in chunks]
     futures = await asyncio.gather(*tasks, return_exceptions=True)
     futures = await asyncio.gather(*tasks, return_exceptions=True)
     # 实时获取已完成任务的结果
     # 实时获取已完成任务的结果
@@ -244,7 +239,6 @@ async def summaryCommit(request: RequestCommit, background_tasks: BackgroundTask
     repo_commit_hash=repo_commit.repo_hash
     repo_commit_hash=repo_commit.repo_hash
     print(f"开始提交仓库: {repo_name}")
     print(f"开始提交仓库: {repo_name}")
     await Commit_Summary_Tasks.filter(id=request.task_id).update(start_time=int(time.time() * 1000))
     await Commit_Summary_Tasks.filter(id=request.task_id).update(start_time=int(time.time() * 1000))
-    # commit_content = Repo(local_path).git.log('-1', '-p', '--pretty=format:%h %s')
     commit_content = Repo(local_path).git.diff(f"{repo_commit_hash}^", repo_commit_hash)
     commit_content = Repo(local_path).git.diff(f"{repo_commit_hash}^", repo_commit_hash)
     background_tasks.add_task(commit_task,commit_content, request.task_id)
     background_tasks.add_task(commit_task,commit_content, request.task_id)
     return {"code": 200, "msg": "添加提交任务成功"}
     return {"code": 200, "msg": "添加提交任务成功"}

+ 0 - 13
demo.py

@@ -15,19 +15,6 @@ from db_config import TORTOISE_ORM
 app = FastAPI()
 app = FastAPI()
 monkey_patch_for_docs_ui(app)
 monkey_patch_for_docs_ui(app)
 register_tortoise(app=app, config=TORTOISE_ORM)
 register_tortoise(app=app, config=TORTOISE_ORM)
-
-@app.get("/user/{id}")
-async def test(id: int):
-    user= await Users.get(id=id)
-    print(type(user))
-    return user
-
-@app.get("/task/{id}")
-async def test(id: int):
-    task = await Scan_Tasks.create(repo_id=1, state=1,result={"a":1}, create_time=1234567890,scan_start_time=1234567890,scan_end_time=1234567890,create_user="admin",repo_hash="1234567890")
-    print(type(task))
-    return task
-
 app.include_router(gitrouter,prefix="/git")
 app.include_router(gitrouter,prefix="/git")
 app.include_router(testapi,prefix="/test")
 app.include_router(testapi,prefix="/test")
 
 

+ 13 - 23
mcp_/client.py

@@ -1,12 +1,14 @@
 import json
 import json
 import asyncio
 import asyncio
+import re
 from typing import Optional
 from typing import Optional
 from contextlib import AsyncExitStack
 from contextlib import AsyncExitStack
 from http import HTTPStatus
 from http import HTTPStatus
 
 
 from mcp import ClientSession, StdioServerParameters
 from mcp import ClientSession, StdioServerParameters
 from mcp.client.stdio import stdio_client
 from mcp.client.stdio import stdio_client
-from dashscope import Application, Generation
+from dashscope import Application
+from base_config import mcp_key, mcp_app_id
 # from dotenv import load_dotenv
 # from dotenv import load_dotenv
 
 
 # load_dotenv()
 # load_dotenv()
@@ -19,8 +21,8 @@ class MCPClient:
         self.session: Optional[ClientSession] = None
         self.session: Optional[ClientSession] = None
         self.exit_stack = AsyncExitStack()
         self.exit_stack = AsyncExitStack()
         self.session_id = session_id
         self.session_id = session_id
-        self.api_key = "sk-0164613e1a2143fc808fc4cc2451bef0"
-        self.app_id = "b424f3fa1d4544d68579671d70596f29"
+        self.api_key = mcp_key
+        self.app_id = mcp_app_id
         self.model = "qwen-max"
         self.model = "qwen-max"
         self.SCRIPT_PATH = "mcp_/server.py"
         self.SCRIPT_PATH = "mcp_/server.py"
         self.SYSTEM_PROMPT = """
         self.SYSTEM_PROMPT = """
@@ -88,17 +90,19 @@ class MCPClient:
 
 
     async def process_query(self, query: str) -> str:
     async def process_query(self, query: str) -> str:
         bailian_response = await self._call_bailian_api(query)
         bailian_response = await self._call_bailian_api(query)
-        try:
-            tool_call = json.loads(bailian_response)
-            print(tool_call)
+        # print("respone:"+bailian_response)
+        json_str = re.findall( r'\{.*\}', bailian_response, re.S)
+        # print(json_str)
+        if json_str:
+            tool_call=json.loads(json_str[0])
             if "tool" in tool_call:
             if "tool" in tool_call:
                 result = await self.session.call_tool(
                 result = await self.session.call_tool(
                     tool_call["tool"],
                     tool_call["tool"],
                     tool_call["arguments"]
                     tool_call["arguments"]
                 )
                 )
-                bailian_response = await self._call_bailian_api("tool_response:"+result.content[0].text)
-                return bailian_response
-        except json.JSONDecodeError:
+                final_answer = await self._call_bailian_api("tool_response:"+result.content[0].text)
+                return final_answer
+        else:
             return bailian_response
             return bailian_response
 
 
     async def chat_loop(self):
     async def chat_loop(self):
@@ -114,17 +118,3 @@ class MCPClient:
     async def cleanup(self):
     async def cleanup(self):
         await self.exit_stack.aclose()
         await self.exit_stack.aclose()
 
 
-
-async def main():
-
-    client = MCPClient()
-    try:
-        await client.connect_to_server()
-        print("Connected to server!")
-        await client.chat_loop()
-    finally:
-        await client.cleanup()
-
-
-if __name__ == "__main__":
-    asyncio.run(main())