Browse Source

mcp模块

fulian23 3 months ago
parent
commit
789bd18cd7
3 changed files with 559 additions and 0 deletions
  1. 3 0
      api/aiRouter.py
  2. 136 0
      mcp/client.py
  3. 420 0
      mcp/server.py

+ 3 - 0
api/aiRouter.py

@@ -44,6 +44,7 @@ async def file_summary(content):
             print(json_data)
         except json.JSONDecodeError:
             print("返回内容不是有效的 JSON 格式!")
+            print(response.output.text)
             json_data = {"summary": []}
     else:
         print(f"请求失败: {response.message}")
@@ -62,6 +63,7 @@ async def commit_summary(content):
             print(json_data)
         except json.JSONDecodeError:
             print("返回内容不是有效的 JSON 格式!")
+            print(response.output.text)
             json_data = {"null": []}
     else:
         print(f"请求失败: {response.message}")
@@ -80,6 +82,7 @@ def filter_code_files(prompt):
             print(json_data)
         except json.JSONDecodeError:
             print("返回内容不是有效的 JSON 格式!")
+            print(response.output.text)
             json_data={"files":[]}
     else:
         print(f"请求失败: {response.message}")

+ 136 - 0
mcp/client.py

@@ -0,0 +1,136 @@
+import json
+import asyncio
+from typing import Optional
+from contextlib import AsyncExitStack
+from http import HTTPStatus
+
+from mcp import ClientSession, StdioServerParameters
+from mcp.client.stdio import stdio_client
+from dashscope import Application, Generation
+# from dotenv import load_dotenv
+
+# load_dotenv()
+
+SCRIPT_PATH = "server.py"
+
+SYSTEM_PROMPT = """您是一个可执行git操作的AI助手,可用工具:
+{tools_description}
+
+响应规则:
+1. 需要工具时必须返回严格JSON格式:
+{{
+    "tool": "工具名",
+    "arguments": {{参数键值对}}
+}}
+2. 不需要工具时直接回复自然语言
+3. 工具列表:
+{available_tools}"""
+class MCPClient:
+    def __init__(self):
+        self.session: Optional[ClientSession] = None
+        self.exit_stack = AsyncExitStack()
+        self.session_id = None
+        self.api_key = "sk-0164613e1a2143fc808fc4cc2451bef0"
+        self.app_id = "b424f3fa1d4544d68579671d70596f29"
+        self.model = "qwen-max"
+        self.SYSTEM_PROMPT = SYSTEM_PROMPT
+
+    async def connect_to_server(self):
+        server_params = StdioServerParameters(
+            command="python",
+            args=[SCRIPT_PATH]
+        )
+        stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
+        self.stdio, self.write = stdio_transport
+        self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
+
+        await self.session.initialize()
+
+
+        response = await self.session.list_tools()
+        tools = response.tools
+        print("\nConnected to server with tools:", [tool.name for tool in tools])
+
+    async def _call_bailian_api(self, prompt: str) -> str:
+
+        try:
+            tool_response = await self.session.list_tools()
+            tools_desc = "\n".join([f"- {t.name}: {t.description}" for t in tool_response.tools])
+
+            full_prompt = self.SYSTEM_PROMPT.format(
+                tools_description=tools_desc,
+                available_tools=[t.name for t in tool_response.tools]
+            )
+
+            response = await asyncio.to_thread(
+                Application.call,
+                api_key=self.api_key,
+                app_id=self.app_id,
+                prompt=full_prompt + "\n用户提问:" + prompt,
+                session_id=self.session_id,
+                stream=False,
+                incremental_output=False,
+                enable_mcp=True,
+                tool_choice="auto"
+            )
+            print(full_prompt + "\n用户提问:" + prompt)
+            if response.output.session_id:
+                self.session_id = response.output.session_id
+
+            if response.status_code == HTTPStatus.OK:
+                return response.output.text
+            return f"API Error: {response.message}"
+
+        except Exception as e:
+            print(e)
+            return f"调用异常: {str(e)}"
+
+    async def process_query(self, query: str) -> str:
+        bailian_response = await self._call_bailian_api(query)
+        try:
+            tool_call = json.loads(bailian_response)
+            print(tool_call)
+            if "tool" in tool_call:
+                result = await self.session.call_tool(
+                    tool_call["tool"],
+                    tool_call["arguments"]
+                )
+                bailian_response = await self._call_bailian_api("tool_response:"+result.content[0].text)
+                return bailian_response
+        except json.JSONDecodeError:
+            return bailian_response
+
+    async def chat_loop(self):
+        print("\nMCP Client Started!")
+        print("Type your queries or 'quit' to exit.")
+
+        while True:
+            try:
+                query = input("\nQuery: ").strip()
+
+                if query.lower() == 'quit':
+                    break
+
+                response = await self.process_query(query)
+                print("\n" + response)
+
+            except Exception as e:
+                print(f"\nError: {str(e)}")
+
+    async def cleanup(self):
+        await self.exit_stack.aclose()
+
+
+async def main():
+
+    client = MCPClient()
+    try:
+        await client.connect_to_server()
+        print("Connected to server!")
+        await client.chat_loop()
+    finally:
+        await client.cleanup()
+
+
+if __name__ == "__main__":
+    asyncio.run(main())

