From 428d71a162480cac31d12db1e0ea54b3e8b5fd59 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 1 Aug 2025 08:26:37 +0200 Subject: [PATCH] fix: fix mcp clients --- internal/app/app.go | 1 + internal/llm/agent/agent.go | 5 ++-- internal/llm/agent/mcp-tools.go | 51 +++++++-------------------------- 3 files changed, 14 insertions(+), 43 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index f3362c7276389b6669d6c9977d3565f482a44062..ca48d3e47838b9eefb98e7e497c03f90a21fa7b5 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -256,6 +256,7 @@ func (app *App) InitCoderAgent() error { } var err error app.CoderAgent, err = agent.NewAgent( + app.globalCtx, coderAgentCfg, app.Permissions, app.Sessions, diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index af69167b56317e4d5204eb51a4d80089ea316c53..df2f0adf25cd4a0849800e2d9266e0c52cddba3a 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -67,6 +67,7 @@ type agent struct { agentCfg config.Agent sessions session.Service messages message.Service + mcpTools []McpTool tools *csync.LazySlice[tools.BaseTool] @@ -86,6 +87,7 @@ var agentPromptMap = map[string]prompt.PromptID{ } func NewAgent( + ctx context.Context, agentCfg config.Agent, // These services are needed in the tools permissions permission.Service, @@ -94,7 +96,6 @@ func NewAgent( history history.Service, lspClients map[string]*lsp.Client, ) (Service, error) { - ctx := context.Background() cfg := config.Get() var agentTool tools.BaseTool @@ -103,7 +104,7 @@ func NewAgent( if taskAgentCfg.ID == "" { return nil, fmt.Errorf("task agent not found in config") } - taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) + taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients) if err != nil { return nil, fmt.Errorf("failed to create task agent: %w", err) } diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index e17a5527fb46979a8cd056473b3bcd184c014d60..77149d85c82e2a700ede2a0ed8b226b6d952fbaf 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -20,9 +20,10 @@ import ( "github.com/mark3labs/mcp-go/mcp" ) -type mcpTool struct { +type McpTool struct { mcpName string tool mcp.Tool + client MCPClient mcpConfig config.MCPConfig permissions permission.Service workingDir string @@ -38,11 +39,11 @@ type MCPClient interface { Close() error } -func (b *mcpTool) Name() string { +func (b *McpTool) Name() string { return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name) } -func (b *mcpTool) Info() tools.ToolInfo { +func (b *McpTool) Info() tools.ToolInfo { required := b.tool.InputSchema.Required if required == nil { required = make([]string, 0) @@ -56,7 +57,6 @@ func (b *mcpTool) Info() tools.ToolInfo { } func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) { - defer c.Close() initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ @@ -93,7 +93,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t return tools.NewTextResponse(output), nil } -func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { +func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { sessionID, messageID := tools.GetContextValues(ctx) if sessionID == "" || messageID == "" { return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") @@ -114,43 +114,13 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, permission.ErrorPermissionDenied } - switch b.mcpConfig.Type { - case config.MCPStdio: - c, err := client.NewStdioMCPClient( - b.mcpConfig.Command, - b.mcpConfig.ResolvedEnv(), - b.mcpConfig.Args..., - ) - if err != nil { - return tools.NewTextErrorResponse(err.Error()), nil - } - return runTool(ctx, c, b.tool.Name, params.Input) - case config.MCPHttp: - c, err := client.NewStreamableHttpClient( - b.mcpConfig.URL, - transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()), - ) - if err != nil { - return tools.NewTextErrorResponse(err.Error()), nil - } - return runTool(ctx, c, b.tool.Name, params.Input) - case config.MCPSse: - c, err := client.NewSSEMCPClient( - b.mcpConfig.URL, - client.WithHeaders(b.mcpConfig.ResolvedHeaders()), - ) - if err != nil { - return tools.NewTextErrorResponse(err.Error()), nil - } - return runTool(ctx, c, b.tool.Name, params.Input) - } - - return tools.NewTextErrorResponse("invalid mcp type"), nil + return runTool(ctx, b.client, b.tool.Name, params.Input) } -func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool { - return &mcpTool{ +func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool { + return &McpTool{ mcpName: name, + client: c, tool: tool, mcpConfig: mcpConfig, permissions: permissions, @@ -179,9 +149,8 @@ func getTools(ctx context.Context, name string, m config.MCPConfig, permissions return stdioTools } for _, t := range tools.Tools { - stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir)) + stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir)) } - defer c.Close() return stdioTools }