mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8
  9	"github.com/kujtimiihoxha/termai/internal/config"
 10	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 11	"github.com/kujtimiihoxha/termai/internal/permission"
 12	"github.com/kujtimiihoxha/termai/internal/version"
 13
 14	"github.com/mark3labs/mcp-go/client"
 15	"github.com/mark3labs/mcp-go/mcp"
 16)
 17
 18type mcpTool struct {
 19	mcpName   string
 20	tool      mcp.Tool
 21	mcpConfig config.MCPServer
 22}
 23
 24type MCPClient interface {
 25	Initialize(
 26		ctx context.Context,
 27		request mcp.InitializeRequest,
 28	) (*mcp.InitializeResult, error)
 29	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 30	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 31	Close() error
 32}
 33
 34func (b *mcpTool) Info() tools.ToolInfo {
 35	return tools.ToolInfo{
 36		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 37		Description: b.tool.Description,
 38		Parameters:  b.tool.InputSchema.Properties,
 39		Required:    b.tool.InputSchema.Required,
 40	}
 41}
 42
 43func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 44	defer c.Close()
 45	initRequest := mcp.InitializeRequest{}
 46	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 47	initRequest.Params.ClientInfo = mcp.Implementation{
 48		Name:    "termai",
 49		Version: version.Version,
 50	}
 51
 52	_, err := c.Initialize(ctx, initRequest)
 53	if err != nil {
 54		return tools.NewTextErrorResponse(err.Error()), nil
 55	}
 56
 57	toolRequest := mcp.CallToolRequest{}
 58	toolRequest.Params.Name = toolName
 59	var args map[string]any
 60	if err = json.Unmarshal([]byte(input), &input); err != nil {
 61		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 62	}
 63	toolRequest.Params.Arguments = args
 64	result, err := c.CallTool(ctx, toolRequest)
 65	if err != nil {
 66		return tools.NewTextErrorResponse(err.Error()), nil
 67	}
 68
 69	output := ""
 70	for _, v := range result.Content {
 71		if v, ok := v.(mcp.TextContent); ok {
 72			output = v.Text
 73		} else {
 74			output = fmt.Sprintf("%v", v)
 75		}
 76	}
 77
 78	return tools.NewTextResponse(output), nil
 79}
 80
 81func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 82	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 83	p := permission.Default.Request(
 84		permission.CreatePermissionRequest{
 85			Path:        config.WorkingDirectory(),
 86			ToolName:    b.Info().Name,
 87			Action:      "execute",
 88			Description: permissionDescription,
 89			Params:      params.Input,
 90		},
 91	)
 92	if !p {
 93		return tools.NewTextErrorResponse("permission denied"), nil
 94	}
 95
 96	switch b.mcpConfig.Type {
 97	case config.MCPStdio:
 98		c, err := client.NewStdioMCPClient(
 99			b.mcpConfig.Command,
100			b.mcpConfig.Env,
101			b.mcpConfig.Args...,
102		)
103		if err != nil {
104			return tools.NewTextErrorResponse(err.Error()), nil
105		}
106		return runTool(ctx, c, b.tool.Name, params.Input)
107	case config.MCPSse:
108		c, err := client.NewSSEMCPClient(
109			b.mcpConfig.URL,
110			client.WithHeaders(b.mcpConfig.Headers),
111		)
112		if err != nil {
113			return tools.NewTextErrorResponse(err.Error()), nil
114		}
115		return runTool(ctx, c, b.tool.Name, params.Input)
116	}
117
118	return tools.NewTextErrorResponse("invalid mcp type"), nil
119}
120
121func NewMcpTool(name string, tool mcp.Tool, mcpConfig config.MCPServer) tools.BaseTool {
122	return &mcpTool{
123		mcpName:   name,
124		tool:      tool,
125		mcpConfig: mcpConfig,
126	}
127}
128
129var mcpTools []tools.BaseTool
130
131func getTools(ctx context.Context, name string, m config.MCPServer, c MCPClient) []tools.BaseTool {
132	var stdioTools []tools.BaseTool
133	initRequest := mcp.InitializeRequest{}
134	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
135	initRequest.Params.ClientInfo = mcp.Implementation{
136		Name:    "termai",
137		Version: version.Version,
138	}
139
140	_, err := c.Initialize(ctx, initRequest)
141	if err != nil {
142		log.Printf("error initializing mcp client: %s", err)
143		return stdioTools
144	}
145	toolsRequest := mcp.ListToolsRequest{}
146	tools, err := c.ListTools(ctx, toolsRequest)
147	if err != nil {
148		log.Printf("error listing tools: %s", err)
149		return stdioTools
150	}
151	for _, t := range tools.Tools {
152		stdioTools = append(stdioTools, NewMcpTool(name, t, m))
153	}
154	defer c.Close()
155	return stdioTools
156}
157
158func GetMcpTools(ctx context.Context) []tools.BaseTool {
159	if len(mcpTools) > 0 {
160		return mcpTools
161	}
162	for name, m := range config.Get().MCPServers {
163		switch m.Type {
164		case config.MCPStdio:
165			c, err := client.NewStdioMCPClient(
166				m.Command,
167				m.Env,
168				m.Args...,
169			)
170			if err != nil {
171				log.Printf("error creating mcp client: %s", err)
172				continue
173			}
174
175			mcpTools = append(mcpTools, getTools(ctx, name, m, c)...)
176		case config.MCPSse:
177			c, err := client.NewSSEMCPClient(
178				m.URL,
179				client.WithHeaders(m.Headers),
180			)
181			if err != nil {
182				log.Printf("error creating mcp client: %s", err)
183				continue
184			}
185			mcpTools = append(mcpTools, getTools(ctx, name, m, c)...)
186		}
187	}
188
189	return mcpTools
190}