1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7
8 "github.com/kujtimiihoxha/termai/internal/config"
9 "github.com/kujtimiihoxha/termai/internal/llm/tools"
10 "github.com/kujtimiihoxha/termai/internal/logging"
11 "github.com/kujtimiihoxha/termai/internal/permission"
12 "github.com/kujtimiihoxha/termai/internal/version"
13
14 "github.com/mark3labs/mcp-go/client"
15 "github.com/mark3labs/mcp-go/mcp"
16)
17
18type mcpTool struct {
19 mcpName string
20 tool mcp.Tool
21 mcpConfig config.MCPServer
22 permissions permission.Service
23}
24
25var logger = logging.Get()
26
27type MCPClient interface {
28 Initialize(
29 ctx context.Context,
30 request mcp.InitializeRequest,
31 ) (*mcp.InitializeResult, error)
32 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
33 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
34 Close() error
35}
36
37func (b *mcpTool) Info() tools.ToolInfo {
38 return tools.ToolInfo{
39 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
40 Description: b.tool.Description,
41 Parameters: b.tool.InputSchema.Properties,
42 Required: b.tool.InputSchema.Required,
43 }
44}
45
46func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
47 defer c.Close()
48 initRequest := mcp.InitializeRequest{}
49 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
50 initRequest.Params.ClientInfo = mcp.Implementation{
51 Name: "termai",
52 Version: version.Version,
53 }
54
55 _, err := c.Initialize(ctx, initRequest)
56 if err != nil {
57 return tools.NewTextErrorResponse(err.Error()), nil
58 }
59
60 toolRequest := mcp.CallToolRequest{}
61 toolRequest.Params.Name = toolName
62 var args map[string]any
63 if err = json.Unmarshal([]byte(input), &input); err != nil {
64 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
65 }
66 toolRequest.Params.Arguments = args
67 result, err := c.CallTool(ctx, toolRequest)
68 if err != nil {
69 return tools.NewTextErrorResponse(err.Error()), nil
70 }
71
72 output := ""
73 for _, v := range result.Content {
74 if v, ok := v.(mcp.TextContent); ok {
75 output = v.Text
76 } else {
77 output = fmt.Sprintf("%v", v)
78 }
79 }
80
81 return tools.NewTextResponse(output), nil
82}
83
84func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
85 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
86 p := b.permissions.Request(
87 permission.CreatePermissionRequest{
88 Path: config.WorkingDirectory(),
89 ToolName: b.Info().Name,
90 Action: "execute",
91 Description: permissionDescription,
92 Params: params.Input,
93 },
94 )
95 if !p {
96 return tools.NewTextErrorResponse("permission denied"), nil
97 }
98
99 switch b.mcpConfig.Type {
100 case config.MCPStdio:
101 c, err := client.NewStdioMCPClient(
102 b.mcpConfig.Command,
103 b.mcpConfig.Env,
104 b.mcpConfig.Args...,
105 )
106 if err != nil {
107 return tools.NewTextErrorResponse(err.Error()), nil
108 }
109 return runTool(ctx, c, b.tool.Name, params.Input)
110 case config.MCPSse:
111 c, err := client.NewSSEMCPClient(
112 b.mcpConfig.URL,
113 client.WithHeaders(b.mcpConfig.Headers),
114 )
115 if err != nil {
116 return tools.NewTextErrorResponse(err.Error()), nil
117 }
118 return runTool(ctx, c, b.tool.Name, params.Input)
119 }
120
121 return tools.NewTextErrorResponse("invalid mcp type"), nil
122}
123
124func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
125 return &mcpTool{
126 mcpName: name,
127 tool: tool,
128 mcpConfig: mcpConfig,
129 permissions: permissions,
130 }
131}
132
133var mcpTools []tools.BaseTool
134
135func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
136 var stdioTools []tools.BaseTool
137 initRequest := mcp.InitializeRequest{}
138 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
139 initRequest.Params.ClientInfo = mcp.Implementation{
140 Name: "termai",
141 Version: version.Version,
142 }
143
144 _, err := c.Initialize(ctx, initRequest)
145 if err != nil {
146 logger.Error("error initializing mcp client", "error", err)
147 return stdioTools
148 }
149 toolsRequest := mcp.ListToolsRequest{}
150 tools, err := c.ListTools(ctx, toolsRequest)
151 if err != nil {
152 logger.Error("error listing tools", "error", err)
153 return stdioTools
154 }
155 for _, t := range tools.Tools {
156 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
157 }
158 defer c.Close()
159 return stdioTools
160}
161
162func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
163 if len(mcpTools) > 0 {
164 return mcpTools
165 }
166 for name, m := range config.Get().MCPServers {
167 switch m.Type {
168 case config.MCPStdio:
169 c, err := client.NewStdioMCPClient(
170 m.Command,
171 m.Env,
172 m.Args...,
173 )
174 if err != nil {
175 logger.Error("error creating mcp client", "error", err)
176 continue
177 }
178
179 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
180 case config.MCPSse:
181 c, err := client.NewSSEMCPClient(
182 m.URL,
183 client.WithHeaders(m.Headers),
184 )
185 if err != nil {
186 logger.Error("error creating mcp client", "error", err)
187 continue
188 }
189 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
190 }
191 }
192
193 return mcpTools
194}