1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7
8 "github.com/kujtimiihoxha/termai/internal/config"
9 "github.com/kujtimiihoxha/termai/internal/llm/tools"
10 "github.com/kujtimiihoxha/termai/internal/logging"
11 "github.com/kujtimiihoxha/termai/internal/permission"
12 "github.com/kujtimiihoxha/termai/internal/version"
13
14 "github.com/mark3labs/mcp-go/client"
15 "github.com/mark3labs/mcp-go/mcp"
16)
17
18type mcpTool struct {
19 mcpName string
20 tool mcp.Tool
21 mcpConfig config.MCPServer
22 permissions permission.Service
23}
24
25type MCPClient interface {
26 Initialize(
27 ctx context.Context,
28 request mcp.InitializeRequest,
29 ) (*mcp.InitializeResult, error)
30 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
31 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
32 Close() error
33}
34
35func (b *mcpTool) Info() tools.ToolInfo {
36 return tools.ToolInfo{
37 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
38 Description: b.tool.Description,
39 Parameters: b.tool.InputSchema.Properties,
40 Required: b.tool.InputSchema.Required,
41 }
42}
43
44func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
45 defer c.Close()
46 initRequest := mcp.InitializeRequest{}
47 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
48 initRequest.Params.ClientInfo = mcp.Implementation{
49 Name: "termai",
50 Version: version.Version,
51 }
52
53 _, err := c.Initialize(ctx, initRequest)
54 if err != nil {
55 return tools.NewTextErrorResponse(err.Error()), nil
56 }
57
58 toolRequest := mcp.CallToolRequest{}
59 toolRequest.Params.Name = toolName
60 var args map[string]any
61 if err = json.Unmarshal([]byte(input), &input); err != nil {
62 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
63 }
64 toolRequest.Params.Arguments = args
65 result, err := c.CallTool(ctx, toolRequest)
66 if err != nil {
67 return tools.NewTextErrorResponse(err.Error()), nil
68 }
69
70 output := ""
71 for _, v := range result.Content {
72 if v, ok := v.(mcp.TextContent); ok {
73 output = v.Text
74 } else {
75 output = fmt.Sprintf("%v", v)
76 }
77 }
78
79 return tools.NewTextResponse(output), nil
80}
81
82func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
83 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
84 p := b.permissions.Request(
85 permission.CreatePermissionRequest{
86 Path: config.WorkingDirectory(),
87 ToolName: b.Info().Name,
88 Action: "execute",
89 Description: permissionDescription,
90 Params: params.Input,
91 },
92 )
93 if !p {
94 return tools.NewTextErrorResponse("permission denied"), nil
95 }
96
97 switch b.mcpConfig.Type {
98 case config.MCPStdio:
99 c, err := client.NewStdioMCPClient(
100 b.mcpConfig.Command,
101 b.mcpConfig.Env,
102 b.mcpConfig.Args...,
103 )
104 if err != nil {
105 return tools.NewTextErrorResponse(err.Error()), nil
106 }
107 return runTool(ctx, c, b.tool.Name, params.Input)
108 case config.MCPSse:
109 c, err := client.NewSSEMCPClient(
110 b.mcpConfig.URL,
111 client.WithHeaders(b.mcpConfig.Headers),
112 )
113 if err != nil {
114 return tools.NewTextErrorResponse(err.Error()), nil
115 }
116 return runTool(ctx, c, b.tool.Name, params.Input)
117 }
118
119 return tools.NewTextErrorResponse("invalid mcp type"), nil
120}
121
122func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
123 return &mcpTool{
124 mcpName: name,
125 tool: tool,
126 mcpConfig: mcpConfig,
127 permissions: permissions,
128 }
129}
130
131var mcpTools []tools.BaseTool
132
133func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
134 var stdioTools []tools.BaseTool
135 initRequest := mcp.InitializeRequest{}
136 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
137 initRequest.Params.ClientInfo = mcp.Implementation{
138 Name: "termai",
139 Version: version.Version,
140 }
141
142 _, err := c.Initialize(ctx, initRequest)
143 if err != nil {
144 logging.Error("error initializing mcp client", "error", err)
145 return stdioTools
146 }
147 toolsRequest := mcp.ListToolsRequest{}
148 tools, err := c.ListTools(ctx, toolsRequest)
149 if err != nil {
150 logging.Error("error listing tools", "error", err)
151 return stdioTools
152 }
153 for _, t := range tools.Tools {
154 stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
155 }
156 defer c.Close()
157 return stdioTools
158}
159
160func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
161 if len(mcpTools) > 0 {
162 return mcpTools
163 }
164 for name, m := range config.Get().MCPServers {
165 switch m.Type {
166 case config.MCPStdio:
167 c, err := client.NewStdioMCPClient(
168 m.Command,
169 m.Env,
170 m.Args...,
171 )
172 if err != nil {
173 logging.Error("error creating mcp client", "error", err)
174 continue
175 }
176
177 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
178 case config.MCPSse:
179 c, err := client.NewSSEMCPClient(
180 m.URL,
181 client.WithHeaders(m.Headers),
182 )
183 if err != nil {
184 logging.Error("error creating mcp client", "error", err)
185 continue
186 }
187 mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
188 }
189 }
190
191 return mcpTools
192}