|
@@ -1,13 +1,16 @@
|
|
-import os, json
|
|
|
|
|
|
+import os, json, time, asyncio
|
|
from base_config import ai_key, path
|
|
from base_config import ai_key, path
|
|
from fastapi import APIRouter, BackgroundTasks
|
|
from fastapi import APIRouter, BackgroundTasks
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
-from models.gitModels import Users
|
|
|
|
-from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
+from git import Repo
|
|
|
|
+
|
|
from http import HTTPStatus
|
|
from http import HTTPStatus
|
|
from dashscope import Application
|
|
from dashscope import Application
|
|
|
|
|
|
|
|
+from models.aiModels import Scan_Tasks
|
|
|
|
+from models.gitModels import Repos
|
|
|
|
+
|
|
airouter = APIRouter()
|
|
airouter = APIRouter()
|
|
class RequestBody(BaseModel):
|
|
class RequestBody(BaseModel):
|
|
uuid: str
|
|
uuid: str
|
|
@@ -59,11 +62,11 @@ def analysis_results(local_path,path):
|
|
print(f"请求失败: {response.message}")
|
|
print(f"请求失败: {response.message}")
|
|
json_data = {"summary":None}
|
|
json_data = {"summary":None}
|
|
json_data["path"]=file_path
|
|
json_data["path"]=file_path
|
|
- print(json_data)
|
|
|
|
|
|
+
|
|
return json_data
|
|
return json_data
|
|
|
|
|
|
|
|
|
|
-def get_filtered_files(folder_path):
|
|
|
|
|
|
+async def get_filtered_files(folder_path):
|
|
base_path = Path(folder_path).resolve()
|
|
base_path = Path(folder_path).resolve()
|
|
if not base_path.is_dir():
|
|
if not base_path.is_dir():
|
|
raise ValueError("无效的目录路径")
|
|
raise ValueError("无效的目录路径")
|
|
@@ -76,58 +79,77 @@ def get_filtered_files(folder_path):
|
|
rel_path = abs_path.relative_to(base_path)
|
|
rel_path = abs_path.relative_to(base_path)
|
|
file_list.append(str(rel_path))
|
|
file_list.append(str(rel_path))
|
|
return file_list
|
|
return file_list
|
|
-def process_batch1(batch_files):
|
|
|
|
|
|
+async def process_batch1(batch_files):
|
|
"""多线程处理单个文件批次的函数"""
|
|
"""多线程处理单个文件批次的函数"""
|
|
try:
|
|
try:
|
|
js = filter_code_files(str(batch_files))
|
|
js = filter_code_files(str(batch_files))
|
|
- return js.get("files", [])
|
|
|
|
|
|
+ return js["files"]
|
|
except Exception as e:
|
|
except Exception as e:
|
|
print(f"处理批次时出错: {e}")
|
|
print(f"处理批次时出错: {e}")
|
|
return []
|
|
return []
|
|
-def get_code_files(path):
|
|
|
|
|
|
+async def get_code_files(path):
|
|
file_list = []
|
|
file_list = []
|
|
- files = get_filtered_files(path)
|
|
|
|
|
|
+ files = await get_filtered_files(path)
|
|
print(files)
|
|
print(files)
|
|
print(f"找到 {len(files)} 个文件")
|
|
print(f"找到 {len(files)} 个文件")
|
|
|
|
|
|
# 将文件列表分块(每500个一组)
|
|
# 将文件列表分块(每500个一组)
|
|
chunks = [files[i * 500: (i + 1) * 500]
|
|
chunks = [files[i * 500: (i + 1) * 500]
|
|
for i in range(0, len(files) // 500 + 1)]
|
|
for i in range(0, len(files) // 500 + 1)]
|
|
- with ThreadPoolExecutor(max_workers=min(5, len(chunks))) as executor:
|
|
|
|
- # 提交所有批次任务
|
|
|
|
- futures = [executor.submit(process_batch1, chunk) for chunk in chunks]
|
|
|
|
- # 实时获取已完成任务的结果
|
|
|
|
- for future in futures:
|
|
|
|
- try:
|
|
|
|
- batch_result = future.result()
|
|
|
|
- file_list.extend(batch_result)
|
|
|
|
- except Exception as e:
|
|
|
|
- print(f"获取结果时出错: {e}")
|
|
|
|
- print(f"最终合并文件数: {len(file_list)}")
|
|
|
|
|
|
+ # 提交所有批次任务
|
|
|
|
+ # futures = [executor.submit(process_batch1, chunk) for chunk in chunks]
|
|
|
|
+ tasks = [process_batch1(chunk) for chunk in chunks]
|
|
|
|
+ futures = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
+ # 实时获取已完成任务的结果
|
|
|
|
+ for future in futures[0]:
|
|
|
|
+ if isinstance(future, Exception):
|
|
|
|
+ print(f"处理出错: {future}")
|
|
|
|
+ else:
|
|
|
|
+ file_list.append(future)
|
|
|
|
+
|
|
return file_list
|
|
return file_list
|
|
-def process_batch2(local_path,path):
|
|
|
|
|
|
+async def process_batch2(local_path,path):
|
|
"""多线程处理单个文件批次的函数"""
|
|
"""多线程处理单个文件批次的函数"""
|
|
try:
|
|
try:
|
|
|
|
+ # print(local_path, path)
|
|
js = analysis_results(local_path,path)
|
|
js = analysis_results(local_path,path)
|
|
return js
|
|
return js
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
|
+ print(11111)
|
|
print(f"处理批次时出错: {e}")
|
|
print(f"处理批次时出错: {e}")
|
|
return {"summary":None}
|
|
return {"summary":None}
|
|
-def analysis(local_path):
|
|
|
|
- file_list = get_code_files(local_path)
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def analysis(local_path, repo_id):
|
|
|
|
+ file_list = await get_code_files(local_path)
|
|
print(file_list)
|
|
print(file_list)
|
|
- with ThreadPoolExecutor(max_workers=5) as executor:
|
|
|
|
- futures = [executor.submit(process_batch2, local_path, file) for file in file_list]
|
|
|
|
- for future in futures:
|
|
|
|
- try:
|
|
|
|
- batch_result = future.result()
|
|
|
|
- file_list.extend(batch_result)
|
|
|
|
- except Exception as e:
|
|
|
|
- print(f"获取结果时出错: {e}")
|
|
|
|
|
|
+ results = []
|
|
|
|
+ tasks = [process_batch2(local_path, file) for file in file_list] # 假设process_batch2已改为异步函数
|
|
|
|
+ batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
+
|
|
|
|
+ for result in batch_results:
|
|
|
|
+ if isinstance(result, Exception):
|
|
|
|
+ print(f"处理出错: {result}")
|
|
|
|
+ await write_to_db({"results": results}, repo_id, 3)
|
|
|
|
+ else:
|
|
|
|
+ results.append(result)
|
|
|
|
+
|
|
|
|
+ await write_to_db({"results": results}, repo_id, 2)
|
|
|
|
+ print("扫描完成")
|
|
|
|
+async def write_to_db(results_dict,repo_id,state):
|
|
|
|
+ await Scan_Tasks.filter(repo_id=repo_id).update(state=state, result=results_dict,scan_end_time=int(time.time()))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
|
|
@airouter.post("/scan")
|
|
@airouter.post("/scan")
|
|
-async def ai(request: RequestBody, background_tasks: BackgroundTasks):
|
|
|
|
- local_path, _ = generate_repo_path(request.uuid, request.repo_url)
|
|
|
|
- background_tasks.add_task(analysis, local_path)
|
|
|
|
|
|
+async def scan(request: RequestBody, background_tasks: BackgroundTasks):
|
|
|
|
+ local_path, repo_name = generate_repo_path(request.uuid, request.repo_url)
|
|
|
|
+ repo_hash = Repo(local_path).head.commit.hexsha[:7]
|
|
|
|
+ repo = await Repos.get(name=repo_name)
|
|
|
|
+ repo_id = repo.id
|
|
|
|
+ print(f"开始扫描仓库: {repo_name}")
|
|
|
|
+ await Scan_Tasks.create(repo_id=repo_id, state=1, create_time=int(time.time()),scan_start_time=int(time.time())
|
|
|
|
+ , create_user=request.uuid, repo_hash=repo_hash)
|
|
|
|
+ background_tasks.add_task(analysis, local_path, repo_id)
|
|
return {"code": 200, "meg": "添加扫描任务成功"}
|
|
return {"code": 200, "meg": "添加扫描任务成功"}
|
|
|
|
|