import asyncio
import json
import os
import logging
from contextlib import AsyncExitStack
from typing import Optional, List, Dict
import nest_asyncio
from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from openai import AsyncOpenAI
nest_asyncio.apply()
load_dotenv()
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "WARNING").upper(),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class MCPClient:
def __init__(self):
self.max_loop = 10
self.messages = []
self.exit_stack = AsyncExitStack()
self.sessions: List[ClientSession] = []
self.available_tools: List[dict] = []
self.tool_session_mapping: Dict[str, ClientSession] = {}
self.client = AsyncOpenAI(api_key=os.getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com")
async def connect_to_server(self, server_name: str, server_config: dict):
"""
Connect to an MCP server
Args:
server_name: Name of the server
server_config: Configuration for the server
:return:
"""
try:
server_params = StdioServerParameters(
command=server_config["command"],
args=server_config["args"],
env=server_config.get("env"),
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
self.sessions.append(session)
logger.debug(f"Connected to server: {server_name}")
await session.initialize()
response = await session.list_tools()
logger.debug(f"Tool Response: {response}")
for tool in response.tools:
self.tool_session_mapping[tool.name] = session
self.available_tools.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
})
except Exception as e:
logger.error(f"Error connecting to server {server_name}: {str(e)}")
raise
async def connect_to_servers(self):
"""
Connect to multiple MCP servers
"""
try:
with open("servers.json", "r") as f:
data = json.load(f)
servers = data.get("mcpServers", {})
for server_name, server_config in servers.items():
await self.connect_to_server(server_name, server_config)
except FileNotFoundError:
logger.error("No servers.json file found")
raise
except Exception as e:
logger.error(f"Error connecting to servers: {str(e)}")
raise
async def process_query(self, query: str, system_prompt: Optional[str] = None) -> None:
"""
Process a query using DeepSeek and available MCP tools
:param query: The query user wants to process
:param system_prompt: Optional system prompt
:return:
"""
if system_prompt:
self.messages.append({"role": "system", "content": system_prompt})
self.messages.append({"role": "user", "content": query})
for i in range(self.max_loop):
try:
response = await self.client.chat.completions.create(
model="deepseek-chat",
messages=self.messages,
tools=self.available_tools,
tool_choice="auto"
)
logger.debug(f"Chat Completion Response: {response}")
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
raise
response_message = response.choices[0].message
self.messages.append({
"role": "assistant",
"content": response_message.content,
"tool_calls": response_message.tool_calls
})
if not response_message.tool_calls:
print("AI: ", response_message.content)
break
for tool_call in response_message.tool_calls:
tool_name = tool_call.function.name
try:
arguments = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
arguments = {}
print(f"Calling tool: [{tool_name}] with arguments: {arguments}")
try:
session = self.tool_session_mapping[tool_name]
mcp_result = await session.call_tool(
name=tool_name,
arguments=arguments
)
logger.debug(f"MCP Tool call Result: {mcp_result}")
if hasattr(mcp_result, 'content'):
if isinstance(mcp_result.content, list):
result_content = "\n".join([str(item) for item in mcp_result.content])
else:
result_content = str(mcp_result.content)
else:
result_content = str(mcp_result)
self.messages.append({
"role": "tool",
"name": tool_name,
"content": result_content,
"tool_call_id": tool_call.id
})
except Exception as e:
error_msg = f"Error: {str(e)}"
logger.error(error_msg)
self.messages.append({
"role": "tool",
"name": tool_name,
"content": error_msg,
"tool_call_id": tool_call.id
})
async def chat_loop(self):
"""
Run an interactive chat loop
"""
print("\nMCP Client Started!")
print("Type 'exit' or 'quit' to quit.\n")
while True:
try:
query = input("\nYou: ").strip()
if query.lower() == "exit" or query.lower() == "quit":
break
if not query:
continue
if self.messages:
logger.debug(f"Messages[{len(self.messages)}]:")
for message in self.messages:
logger.debug(message)
await self.process_query(query)
except Exception as e:
logger.error(f"Error: {str(e)}")
async def cleanup(self):
await self.exit_stack.aclose()
async def main():
client = MCPClient()
try:
await client.connect_to_servers()
await client.chat_loop()
except KeyboardInterrupt:
print("Goodbye!")
finally:
await client.cleanup()
if __name__ == "__main__":
asyncio.run(main())