aiRouter.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import os, json, time, asyncio
  2. from base_config import ai_key, path, file_summary_app_id, commit_summary_app_id, filter_code_files_app_id, analysis_results_app_id
  3. from fastapi import APIRouter, BackgroundTasks
  4. from pathlib import Path
  5. from pydantic import BaseModel
  6. from git import Repo
  7. from http import HTTPStatus
  8. from dashscope import Application
  9. from models.aiModels import Scan_Tasks, Commit_Summary_Tasks, File_Summary_Tasks
  10. from mcp_.client import MCPClient
  11. airouter = APIRouter()
  12. class RequestCommit(BaseModel):
  13. task_id: str
  14. uuid: str
  15. repo_url: str
  16. class RequestScan(BaseModel):
  17. task_id: str
  18. uuid: str
  19. repo_url: str
  20. class RequestFile(BaseModel):
  21. task_id: str
  22. uuid: str
  23. repo_url: str
  24. file_path: str
  25. class RequestChat(BaseModel):
  26. uuid: str
  27. message: str
  28. def generate_repo_path(uuid, repo_url):
  29. repo_name = repo_url.split("/")[-1].replace(".git", "")
  30. base_path = os.path.join(path, uuid)
  31. return os.path.join(base_path, repo_name), repo_name
  32. async def file_summary(content):
  33. response = Application.call(
  34. api_key=ai_key,
  35. app_id=file_summary_app_id,
  36. prompt=content)
  37. if response.status_code == HTTPStatus.OK:
  38. try:
  39. json_data = json.loads(response.output.text)
  40. print(json_data)
  41. except json.JSONDecodeError:
  42. print("返回内容不是有效的 JSON 格式!")
  43. print(response.output.text)
  44. json_data = {"summary": []}
  45. else:
  46. print(f"请求失败: {response.message}")
  47. json_data = {"summary": []}
  48. return json_data
  49. async def commit_summary(content):
  50. response = Application.call(
  51. api_key=ai_key,
  52. app_id=commit_summary_app_id,
  53. prompt=content)
  54. if response.status_code == HTTPStatus.OK:
  55. try:
  56. json_data = json.loads(response.output.text)
  57. print(json_data)
  58. except json.JSONDecodeError:
  59. print("返回内容不是有效的 JSON 格式!")
  60. print(response.output.text)
  61. json_data = {"null": []}
  62. else:
  63. print(f"请求失败: {response.message}")
  64. json_data = {"null": []}
  65. return json_data
  66. def filter_code_files(prompt):
  67. response = Application.call(
  68. api_key=ai_key,
  69. app_id=filter_code_files_app_id,
  70. prompt=prompt)
  71. if response.status_code == HTTPStatus.OK:
  72. try:
  73. json_data = json.loads(response.output.text)
  74. print(json_data)
  75. except json.JSONDecodeError:
  76. print("返回内容不是有效的 JSON 格式!")
  77. print(response.output.text)
  78. json_data={"files":[]}
  79. else:
  80. print(f"请求失败: {response.message}")
  81. json_data = {"files": []}
  82. return json_data
  83. def analysis_results(local_path,path):
  84. prompt=""
  85. file_path=os.path.join(local_path,path)
  86. with open(file_path, 'r',encoding="utf8") as f:
  87. for line_num, line in enumerate(f, start=1):
  88. prompt+=f"{line_num}\t{line}"
  89. response = Application.call(
  90. api_key=ai_key,
  91. app_id=analysis_results_app_id,
  92. prompt=prompt)
  93. if response.status_code == HTTPStatus.OK:
  94. try:
  95. json_data = json.loads(response.output.text)
  96. except json.JSONDecodeError:
  97. print("返回内容不是有效的 JSON 格式!")
  98. print(response.output.text)
  99. json_data={"summary":None}
  100. else:
  101. print(f"请求失败: {response.message}")
  102. json_data = {"summary":None}
  103. json_data["path"]=path
  104. return json_data
  105. async def get_filtered_files(folder_path):
  106. base_path = Path(folder_path).resolve()
  107. if not base_path.is_dir():
  108. raise ValueError("无效的目录路径")
  109. file_list = []
  110. for root, dirs, files in os.walk(base_path):
  111. dirs[:] = [d for d in dirs if not d.startswith('.')]
  112. files = [f for f in files if not f.startswith('.')]
  113. for file in files:
  114. abs_path = Path(root) / file
  115. rel_path = abs_path.relative_to(base_path)
  116. file_list.append(str(rel_path))
  117. return file_list
  118. async def process_batch1(batch_files):
  119. """多线程处理单个文件批次的函数"""
  120. try:
  121. js = filter_code_files(str(batch_files))
  122. return js["files"]
  123. except Exception as e:
  124. print(f"处理批次时出错: {e}")
  125. return []
  126. async def get_code_files(path):
  127. file_list = []
  128. files = await get_filtered_files(path)
  129. print(files)
  130. print(f"找到 {len(files)} 个文件")
  131. # 将文件列表分块(每500个一组)
  132. chunks = [files[i * 500: (i + 1) * 500]
  133. for i in range(0, len(files) // 500 + 1)]
  134. # 提交所有批次任务
  135. tasks = [process_batch1(chunk) for chunk in chunks]
  136. futures = await asyncio.gather(*tasks, return_exceptions=True)
  137. # 实时获取已完成任务的结果
  138. for future in futures[0]:
  139. if isinstance(future, Exception):
  140. print(f"处理出错: {future}")
  141. else:
  142. file_list.append(future)
  143. return file_list
  144. async def process_batch2(local_path,path):
  145. """多线程处理单个文件批次的函数"""
  146. try:
  147. # print(local_path, path)
  148. js = analysis_results(local_path,path)
  149. return js
  150. except Exception as e:
  151. print(11111)
  152. print(f"处理批次时出错: {e}")
  153. return {"summary":None}
  154. async def analysis(local_path, task_id):
  155. file_list = await get_code_files(local_path)
  156. all_extensions = [
  157. "adoc", "asm", "awk", "bas", "bat", "bib", "c", "cbl", "cls", "clj",
  158. "cljc", "cljs", "cmd", "conf", "cpp", "cr", "cs", "css", "cxx", "dart",
  159. "dockerfile", "edn", "el", "env", "erl", "ex", "exs", "f", "f90", "f95",
  160. "fs", "fsscript", "fsi", "fsx", "g4", "gd", "gql", "graphql", "groovy",
  161. "gsh", "gvy", "h", "hbs", "hcl", "hh", "hl", "hpp", "hrl", "hs", "htm",
  162. "html", "hx", "ini", "jad", "jade", "java", "jl", "js", "json", "json5",
  163. "jsx", "kt", "kts", "less", "lfe", "lgt", "lhs", "log", "ltx", "lua",
  164. "m", "mjs", "ml", "mli", "mm", "nim", "nims", "nlogo", "pas", "php",
  165. "pl", "plantuml", "pro", "ps1", "pug", "puml", "py", "qml", "r", "rb",
  166. "re", "rei", "res", "resi", "rkt", "rs", "rst", "s", "sass", "scala",
  167. "scm", "scss", "sed", "sh", "sol", "sql", "ss", "st", "squeak", "swift",
  168. "tcl", "tex", "tf", "tfvars", "toml", "ts", "tsx", "txt", "v", "vb",
  169. "vbs", "vh", "vhd", "vhdl", "vim", "vue", "xml", "yaml", "yang", "yml"]
  170. file_list = [file for file in file_list if file.split(".")[-1] in all_extensions]
  171. results = []
  172. tasks = [process_batch2(local_path, file) for file in file_list]
  173. batch_results = await asyncio.gather(*tasks, return_exceptions=True)
  174. for result in batch_results:
  175. if isinstance(result, Exception):
  176. print(f"处理出错: {result}")
  177. await Scan_Tasks.filter(id=task_id).update(state=3, result={"results": results},
  178. scan_end_time=int(time.time() * 1000))
  179. else:
  180. results.append(result)
  181. await Scan_Tasks.filter(id=task_id).update(state=2, result={"results": results},
  182. scan_end_time=int(time.time() * 1000))
  183. print("扫描完成")
  184. async def commit_task(content,task_id):
  185. commit_result=await asyncio.gather(commit_summary(content), return_exceptions=True)
  186. if isinstance(commit_result, Exception):
  187. print(f"提交出错: {commit_result}")
  188. else:
  189. print("提交成功")
  190. await Commit_Summary_Tasks.filter(id=task_id).update(result=commit_result[0],end_time=int(time.time() * 1000))
  191. async def file_task(file_path,task_id):
  192. with open(file_path, 'r', encoding="utf8") as f:
  193. content = f.read()
  194. file_result = await asyncio.gather(file_summary(content), return_exceptions=True)
  195. if isinstance(file_result, Exception):
  196. print(f"提交出错: {file_result}")
  197. else:
  198. print("提交成功")
  199. await File_Summary_Tasks.filter(id=task_id).update(result=file_result[0], end_time=int(time.time() * 1000))
  200. @airouter.post("/scan")
  201. async def scan(request: RequestScan, background_tasks: BackgroundTasks):
  202. local_path, repo_name = generate_repo_path(request.uuid, request.repo_url)
  203. print(f"开始扫描仓库: {repo_name}")
  204. await Scan_Tasks.filter(id=request.task_id).update(state=1, scan_start_time=int(time.time() * 1000))
  205. background_tasks.add_task(analysis, local_path, request.task_id)
  206. return {"code": 200, "msg": "添加扫描任务成功"}
  207. @airouter.post("/summaryCommit")
  208. async def summaryCommit(request: RequestCommit, background_tasks: BackgroundTasks):
  209. local_path, repo_name = generate_repo_path(request.uuid, request.repo_url)
  210. repo_commit=await Commit_Summary_Tasks.get(id=request.task_id)
  211. repo_commit_hash=repo_commit.repo_hash
  212. print(f"开始提交仓库: {repo_name}")
  213. await Commit_Summary_Tasks.filter(id=request.task_id).update(start_time=int(time.time() * 1000))
  214. commit_content = Repo(local_path).git.diff(f"{repo_commit_hash}^", repo_commit_hash)
  215. background_tasks.add_task(commit_task,commit_content, request.task_id)
  216. return {"code": 200, "msg": "添加提交任务成功"}
  217. @airouter.post("/summaryFile")
  218. async def summaryFile(request: RequestFile,background_tasks: BackgroundTasks):
  219. await File_Summary_Tasks.filter(id=request.task_id).update(start_time=int(time.time() * 1000))
  220. background_tasks.add_task(file_task, request.file_path, request.task_id)
  221. return {"code": 200, "msg": "添加提交任务成功"}
  222. @airouter.post("/chat")
  223. async def chat(request: RequestChat):
  224. client = MCPClient(request.uuid)
  225. await client.connect_to_server()
  226. response = await client.process_query(request.message)
  227. await client.cleanup()
  228. return {"code": 200, "msg": response}