1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8
9 "github.com/kujtimiihoxha/termai/internal/config"
10 "github.com/kujtimiihoxha/termai/internal/llm/tools"
11 "github.com/kujtimiihoxha/termai/internal/permission"
12 "github.com/kujtimiihoxha/termai/internal/version"
13
14 "github.com/mark3labs/mcp-go/client"
15 "github.com/mark3labs/mcp-go/mcp"
16)
17
18type mcpTool struct {
19 mcpName string
20 tool mcp.Tool
21 mcpConfig config.MCPServer
22}
23
24type MCPClient interface {
25 Initialize(
26 ctx context.Context,
27 request mcp.InitializeRequest,
28 ) (*mcp.InitializeResult, error)
29 ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
30 CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
31 Close() error
32}
33
34func (b *mcpTool) Info() tools.ToolInfo {
35 return tools.ToolInfo{
36 Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
37 Description: b.tool.Description,
38 Parameters: b.tool.InputSchema.Properties,
39 Required: b.tool.InputSchema.Required,
40 }
41}
42
43func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
44 defer c.Close()
45 initRequest := mcp.InitializeRequest{}
46 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
47 initRequest.Params.ClientInfo = mcp.Implementation{
48 Name: "termai",
49 Version: version.Version,
50 }
51
52 _, err := c.Initialize(ctx, initRequest)
53 if err != nil {
54 return tools.NewTextErrorResponse(err.Error()), nil
55 }
56
57 toolRequest := mcp.CallToolRequest{}
58 toolRequest.Params.Name = toolName
59 var args map[string]any
60 if err = json.Unmarshal([]byte(input), &input); err != nil {
61 return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
62 }
63 toolRequest.Params.Arguments = args
64 result, err := c.CallTool(ctx, toolRequest)
65 if err != nil {
66 return tools.NewTextErrorResponse(err.Error()), nil
67 }
68
69 output := ""
70 for _, v := range result.Content {
71 if v, ok := v.(mcp.TextContent); ok {
72 output = v.Text
73 } else {
74 output = fmt.Sprintf("%v", v)
75 }
76 }
77
78 return tools.NewTextResponse(output), nil
79}
80
81func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
82 permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
83 p := permission.Default.Request(
84 permission.CreatePermissionRequest{
85 Path: config.WorkingDirectory(),
86 ToolName: b.Info().Name,
87 Action: "execute",
88 Description: permissionDescription,
89 Params: params.Input,
90 },
91 )
92 if !p {
93 return tools.NewTextErrorResponse("permission denied"), nil
94 }
95
96 switch b.mcpConfig.Type {
97 case config.MCPStdio:
98 c, err := client.NewStdioMCPClient(
99 b.mcpConfig.Command,
100 b.mcpConfig.Env,
101 b.mcpConfig.Args...,
102 )
103 if err != nil {
104 return tools.NewTextErrorResponse(err.Error()), nil
105 }
106 return runTool(ctx, c, b.tool.Name, params.Input)
107 case config.MCPSse:
108 c, err := client.NewSSEMCPClient(
109 b.mcpConfig.URL,
110 client.WithHeaders(b.mcpConfig.Headers),
111 )
112 if err != nil {
113 return tools.NewTextErrorResponse(err.Error()), nil
114 }
115 return runTool(ctx, c, b.tool.Name, params.Input)
116 }
117
118 return tools.NewTextErrorResponse("invalid mcp type"), nil
119}
120
121func NewMcpTool(name string, tool mcp.Tool, mcpConfig config.MCPServer) tools.BaseTool {
122 return &mcpTool{
123 mcpName: name,
124 tool: tool,
125 mcpConfig: mcpConfig,
126 }
127}
128
129var mcpTools []tools.BaseTool
130
131func getTools(ctx context.Context, name string, m config.MCPServer, c MCPClient) []tools.BaseTool {
132 var stdioTools []tools.BaseTool
133 initRequest := mcp.InitializeRequest{}
134 initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
135 initRequest.Params.ClientInfo = mcp.Implementation{
136 Name: "termai",
137 Version: version.Version,
138 }
139
140 _, err := c.Initialize(ctx, initRequest)
141 if err != nil {
142 log.Printf("error initializing mcp client: %s", err)
143 return stdioTools
144 }
145 toolsRequest := mcp.ListToolsRequest{}
146 tools, err := c.ListTools(ctx, toolsRequest)
147 if err != nil {
148 log.Printf("error listing tools: %s", err)
149 return stdioTools
150 }
151 for _, t := range tools.Tools {
152 stdioTools = append(stdioTools, NewMcpTool(name, t, m))
153 }
154 defer c.Close()
155 return stdioTools
156}
157
158func GetMcpTools(ctx context.Context) []tools.BaseTool {
159 if len(mcpTools) > 0 {
160 return mcpTools
161 }
162 for name, m := range config.Get().MCPServers {
163 switch m.Type {
164 case config.MCPStdio:
165 c, err := client.NewStdioMCPClient(
166 m.Command,
167 m.Env,
168 m.Args...,
169 )
170 if err != nil {
171 log.Printf("error creating mcp client: %s", err)
172 continue
173 }
174
175 mcpTools = append(mcpTools, getTools(ctx, name, m, c)...)
176 case config.MCPSse:
177 c, err := client.NewSSEMCPClient(
178 m.URL,
179 client.WithHeaders(m.Headers),
180 )
181 if err != nil {
182 log.Printf("error creating mcp client: %s", err)
183 continue
184 }
185 mcpTools = append(mcpTools, getTools(ctx, name, m, c)...)
186 }
187 }
188
189 return mcpTools
190}