refactor: improvements

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/llm/agent/agent.go        | 32 ++++++++++----
internal/llm/agent/mcp-tools.go    | 70 +++++++++++++++++--------------
internal/tui/components/mcp/mcp.go |  8 +-
3 files changed, 65 insertions(+), 45 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -1102,32 +1102,44 @@ func (a *agent) setupEvents(ctx context.Context) {
 					return
 				}
 				switch event.Payload.Type {
-				case MCPEventToolsListChanged, MCPEventPromptsListChanged:
+				case MCPEventPromptsListChanged:
 					name := event.Payload.Name
 					c, ok := mcpClients.Get(name)
 					if !ok {
 						slog.Warn("MCP client not found for tools/prompts update", "name", name)
 						continue
 					}
-					cfg := config.Get()
-					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
+					prompts, err := getPrompts(ctx, c)
 					if err != nil {
-						slog.Error("error listing tools", "error", err)
-						updateMCPState(name, MCPStateError, err, nil, 0, 0)
+						slog.Error("error listing prompts", "error", err)
+						updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
 						_ = c.Close()
 						continue
 					}
-					prompts, err := getPrompts(ctx, c)
+					updateMcpPrompts(name, prompts)
+					prevState, _ := mcpStates.Get(name)
+					prevState.Counts.Prompts = len(prompts)
+					updateMCPState(name, MCPStateConnected, nil, c, prevState.Counts)
+				case MCPEventToolsListChanged:
+					name := event.Payload.Name
+					c, ok := mcpClients.Get(name)
+					if !ok {
+						slog.Warn("MCP client not found for tools/prompts update", "name", name)
+						continue
+					}
+					cfg := config.Get()
+					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
 					if err != nil {
-						slog.Error("error listing prompts", "error", err)
-						updateMCPState(name, MCPStateError, err, nil, 0, 0)
+						slog.Error("error listing tools", "error", err)
+						updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
 						_ = c.Close()
 						continue
 					}
 					updateMcpTools(name, tools)
-					updateMcpPrompts(name, prompts)
 					a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
-					updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len(), len(prompts))
+					prevState, _ := mcpStates.Get(name)
+					prevState.Counts.Tools = len(tools)
+					updateMCPState(name, MCPStateConnected, nil, c, prevState.Counts)
 				default:
 					continue
 				}

internal/llm/agent/mcp-tools.go 🔗

@@ -60,12 +60,17 @@ const (
 
 // MCPEvent represents an event in the MCP system
 type MCPEvent struct {
-	Type        MCPEventType
-	Name        string
-	State       MCPState
-	Error       error
-	ToolCount   int
-	PromptCount int
+	Type   MCPEventType
+	Name   string
+	State  MCPState
+	Error  error
+	Counts MCPCounts
+}
+
+// MCPCounts number of available tools, prompts, etc.
+type MCPCounts struct {
+	Tools   int
+	Prompts int
 }
 
 // MCPClientInfo holds information about an MCP client's state
@@ -74,8 +79,7 @@ type MCPClientInfo struct {
 	State       MCPState
 	Error       error
 	Client      *mcp.ClientSession
-	ToolCount   int
-	PromptCount int
+	Counts      MCPCounts
 	ConnectedAt time.Time
 }
 
@@ -159,14 +163,14 @@ func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, err
 	if err == nil {
 		return sess, nil
 	}
-	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount, state.PromptCount)
+	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
 
 	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
 	if err != nil {
 		return nil, err
 	}
 
-	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount, state.PromptCount)
+	updateMCPState(name, MCPStateConnected, nil, sess, state.Counts)
 	mcpClients.Set(name, sess)
 	return sess, nil
 }
@@ -231,14 +235,13 @@ func GetMCPState(name string) (MCPClientInfo, bool) {
 }
 
 // updateMCPState updates the state of an MCP client and publishes an event
