mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7
  8	"github.com/kujtimiihoxha/termai/internal/config"
  9	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 10	"github.com/kujtimiihoxha/termai/internal/logging"
 11	"github.com/kujtimiihoxha/termai/internal/permission"
 12	"github.com/kujtimiihoxha/termai/internal/version"
 13
 14	"github.com/mark3labs/mcp-go/client"
 15	"github.com/mark3labs/mcp-go/mcp"
 16)
 17
 18type mcpTool struct {
 19	mcpName     string
 20	tool        mcp.Tool
 21	mcpConfig   config.MCPServer
 22	permissions permission.Service
 23}
 24
 25type MCPClient interface {
 26	Initialize(
 27		ctx context.Context,
 28		request mcp.InitializeRequest,
 29	) (*mcp.InitializeResult, error)
 30	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 31	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 32	Close() error
 33}
 34
 35func (b *mcpTool) Info() tools.ToolInfo {
 36	return tools.ToolInfo{
 37		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 38		Description: b.tool.Description,
 39		Parameters:  b.tool.InputSchema.Properties,
 40		Required:    b.tool.InputSchema.Required,
 41	}
 42}
 43
 44func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 45	defer c.Close()
 46	initRequest := mcp.InitializeRequest{}
 47	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 48	initRequest.Params.ClientInfo = mcp.Implementation{
 49		Name:    "termai",
 50		Version: version.Version,
 51	}
 52
 53	_, err := c.Initialize(ctx, initRequest)
 54	if err != nil {
 55		return tools.NewTextErrorResponse(err.Error()), nil
 56	}
 57
 58	toolRequest := mcp.CallToolRequest{}
 59	toolRequest.Params.Name = toolName
 60	var args map[string]any
 61	if err = json.Unmarshal([]byte(input), &input); err != nil {
 62		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 63	}
 64	toolRequest.Params.Arguments = args
 65	result, err := c.CallTool(ctx, toolRequest)
 66	if err != nil {
 67		return tools.NewTextErrorResponse(err.Error()), nil
 68	}
 69
 70	output := ""
 71	for _, v := range result.Content {
 72		if v, ok := v.(mcp.TextContent); ok {
 73			output = v.Text
 74		} else {
 75			output = fmt.Sprintf("%v", v)
 76		}
 77	}
 78
 79	return tools.NewTextResponse(output), nil
 80}
 81
 82func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 83	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 84	p := b.permissions.Request(
 85		permission.CreatePermissionRequest{
 86			Path:        config.WorkingDirectory(),
 87			ToolName:    b.Info().Name,
 88			Action:      "execute",
 89			Description: permissionDescription,
 90			Params:      params.Input,
 91		},
 92	)
 93	if !p {
 94		return tools.NewTextErrorResponse("permission denied"), nil
 95	}
 96
 97	switch b.mcpConfig.Type {
 98	case config.MCPStdio:
 99		c, err := client.NewStdioMCPClient(
100			b.mcpConfig.Command,
101			b.mcpConfig.Env,
102			b.mcpConfig.Args...,
103		)
104		if err != nil {
105			return tools.NewTextErrorResponse(err.Error()), nil
106		}
107		return runTool(ctx, c, b.tool.Name, params.Input)
108	case config.MCPSse:
109		c, err := client.NewSSEMCPClient(
110			b.mcpConfig.URL,
111			client.WithHeaders(b.mcpConfig.Headers),
112		)
113		if err != nil {
114			return tools.NewTextErrorResponse(err.Error()), nil
115		}
116		return runTool(ctx, c, b.tool.Name, params.Input)
117	}
118
119	return tools.NewTextErrorResponse("invalid mcp type"), nil
120}
121
122func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
123	return &mcpTool{
124		mcpName:     name,
125		tool:        tool,
126		mcpConfig:   mcpConfig,
127		permissions: permissions,
128	}
129}
130
131var mcpTools []tools.BaseTool
132
133func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
134	var stdioTools []tools.BaseTool
135	initRequest := mcp.InitializeRequest{}
136	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
137	initRequest.Params.ClientInfo = mcp.Implementation{
138		Name:    "termai",
139		Version: version.Version,
140	}
141
142	_, err := c.Initialize(ctx, initRequest)
143	if err != nil {
144		logging.Error("error initializing mcp client", "error", err)
145		return stdioTools
146	}
147	toolsRequest := mcp.ListToolsRequest{}
148	tools, err := c.ListTools(ctx, toolsRequest)
149	if err != nil {
150		logging.Error("error listing tools", "error", err)
151		return stdioTools
152	}
153	for _, t := range tools.Tools {
154		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
155	}
156	defer c.Close()
157	return stdioTools
158}
159
160func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
161	if len(mcpTools) > 0 {
162		return mcpTools
163	}
164	for name, m := range config.Get().MCPServers {
165		switch m.Type {
166		case config.MCPStdio:
167			c, err := client.NewStdioMCPClient(
168				m.Command,
169				m.Env,
170				m.Args...,
171			)
172			if err != nil {
173				logging.Error("error creating mcp client", "error", err)
174				continue
175			}
176
177			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
178		case config.MCPSse:
179			c, err := client.NewSSEMCPClient(
180				m.URL,
181				client.WithHeaders(m.Headers),
182			)
183			if err != nil {
184				logging.Error("error creating mcp client", "error", err)
185				continue
186			}
187			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
188		}
189	}
190
191	return mcpTools
192}