+ 420 - 0
mcp/server.py

@@ -0,0 +1,420 @@
+import logging
+from pathlib import Path
+from typing import Optional
+from enum import Enum
+
+import git
+from mcp.types import TextContent
+from mcp.server.fastmcp import FastMCP
+
+
+# --- Pydantic Models are REMOVED ---
+
+# --- GitTools Enum (remains the same) ---
+class GitTools(str, Enum):
+    STATUS = "git_status"
+    DIFF_UNSTAGED = "git_diff_unstaged"
+    DIFF_STAGED = "git_diff_staged"
+    DIFF = "git_diff"
+    COMMIT = "git_commit"
+    ADD = "git_add"
+    RESET = "git_reset"
+    LOG = "git_log"
+    CREATE_BRANCH = "git_create_branch"
+    CHECKOUT = "git_checkout"
+    SHOW = "git_show"
+    INIT = "git_init"
+
+# --- Low-level Git Functions (remain the same, using gitpython.Repo) ---
+# Note: Type hint updated to gitpython.Repo
+def _get_repo(repo_path: str | Path) -> git.Repo:
+    """Helper to get Repo object and handle errors."""
+    try:
+        return git.Repo(str(repo_path))
+    except git.InvalidGitRepositoryError:
+        raise ValueError(f"'{repo_path}' is not a valid Git repository.")
+    except git.NoSuchPathError:
+         raise ValueError(f"Repository path '{repo_path}' does not exist.")
+    except Exception as e:
+        raise ValueError(f"Error accessing repository at '{repo_path}': {e}")
+
+def git_status(repo: git.Repo) -> str:
+    return repo.git.status()
+
+def git_diff_unstaged(repo: git.Repo) -> str:
+    return repo.git.diff()
+
+def git_diff_staged(repo: git.Repo) -> str:
+    return repo.git.diff("--cached")
+
+def git_diff(repo: git.Repo, target: str) -> str:
+    return repo.git.diff(target)
+
+def git_commit(repo: git.Repo, message: str) -> str:
+    try:
+        # Check if there's anything staged to commit *before* attempting
+        # This prevents errors if the index matches HEAD but there are unstaged changes
+        if not repo.index.diff("HEAD"):
+             return "No changes added to commit (working tree clean or changes not staged)."
+        commit = repo.index.commit(message)
+        return f"Changes committed successfully with hash {commit.hexsha}"
+    except Exception as e:
+        # Catch potential errors during commit check or commit itself
+        return f"Error committing changes: {str(e)}"
+
+def git_add(repo: git.Repo, files: list[str]) -> str:
+    try:
+        repo.index.add(files)
+        return f"Files staged successfully: {', '.join(files)}"
+    except FileNotFoundError as e:
+         return f"Error staging files: File not found - {e.filename}"
+    except Exception as e:
+        return f"Error staging files: {str(e)}"
+
+
+def git_reset(repo: git.Repo) -> str:
+    try:
+        # Resetting the index to HEAD (unstaging all)
+        repo.index.reset()
+        return "All staged changes reset (unstaged)"
+    except Exception as e:
+        return f"Error resetting staged changes: {str(e)}"
+
+def git_log(repo: git.Repo, max_count: int = 10) -> list[str]:
+    commits = list(repo.iter_commits(max_count=max_count))
+    log = []
+    for commit in commits:
+        log.append(
+            f"Commit: {commit.hexsha}\n"
+            f"Author: {commit.author}\n"
+            f"Date: {commit.authored_datetime}\n"
+            f"Message: {commit.message.strip()}\n" # Use strip() for cleaner output
+        )
+    return log
+
+def git_create_branch(repo: git.Repo, branch_name: str, base_branch: Optional[str] = None) -> str:
+    try:
+        if base_branch:
+            try:
+                base = repo.refs[base_branch]
+            except IndexError:
+                try:
+                    base = repo.commit(base_branch)
+                except git.BadName:
+                    return f"Error: Base reference '{base_branch}' not found (neither branch nor commit)."
+                except Exception as e:
+                     return f"Error resolving base reference '{base_branch}': {str(e)}"
+        else:
+            base = repo.head.commit
+        if branch_name in repo.heads:
+            return f"Error: Branch '{branch_name}' already exists."
+
+        new_branch = repo.create_head(branch_name, base)
+        base_ref_name = getattr(base, 'name', base.hexsha) # Get branch name if possible, else hash
+        return f"Created branch '{new_branch.name}' based on '{base_ref_name}'"
+    except git.GitCommandError as e:
+        # Catch specific git errors if possible
+        return f"Error creating branch '{branch_name}': {e.stderr or e.stdout}"
+    except Exception as e:
+        return f"Error creating branch '{branch_name}': {str(e)}"
+
+
+def git_checkout(repo: git.Repo, branch_name: str) -> str:
+    try:
+        repo.git.checkout(branch_name)
+        return f"Switched to branch '{branch_name}'"
+    except git.GitCommandError as e:
+        # Provide more specific feedback
+        if "did not match any file(s) known to git" in (e.stderr or ""):
+             return f"Error: Branch or pathspec '{branch_name}' not found."
+        elif "Please commit your changes or stash them before you switch branches" in (e.stderr or ""):
+             return f"Error: Cannot checkout branch '{branch_name}'. You have unstaged changes. Please commit or stash them first."
+        else:
+            return f"Error checking out branch '{branch_name}': {e.stderr or e.stdout}"
+    except Exception as e:
+         return f"An unexpected error occurred during checkout: {str(e)}"
+
+
+def git_init(repo_path: str) -> str:
+    try:
+        # Check if it already exists and is a repo
+        target_path = Path(repo_path)
+        if target_path.exists() and target_path.joinpath(".git").is_dir():
+             # Check if it's actually a valid repo
+             try:
+                 existing_repo = git.Repo(repo_path)
+                 return f"Repository already exists at {existing_repo.git_dir}"
+             except git.InvalidGitRepositoryError:
+                 # Path exists, .git dir exists, but it's invalid. Allow re-init? Or error?
+                 # Let's error for safety. User can delete .git if they want re-init.
+                 return f"Error: An invalid Git repository structure already exists at '{repo_path}'. Remove '.git' folder to reinitialize."
+             except Exception as e:
+                 return f"Error checking existing repository at '{repo_path}': {e}"
+
+        # Initialize (mkdir=True handles non-existent parent dirs)
+        repo = git.Repo.init(path=repo_path, mkdir=True)
+        return f"Initialized empty Git repository in {repo.git_dir}"
+
+    except Exception as e:
+        return f"Error initializing repository at '{repo_path}': {str(e)}"
+
+def git_show(repo: git.Repo, revision: str) -> str:
+    try:
+        commit = repo.commit(revision)
+    except git.BadName:
+         return f"Error: Revision '{revision}' not found."
+    except Exception as e:
+        return f"Error finding revision '{revision}': {str(e)}"
+
+    output = [
+        f"Commit: {commit.hexsha}\n"
+        f"Author: {commit.author}\n"
+        f"Date: {commit.authored_datetime}\n"
+        f"Message:\n{commit.message.strip()}\n"
+    ]
+    try:
+        # Show diff against first parent, or initial commit diff
+        diffs = commit.diff(commit.parents[0] if commit.parents else git.NULL_TREE, create_patch=True)
+        if diffs:
+             output.append("\nChanges:\n")
+             for d in diffs:
+                # Use a safer way to decode, ignoring errors
+                diff_text = d.diff.decode('utf-8', errors='ignore') if d.diff else ""
+                a_path = d.a_path or (d.a_blob.path if d.a_blob else 'unknown')
+                b_path = d.b_path or (d.b_blob.path if d.b_blob else 'unknown')
+                output.append(f"--- a/{a_path}\n+++ b/{b_path}\n")
+                output.append(diff_text)
+        else:
+             # Check if it's the initial commit
+             if not commit.parents:
+                 output.append("\nChanges: (Initial commit)\n")
+                 # Show the initial tree content or a message indicating it's the first commit
+                 # For brevity, just stating it's initial might be enough.
+                 # Or iterate through tree: for item in commit.tree.traverse(): output.append(f"+ {item.path}\n")
+             else:
+                 output.append("\nNo changes in this commit compared to its parent.")
+
+
+    except Exception as e:
+        output.append(f"\nWarning: Could not generate diff for commit {commit.hexsha}: {str(e)}")
+
+    return "".join(output)
+
+
+# --- MCPFast Server Setup ---
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+# Initialize MCPFast
+mcp = FastMCP("git")
+
+# --- Tool Definitions using @mcp.tool() ---
+
+@mcp.tool()
+async def init_tool(repo_path: str) -> list[TextContent]:
+    """Initializes a Git repository at the specified path.
+
+    Args:
+        repo_path: The file system path where the Git repository should be initialized. Parent directories will be created if they don't exist.
+    """
+    logger.info(f"Executing tool: {GitTools.INIT} for path {repo_path}")
+    result = git_init(repo_path)
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def status_tool(repo_path: str) -> list[TextContent]:
+    """Gets the status of a Git repository (shows staged, unstaged, and untracked files).
+
+    Args:
+        repo_path: The file system path to the Git repository.
+    """
+    logger.info(f"Executing tool: {GitTools.STATUS} for repo {repo_path}")
+    try:
+        repo = _get_repo(repo_path)
+        status = git_status(repo)
+        result = f"Repository status for '{repo_path}':\n{status}"
+    except ValueError as e:
+        result = str(e)
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def diff_unstaged_tool(repo_path: str) -> list[TextContent]:
+    """Shows changes in the working directory that are not staged for commit.
+
+    Args:
+        repo_path: The file system path to the Git repository.
+    """
+    logger.info(f"Executing tool: {GitTools.DIFF_UNSTAGED} for repo {repo_path}")
+    try:
+        repo = _get_repo(repo_path)
+        diff = git_diff_unstaged(repo)
+        result = f"Unstaged changes (working directory vs index) in '{repo_path}':\n{diff or 'No unstaged changes.'}"
+    except ValueError as e:
+        result = str(e)
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def diff_staged_tool(repo_path: str) -> list[TextContent]:
+    """Shows changes that are staged for the next commit (compared to HEAD).
+
+    Args:
+        repo_path: The file system path to the Git repository.
+    """
+    logger.info(f"Executing tool: {GitTools.DIFF_STAGED} for repo {repo_path}")
+    try:
+        repo = _get_repo(repo_path)
+        diff = git_diff_staged(repo)
+        result = f"Staged changes (index vs HEAD) in '{repo_path}':\n{diff or 'No staged changes.'}"
+    except ValueError as e:
+        result = str(e)
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def diff_tool(repo_path: str, target: str) -> list[TextContent]:
+    """Shows differences between the current HEAD and a specified target (e.g., a branch, tag, or commit hash).
+
+    Args:
+        repo_path: The file system path to the Git repository.
+        target: The branch, tag, commit hash, or other refspec to compare HEAD against.
+    """
+    logger.info(f"Executing tool: {GitTools.DIFF} for repo {repo_path} against {target}")
+    try:
+        repo = _get_repo(repo_path)
+        diff = git_diff(repo, target)
+        result = f"Diff between HEAD and '{target}' in '{repo_path}':\n{diff or 'No differences found.'}"
+    except ValueError as e:
+        result = str(e)
+    except git.GitCommandError as e:
+        result = f"Error running git diff against '{target}': {e.stderr or e.stdout}"
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def add_tool(repo_path: str, files: list[str]) -> list[TextContent]:
+    """Adds specified file contents to the staging area (index) for the next commit.
+
+    Args:
+        repo_path: The file system path to the Git repository.
+        files: A list of file paths (relative to the repository root) to stage. Use '.' to stage all changes.
+    """
+    logger.info(f"Executing tool: {GitTools.ADD} for repo {repo_path}, files: {files}")
+    if not files:
+        return [TextContent(type="text", text="Error: No files specified to add.")]
+    try:
+        repo = _get_repo(repo_path)
+        result = git_add(repo, files)
+    except ValueError as e: # Catches errors from _get_repo
+        result = str(e)
+    except Exception as e: # Catch other potential errors during add
+        result = f"An unexpected error occurred during add: {str(e)}"
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def commit_tool(repo_path: str, message: str) -> list[TextContent]:
+    """Records changes staged in the index to the repository history.
+
+    Args:
+        repo_path: The file system path to the Git repository.
+        message: The commit message describing the changes.
+    """
+    logger.info(f"Executing tool: {GitTools.COMMIT} for repo {repo_path}")
+    try:
+        repo = _get_repo(repo_path)
+        # The check is now inside git_commit for cleaner tool function
+        result = git_commit(repo, message)
+    except ValueError as e: # Catches errors from _get_repo
+        result = str(e)
+    except Exception as e: # Catch other potential errors
+        result = f"An unexpected error occurred during commit: {str(e)}"
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def reset_tool(repo_path: str) -> list[TextContent]:
+    """Resets the staging area (index) to match the current HEAD commit, effectively unstaging all changes. Does not modify the working directory.
+
+    Args:
+        repo_path: The file system path to the Git repository.
+    """
+    logger.info(f"Executing tool: {GitTools.RESET} for repo {repo_path}")
+    try:
+        repo = _get_repo(repo_path)
+        result = git_reset(repo)
+    except ValueError as e:
+        result = str(e)
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def log_tool(repo_path: str, max_count: int = 10) -> list[TextContent]:
+    """Shows the commit history log for the current branch.
+
+    Args:
+        repo_path: The file system path to the Git repository.
+        max_count: The maximum number of commits to display (default: 10).
+    """
+    logger.info(f"Executing tool: {GitTools.LOG} for repo {repo_path}, max_count={max_count}")
+    try:
+        repo = _get_repo(repo_path)
+        logs = git_log(repo, max_count)
+        if not logs:
+            result = f"No commit history found for '{repo_path}' (possibly an empty repository)."
+        else:
+            result = f"Commit history for '{repo_path}' (last {len(logs)} commits):\n\n" + "\n\n".join(logs) # Add double newline for readability
+    except ValueError as e:
+        result = str(e)
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def create_branch_tool(repo_path: str, branch_name: str, base_branch: Optional[str] = None) -> list[TextContent]:
+    """Creates a new branch.
+
+    Args:
+        repo_path: The file system path to the Git repository.
+        branch_name: The name for the new branch.
+        base_branch: Optional. The existing branch or commit hash to base the new branch on. If omitted, defaults to the current HEAD.
+    """
+    logger.info(f"Executing tool: {GitTools.CREATE_BRANCH} for repo {repo_path}, branch: {branch_name}, base: {base_branch}")
+    try:
+        repo = _get_repo(repo_path)
+        result = git_create_branch(repo, branch_name, base_branch)
+    except ValueError as e:
+        result = str(e)
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def checkout_tool(repo_path: str, branch_name: str) -> list[TextContent]:
+    """Switches the working directory to a different branch.
+
+    Args:
+        repo_path: The file system path to the Git repository.
+        branch_name: The name of the branch to switch to.
+    """
+    logger.info(f"Executing tool: {GitTools.CHECKOUT} for repo {repo_path}, branch: {branch_name}")
+    try:
+        repo = _get_repo(repo_path)
+        result = git_checkout(repo, branch_name)
+    except ValueError as e:
+        result = str(e)
+    return [TextContent(type="text", text=result)]
+
+@mcp.tool()
+async def show_tool(repo_path: str, revision: str) -> list[TextContent]:
+    """Shows details (metadata and content changes) for a specific commit or object.
+
+    Args:
+        repo_path: The file system path to the Git repository.
+        revision: The commit hash, tag, or branch name to show details for (e.g., 'HEAD', 'main', 'v1.0', 'abcdef123').
+    """
+    logger.info(f"Executing tool: {GitTools.SHOW} for repo {repo_path}, revision: {revision}")
+    try:
+        repo = _get_repo(repo_path)
+        result = git_show(repo, revision)
+    except ValueError as e:
+        result = str(e)
+    return [TextContent(type="text", text=result)]
+
+
+if __name__ == "__main__":
+    logger.info("Starting Git Tool Server using MCPFast stdio transport...")
+    # Run using stdio transport. Ensure the environment calling this script
+    # is set up to communicate via stdin/stdout as expected by MCPFast stdio.
+    mcp.run(transport='stdio')
+    logger.info("MCPFast stdio transport finished.")