-func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount, promptCount int) {
+func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, counts MCPCounts) {
 	info := MCPClientInfo{
-		Name:        name,
-		State:       state,
-		Error:       err,
-		Client:      client,
-		ToolCount:   toolCount,
-		PromptCount: promptCount,
+		Name:   name,
+		State:  state,
+		Error:  err,
+		Client: client,
+		Counts: counts,
 	}
 	switch state {
 	case MCPStateConnected:
@@ -252,12 +255,11 @@ func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSe
 
 	// Publish state change event
 	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
-		Type:        MCPEventStateChanged,
-		Name:        name,
-		State:       state,
-		Error:       err,
-		ToolCount:   toolCount,
-		PromptCount: promptCount,
+		Type:   MCPEventStateChanged,
+		Name:   name,
+		State:  state,
+		Error:  err,
+		Counts: counts,
 	})
 }
 
@@ -278,13 +280,13 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 	// Initialize states for all configured MCPs
 	for name, m := range cfg.MCP {
 		if m.Disabled {
-			updateMCPState(name, MCPStateDisabled, nil, nil, 0, 0)
+			updateMCPState(name, MCPStateDisabled, nil, nil, MCPCounts{})
 			slog.Debug("skipping disabled mcp", "name", name)
 			continue
 		}
 
 		// Set initial starting state
-		updateMCPState(name, MCPStateStarting, nil, nil, 0, 0)
+		updateMCPState(name, MCPStateStarting, nil, nil, MCPCounts{})
 
 		wg.Add(1)
 		go func(name string, m config.MCPConfig) {
@@ -300,7 +302,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 					default:
 						err = fmt.Errorf("panic: %v", v)
 					}
-					updateMCPState(name, MCPStateError, err, nil, 0, 0)
+					updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
 					slog.Error("panic in mcp client initialization", "error", err, "name", name)
 				}
 			}()
@@ -318,7 +320,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
 			if err != nil {
 				slog.Error("error listing tools", "error", err)
-				updateMCPState(name, MCPStateError, err, nil, 0, 0)
+				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
 				c.Close()
 				return
 			}
@@ -326,7 +328,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 			prompts, err := getPrompts(ctx, c)
 			if err != nil {
 				slog.Error("error listing prompts", "error", err)
-				updateMCPState(name, MCPStateError, err, nil, 0, 0)
+				updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
 				c.Close()
 				return
 			}
@@ -334,7 +336,13 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 			updateMcpTools(name, tools)
 			updateMcpPrompts(name, prompts)
 			mcpClients.Set(name, c)
-			updateMCPState(name, MCPStateConnected, nil, c, len(tools), len(prompts))
+			updateMCPState(
+				name, MCPStateConnected, nil, c,
+				MCPCounts{
+					Tools:   len(tools),
+					Prompts: len(prompts),
+				},
+			)
 		}(name, m)
 	}
 	wg.Wait()
@@ -361,7 +369,7 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso
 
 	transport, err := createMCPTransport(mcpCtx, m, resolver)
 	if err != nil {
-		updateMCPState(name, MCPStateError, err, nil, 0, 0)
+		updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
 		slog.Error("error creating mcp client", "error", err, "name", name)
 		return nil, err
 	}
@@ -391,7 +399,7 @@ func createMCPSession(ctx context.Context, name string, m config.MCPConfig, reso
 
 	session, err := client.Connect(mcpCtx, transport, nil)
 	if err != nil {
-		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0, 0)
+		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, MCPCounts{})
 		slog.Error("error starting mcp client", "error", err, "name", name)
 		_ = session.Close()
 		cancel()

internal/tui/components/mcp/mcp.go 🔗

@@ -68,11 +68,11 @@ func RenderMCPList(opts RenderOptions) []string {
 				description = t.S().Subtle.Render("starting...")
 			case agent.MCPStateConnected:
 				icon = t.ItemOnlineIcon
-				if state.ToolCount > 0 {
-					extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", state.ToolCount)))
+				if count := state.Counts.Tools; count > 0 {
+					extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", count)))
 				}
-				if state.PromptCount > 0 {
-					extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", state.PromptCount)))
+				if count := state.Counts.Prompts; count > 0 {
+					extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", count)))
 				}
 			case agent.MCPStateError:
 				icon = t.ItemErrorIcon