Browse Source

扫描结果提交数据库

fulian23 4 months ago
parent
commit
ecd9fbaf38
8 changed files with 91 additions and 58 deletions
  1. 56 34
      api/aiRouter.py
  2. 2 0
      api/gitRouter.py
  3. 10 1
      demo.py
  4. 0 0
      git_log.json
  5. 0 5
      models/AIModels.py
  6. 16 0
      models/aiModels.py
  7. 1 0
      models/gitModels.py
  8. 6 18
      test/aitest.py

+ 56 - 34
api/aiRouter.py

@@ -1,13 +1,16 @@
-import os, json
+import os, json, time, asyncio
 from base_config import ai_key, path
 from base_config import ai_key, path
 from fastapi import APIRouter, BackgroundTasks
 from fastapi import APIRouter, BackgroundTasks
 from pathlib import Path
 from pathlib import Path
 from pydantic import BaseModel
 from pydantic import BaseModel
-from models.gitModels import Users
-from concurrent.futures import ThreadPoolExecutor
+from git import Repo
+
 from http import HTTPStatus
 from http import HTTPStatus
 from dashscope import Application
 from dashscope import Application
 
 
+from models.aiModels import Scan_Tasks
+from models.gitModels import Repos
+
 airouter = APIRouter()
 airouter = APIRouter()
 class RequestBody(BaseModel):
 class RequestBody(BaseModel):
     uuid: str
     uuid: str
@@ -59,11 +62,11 @@ def analysis_results(local_path,path):
         print(f"请求失败: {response.message}")
         print(f"请求失败: {response.message}")
         json_data = {"summary":None}
         json_data = {"summary":None}
     json_data["path"]=file_path
     json_data["path"]=file_path
-    print(json_data)
+
     return json_data
     return json_data
 
 
 
 
-def get_filtered_files(folder_path):
+async def get_filtered_files(folder_path):
     base_path = Path(folder_path).resolve()
     base_path = Path(folder_path).resolve()
     if not base_path.is_dir():
     if not base_path.is_dir():
         raise ValueError("无效的目录路径")
         raise ValueError("无效的目录路径")
@@ -76,58 +79,77 @@ def get_filtered_files(folder_path):
             rel_path = abs_path.relative_to(base_path)
             rel_path = abs_path.relative_to(base_path)
             file_list.append(str(rel_path))
             file_list.append(str(rel_path))
     return file_list
     return file_list
-def process_batch1(batch_files):
+async def process_batch1(batch_files):
     """多线程处理单个文件批次的函数"""
     """多线程处理单个文件批次的函数"""
     try:
     try:
         js = filter_code_files(str(batch_files))
         js = filter_code_files(str(batch_files))
-        return js.get("files", [])
+        return js["files"]
     except Exception as e:
     except Exception as e:
         print(f"处理批次时出错: {e}")
         print(f"处理批次时出错: {e}")
         return []
         return []
-def get_code_files(path):
+async def get_code_files(path):
     file_list = []
     file_list = []
