client.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import json
  2. import asyncio
  3. from typing import Optional
  4. from contextlib import AsyncExitStack
  5. from http import HTTPStatus
  6. from mcp import ClientSession, StdioServerParameters
  7. from mcp.client.stdio import stdio_client
  8. from dashscope import Application, Generation
  9. # from dotenv import load_dotenv
  10. # load_dotenv()
  11. SCRIPT_PATH = "server.py"
  12. SYSTEM_PROMPT = """您是一个可执行git操作的AI助手,可用工具:
  13. {tools_description}
  14. 响应规则:
  15. 1. 需要工具时必须返回严格JSON格式:
  16. {{
  17. "tool": "工具名",
  18. "arguments": {{参数键值对}}
  19. }}
  20. 2. 不需要工具时直接回复自然语言
  21. 3. 工具列表:
  22. {available_tools}"""
  23. class MCPClient:
  24. def __init__(self):
  25. self.session: Optional[ClientSession] = None
  26. self.exit_stack = AsyncExitStack()
  27. self.session_id = None
  28. self.api_key = "sk-0164613e1a2143fc808fc4cc2451bef0"
  29. self.app_id = "b424f3fa1d4544d68579671d70596f29"
  30. self.model = "qwen-max"
  31. self.SYSTEM_PROMPT = SYSTEM_PROMPT
  32. async def connect_to_server(self):
  33. server_params = StdioServerParameters(
  34. command="python",
  35. args=[SCRIPT_PATH]
  36. )
  37. stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
  38. self.stdio, self.write = stdio_transport
  39. self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
  40. await self.session.initialize()
  41. response = await self.session.list_tools()
  42. tools = response.tools
  43. print("\nConnected to server with tools:", [tool.name for tool in tools])
  44. async def _call_bailian_api(self, prompt: str) -> str:
  45. try:
  46. tool_response = await self.session.list_tools()
  47. tools_desc = "\n".join([f"- {t.name}: {t.description}" for t in tool_response.tools])
  48. full_prompt = self.SYSTEM_PROMPT.format(
  49. tools_description=tools_desc,
  50. available_tools=[t.name for t in tool_response.tools]
  51. )
  52. response = await asyncio.to_thread(
  53. Application.call,
  54. api_key=self.api_key,
  55. app_id=self.app_id,
  56. prompt=full_prompt + "\n用户提问:" + prompt,
  57. session_id=self.session_id,
  58. stream=False,
  59. incremental_output=False,
  60. enable_mcp=True,
  61. tool_choice="auto"
  62. )
  63. print(full_prompt + "\n用户提问:" + prompt)
  64. if response.output.session_id:
  65. self.session_id = response.output.session_id
  66. if response.status_code == HTTPStatus.OK:
  67. return response.output.text
  68. return f"API Error: {response.message}"
  69. except Exception as e:
  70. print(e)
  71. return f"调用异常: {str(e)}"
  72. async def process_query(self, query: str) -> str:
  73. bailian_response = await self._call_bailian_api(query)
  74. try:
  75. tool_call = json.loads(bailian_response)
  76. print(tool_call)
  77. if "tool" in tool_call:
  78. result = await self.session.call_tool(
  79. tool_call["tool"],
  80. tool_call["arguments"]
  81. )
  82. bailian_response = await self._call_bailian_api("tool_response:"+result.content[0].text)
  83. return bailian_response
  84. except json.JSONDecodeError:
  85. return bailian_response
  86. async def chat_loop(self):
  87. print("\nMCP Client Started!")
  88. print("Type your queries or 'quit' to exit.")
  89. while True:
  90. try:
  91. query = input("\nQuery: ").strip()
  92. if query.lower() == 'quit':
  93. break
  94. response = await self.process_query(query)
  95. print("\n" + response)
  96. except Exception as e:
  97. print(f"\nError: {str(e)}")
  98. async def cleanup(self):
  99. await self.exit_stack.aclose()
  100. async def main():
  101. client = MCPClient()
  102. try:
  103. await client.connect_to_server()
  104. print("Connected to server!")
  105. await client.chat_loop()
  106. finally:
  107. await client.cleanup()
  108. if __name__ == "__main__":
  109. asyncio.run(main())