-    files = get_filtered_files(path)
+    files = await get_filtered_files(path)
     print(files)
     print(files)
     print(f"找到 {len(files)} 个文件")
     print(f"找到 {len(files)} 个文件")
 
 
     # 将文件列表分块(每500个一组)
     # 将文件列表分块(每500个一组)
     chunks = [files[i * 500: (i + 1) * 500]
     chunks = [files[i * 500: (i + 1) * 500]
               for i in range(0, len(files) // 500 + 1)]
               for i in range(0, len(files) // 500 + 1)]
-    with ThreadPoolExecutor(max_workers=min(5, len(chunks))) as executor:
-        # 提交所有批次任务
-        futures = [executor.submit(process_batch1, chunk) for chunk in chunks]
-        # 实时获取已完成任务的结果
-        for future in futures:
-            try:
-                batch_result = future.result()
-                file_list.extend(batch_result)
-            except Exception as e:
-                print(f"获取结果时出错: {e}")
-    print(f"最终合并文件数: {len(file_list)}")
+    # 提交所有批次任务
+    # futures = [executor.submit(process_batch1, chunk) for chunk in chunks]
+    tasks = [process_batch1(chunk) for chunk in chunks]
+    futures = await asyncio.gather(*tasks, return_exceptions=True)
+    # 实时获取已完成任务的结果
+    for future in futures[0]:
+        if isinstance(future, Exception):
+            print(f"处理出错: {future}")
+        else:
+            file_list.append(future)
+
     return file_list
     return file_list
-def process_batch2(local_path,path):
+async def process_batch2(local_path,path):
     """多线程处理单个文件批次的函数"""
     """多线程处理单个文件批次的函数"""
     try:
     try:
+        # print(local_path, path)
         js = analysis_results(local_path,path)
         js = analysis_results(local_path,path)
         return js
         return js
     except Exception as e:
     except Exception as e:
+        print(11111)
         print(f"处理批次时出错: {e}")
         print(f"处理批次时出错: {e}")
         return {"summary":None}
         return {"summary":None}
-def analysis(local_path):
-    file_list = get_code_files(local_path)
+
+
+async def analysis(local_path, repo_id):
+    file_list = await get_code_files(local_path)
     print(file_list)
     print(file_list)
-    with ThreadPoolExecutor(max_workers=5) as executor:
-        futures = [executor.submit(process_batch2, local_path, file) for file in file_list]
-        for future in futures:
-            try:
-                batch_result = future.result()
-                file_list.extend(batch_result)
-            except Exception as e:
-                print(f"获取结果时出错: {e}")
+    results = []
+    tasks = [process_batch2(local_path, file) for file in file_list]  # 假设process_batch2已改为异步函数
+    batch_results = await asyncio.gather(*tasks, return_exceptions=True)
+
+    for result in batch_results:
+        if isinstance(result, Exception):
+            print(f"处理出错: {result}")
+            await write_to_db({"results": results}, repo_id, 3)
+        else:
+            results.append(result)
+
+    await write_to_db({"results": results}, repo_id, 2)
+    print("扫描完成")
+async def write_to_db(results_dict,repo_id,state):
+    await Scan_Tasks.filter(repo_id=repo_id).update(state=state, result=results_dict,scan_end_time=int(time.time()))
+
+
 
 
 @airouter.post("/scan")
 @airouter.post("/scan")
-async def ai(request: RequestBody, background_tasks: BackgroundTasks):
-    local_path, _ = generate_repo_path(request.uuid, request.repo_url)
-    background_tasks.add_task(analysis, local_path)
+async def scan(request: RequestBody, background_tasks: BackgroundTasks):
+    local_path, repo_name = generate_repo_path(request.uuid, request.repo_url)
+    repo_hash = Repo(local_path).head.commit.hexsha[:7]
+    repo = await Repos.get(name=repo_name)
+    repo_id = repo.id
+    print(f"开始扫描仓库: {repo_name}")
+    await Scan_Tasks.create(repo_id=repo_id, state=1, create_time=int(time.time()),scan_start_time=int(time.time())
+                            , create_user=request.uuid, repo_hash=repo_hash)
+    background_tasks.add_task(analysis, local_path, repo_id)
     return {"code": 200, "meg": "添加扫描任务成功"}
     return {"code": 200, "meg": "添加扫描任务成功"}
 
 

+ 2 - 0
api/gitRouter.py

@@ -58,7 +58,9 @@ async def clone_task(repo_url, local_path,uuid,repo_name):
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
         await loop.run_in_executor(None, Repo.clone_from, repo_url, local_path)
         await loop.run_in_executor(None, Repo.clone_from, repo_url, local_path)
         await Repos.filter(create_user=uuid,name=repo_name).update(path=local_path, state=1, update_time=current_time)
         await Repos.filter(create_user=uuid,name=repo_name).update(path=local_path, state=1, update_time=current_time)
+        print(f"克隆仓库成功: {repo_url}")
     except:
     except:
+        print(f"克隆仓库失败: {repo_url}")
         await Repos.filter(create_user=uuid,name=repo_name).update(path=local_path, state=0, update_time=current_time)
         await Repos.filter(create_user=uuid,name=repo_name).update(path=local_path, state=0, update_time=current_time)
         shutil.rmtree(local_path)
         shutil.rmtree(local_path)
 
 

+ 10 - 1
demo.py

@@ -1,7 +1,9 @@
 from fastapi_cdn_host import monkey_patch_for_docs_ui
 from fastapi_cdn_host import monkey_patch_for_docs_ui
 from fastapi import FastAPI
 from fastapi import FastAPI
 from uvicorn import run
 from uvicorn import run
-from models.gitModels import *
+
+from models.gitModels import Users
+from models.aiModels import Scan_Tasks
 
 
 from api.gitRouter import gitrouter
 from api.gitRouter import gitrouter
 from api.testapi import testapi
 from api.testapi import testapi
@@ -19,6 +21,13 @@ async def test(id: int):
     user= await Users.get(id=id)
     user= await Users.get(id=id)
     print(type(user))
     print(type(user))
     return user
     return user
+
+@app.get("/task/{id}")
+async def test(id: int):
+    task = await Scan_Tasks.create(repo_id=1, state=1,result={"a":1}, create_time=1234567890,scan_start_time=1234567890,scan_end_time=1234567890,create_user="admin",repo_hash="1234567890")
+    print(type(task))
+    return task
+
 app.include_router(gitrouter,prefix="/git")
 app.include_router(gitrouter,prefix="/git")
 app.include_router(testapi,prefix="/test")
 app.include_router(testapi,prefix="/test")
 
 

+ 0 - 0
git_log.json


+ 0 - 5
models/AIModels.py

@@ -1,5 +0,0 @@
-from tortoise.models import Model
-from tortoise import fields
-
-class File_Summary_Tasks(Model):
-    pass

+ 16 - 0
models/aiModels.py

@@ -0,0 +1,16 @@
+from tortoise.models import Model
+from tortoise import fields
+
+class Scan_Tasks(Model):
+    id = fields.IntField(pk=True)
+    repo_id = fields.IntField()
+    state = fields.IntField()
+    result = fields.JSONField()
+    create_time = fields.BigIntField()
+    scan_start_time = fields.BigIntField()
+    scan_end_time = fields.BigIntField()
+    create_user = fields.CharField(max_length=36)
+    repo_hash = fields.CharField(max_length=40)
+
+
+

+ 1 - 0
models/gitModels.py

@@ -13,6 +13,7 @@ class Users(Model):
     email = fields.CharField(max_length=50, null=True)
     email = fields.CharField(max_length=50, null=True)
     registTime = fields.BigIntField()
     registTime = fields.BigIntField()
 class Repos(Model):
 class Repos(Model):
+    id = fields.IntField(pk=True)
     name = fields.CharField(max_length=50)
     name = fields.CharField(max_length=50)
     state = fields.IntField()
     state = fields.IntField()
     create_user = fields.CharField(max_length=36)
     create_user = fields.CharField(max_length=36)

+ 6 - 18
test/aitest.py

@@ -1,21 +1,9 @@
-import json
-from http import HTTPStatus
-from dashscope import Application
 
 
-with open("output.txt", 'r', encoding="utf8") as f:
-    prompt = f.read()
+from models.AIModels import Scan_Tasks
 
 
-response = Application.call(
-    # 若没有配置环境变量,可用百炼API Key将下行替换为:api_key="sk-xxx"。但不建议在生产环境中直接将API Key硬编码到代码中,以减少API Key泄露风险。
-    api_key="sk-0164613e1a2143fc808fc4cc2451bef0",
-    app_id='2f288f146e2d492abb3fe22695e70635',  # 替换为实际的应用 ID
-    prompt=prompt)
+async def xxx():
+    s = await Scan_Tasks.get(id=1)
+    print(s)
 
 
-if response.status_code == HTTPStatus.OK:
-    try:
-        json_data = json.loads(response.output.text)
-        print(json_data)
-    except json.JSONDecodeError:
-        print("返回内容不是有效的 JSON 格式!")
-else:
-    print(f"请求失败: {response.message}")
+if __name__ == "__main__":
+    asyncio.run(xxx())