Merge remote-tracking branch 'origin/main' into feat/mcp_notifications

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

.github/cla-signatures.json          |  8 +++
Taskfile.yaml                        |  3 
go.mod                               |  2 
go.sum                               |  4 
internal/event/logger.go             |  9 +-
internal/llm/agent/agent.go          | 18 +++---
internal/llm/agent/mcp-tools.go      | 24 +++++---
internal/llm/provider/gemini.go      | 14 +++-
internal/llm/provider/openai.go      | 12 +++
internal/llm/provider/openai_test.go | 76 ++++++++++++++++++++++++++++++
10 files changed, 136 insertions(+), 34 deletions(-)

Detailed changes

.github/cla-signatures.json 🔗

@@ -671,6 +671,14 @@
       "created_at": "2025-09-27T13:09:22Z",
       "repoId": 987670088,
       "pullRequestNo": 1141
+    },
+    {
+      "name": "Wangch29",
+      "id": 115294077,
+      "comment_id": 3344526018,
+      "created_at": "2025-09-29T01:19:40Z",
+      "repoId": 987670088,
+      "pullRequestNo": 1148
     }
   ]
 }

Taskfile.yaml 🔗

@@ -97,7 +97,8 @@ tasks:
       - sh: "[ $(git status --porcelain=2 | wc -l) = 0 ]"
         msg: "Git is dirty"
     cmds:
+      - git commit --allow-empty -m "{{.NEXT}}"
       - git tag -d nightly
       - git tag --sign {{.NEXT}} {{.CLI_ARGS}}
-      - echo "pushing {{.NEXT}}..."
+      - echo "Pushing {{.NEXT}}..."
       - git push origin --tags

go.mod 🔗

@@ -26,7 +26,7 @@ require (
 	github.com/google/uuid v1.6.0
 	github.com/invopop/jsonschema v0.13.0
 	github.com/joho/godotenv v1.5.1
-	github.com/mark3labs/mcp-go v0.40.0
+	github.com/mark3labs/mcp-go v0.41.0
 	github.com/muesli/termenv v0.16.0
 	github.com/ncruces/go-sqlite3 v0.29.0
 	github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646

go.sum 🔗

@@ -194,8 +194,8 @@ github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQ
 github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
 github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
-github.com/mark3labs/mcp-go v0.40.0 h1:M0oqK412OHBKut9JwXSsj4KanSmEKpzoW8TcxoPOkAU=
-github.com/mark3labs/mcp-go v0.40.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
+github.com/mark3labs/mcp-go v0.41.0 h1:IFfJaovCet65F3av00bE1HzSnmHpMRWM1kz96R98I70=
+github.com/mark3labs/mcp-go v0.41.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
 github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
 github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=

internal/event/logger.go 🔗

@@ -1,6 +1,7 @@
 package event
 
 import (
+	"fmt"
 	"log/slog"
 
 	"github.com/posthog/posthog-go"
@@ -11,17 +12,17 @@ var _ posthog.Logger = logger{}
 type logger struct{}
 
 func (logger) Debugf(format string, args ...any) {
-	slog.Debug(format, args...)
+	slog.Debug(fmt.Sprintf(format, args...))
 }
 
 func (logger) Logf(format string, args ...any) {
-	slog.Info(format, args...)
+	slog.Info(fmt.Sprintf(format, args...))
 }
 
 func (logger) Warnf(format string, args ...any) {
-	slog.Warn(format, args...)
+	slog.Warn(fmt.Sprintf(format, args...))
 }
 
 func (logger) Errorf(format string, args ...any) {
-	slog.Error(format, args...)
+	slog.Error(fmt.Sprintf(format, args...))
 }

internal/llm/agent/agent.go 🔗

@@ -27,8 +27,6 @@ import (
 	"github.com/charmbracelet/crush/internal/shell"
 )
 
-const streamChunkTimeout = 80 * time.Second
-
 type AgentEventType string
 
 const (
@@ -577,7 +575,6 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 	// Add the session and message ID into the context if needed by tools.
 	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 
-	// Process each event in the stream.
 loop:
 	for {
 		select {
@@ -593,9 +590,6 @@ loop:
 				}
 				return assistantMsg, nil, processErr
 			}
-		case <-time.After(streamChunkTimeout):
-			a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
-			return assistantMsg, nil, fmt.Errorf("stream chunk timeout")
 		case <-ctx.Done():
 			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 			return assistantMsg, nil, ctx.Err()
@@ -1129,7 +1123,13 @@ func (a *agent) setupEvents(ctx context.Context) {
 						continue
 					}
 					cfg := config.Get()
-					tools := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
+					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
+					if err != nil {
+						slog.Error("error listing tools", "error", err)
+						updateMCPState(name, MCPStateError, err, nil, 0)
+						_ = c.Close()
+						continue
+					}
 					updateMcpTools(name, tools)
 					// Update the lazy map with the new tools
 					a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
@@ -1144,7 +1144,5 @@ func (a *agent) setupEvents(ctx context.Context) {
 		}
 	}()
 
-	a.cleanupFuncs = append(a.cleanupFuncs, func() {
-		cancel()
-	})
+	a.cleanupFuncs = append(a.cleanupFuncs, cancel)
 }

internal/llm/agent/mcp-tools.go 🔗

@@ -197,13 +197,10 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
 	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
 }
 
-func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
+func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) {
 	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
 	if err != nil {
-		slog.Error("error listing tools", "error", err)
-		updateMCPState(name, MCPStateError, err, nil, 0)
-		c.Close()
-		return nil
+		return nil, err
 	}
 	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
 	for _, tool := range result.Tools {
@@ -214,7 +211,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service,
 			workingDir:  workingDir,
 		})
 	}
-	return mcpTools
+	return mcpTools, nil
 }
 
 // SubscribeMCPEvents returns a channel for MCP events
@@ -324,6 +321,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 
 			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
 			defer cancel()
+
 			c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
 			if err != nil {
 				return
@@ -331,8 +329,16 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 
 			mcpClients.Set(name, c)
 
-			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
+			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
+			if err != nil {
+				slog.Error("error listing tools", "error", err)
+				updateMCPState(name, MCPStateError, err, nil, 0)
+				c.Close()
+				return
+			}
+
 			updateMcpTools(name, tools)
+			mcpClients.Set(name, c)
 			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
 		}(name, m)
 	}
@@ -375,8 +381,8 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
 	initCtx, cancel := context.WithTimeout(ctx, timeout)
 	defer cancel()
 
-	if err := c.Start(ctx); err != nil {
-		updateMCPState(name, MCPStateError, err, nil, 0)
+	if err := c.Start(initCtx); err != nil {
+		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
 		slog.Error("error starting mcp client", "error", err, "name", name)
 		_ = c.Close()
 		return nil, err

internal/llm/provider/gemini.go 🔗

@@ -43,9 +43,14 @@ func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
 	cc := &genai.ClientConfig{
 		APIKey:  opts.apiKey,
 		Backend: genai.BackendGeminiAPI,
-		HTTPOptions: genai.HTTPOptions{
-			BaseURL: opts.baseURL,
-		},
+	}
+	if opts.baseURL != "" {
+		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
+		if err == nil && resolvedBaseURL != "" {
+			cc.HTTPOptions = genai.HTTPOptions{
+				BaseURL: resolvedBaseURL,
+			}
+		}
 	}
 	if config.Get().Options.Debug {
 		cc.HTTPClient = log.NewHTTPClient()
@@ -65,9 +70,8 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont
 			var parts []*genai.Part
 			parts = append(parts, &genai.Part{Text: msg.Content().String()})
 			for _, binaryContent := range msg.BinaryContent() {
-				imageFormat := strings.Split(binaryContent.MIMEType, "/")
 				parts = append(parts, &genai.Part{InlineData: &genai.Blob{
-					MIMEType: imageFormat[1],
+					MIMEType: binaryContent.MIMEType,
 					Data:     binaryContent.Data,
 				}})
 			}

internal/llm/provider/openai.go 🔗

@@ -529,11 +529,19 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
 			return true, 0, nil
 		}
 
-		if apiErr.StatusCode != http.StatusTooManyRequests && apiErr.StatusCode != http.StatusInternalServerError {
+		if apiErr.StatusCode == http.StatusTooManyRequests {
+			// Check if this is an insufficient quota error (permanent)
+			if apiErr.Type == "insufficient_quota" || apiErr.Code == "insufficient_quota" {
+				return false, 0, fmt.Errorf("OpenAI quota exceeded: %s. Please check your plan and billing details", apiErr.Message)
+			}
+			// Other 429 errors (rate limiting) can be retried
+		} else if apiErr.StatusCode != http.StatusInternalServerError {
 			return false, 0, err
 		}
 
-		retryAfterValues = apiErr.Response.Header.Values("Retry-After")
+		if apiErr.Response != nil {
+			retryAfterValues = apiErr.Response.Header.Values("Retry-After")
+		}
 	}
 
 	if apiErr != nil {

internal/llm/provider/openai_test.go 🔗

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"os"
+	"strings"
 	"testing"
 	"time"
 
@@ -88,3 +89,78 @@ func TestOpenAIClientStreamChoices(t *testing.T) {
 		}
 	}
 }
+
+func TestOpenAIClient429InsufficientQuotaError(t *testing.T) {
+	client := &openaiClient{
+		providerOptions: providerClientOptions{
+			modelType:     config.SelectedModelTypeLarge,
+			apiKey:        "test-key",
+			systemMessage: "test",
+			config: config.ProviderConfig{
+				ID:     "test-openai",
+				APIKey: "test-key",
+			},
+			model: func(config.SelectedModelType) catwalk.Model {
+				return catwalk.Model{
+					ID:   "test-model",
+					Name: "test-model",
+				}
+			},
+		},
+	}
+
+	// Test insufficient_quota error should not retry
+	apiErr := &openai.Error{
+		StatusCode: 429,
+		Message:    "You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.",
+		Type:       "insufficient_quota",
+		Code:       "insufficient_quota",
+	}
+
+	retry, _, err := client.shouldRetry(1, apiErr)
+	if retry {
+		t.Error("Expected shouldRetry to return false for insufficient_quota error, but got true")
+	}
+	if err == nil {
+		t.Error("Expected shouldRetry to return an error for insufficient_quota, but got nil")
+	}
+	if err != nil && !strings.Contains(err.Error(), "quota") {
+		t.Errorf("Expected error message to mention quota, got: %v", err)
+	}
+}
+
+func TestOpenAIClient429RateLimitError(t *testing.T) {
+	client := &openaiClient{
+		providerOptions: providerClientOptions{
+			modelType:     config.SelectedModelTypeLarge,
+			apiKey:        "test-key",
+			systemMessage: "test",
+			config: config.ProviderConfig{
+				ID:     "test-openai",
+				APIKey: "test-key",
+			},
+			model: func(config.SelectedModelType) catwalk.Model {
+				return catwalk.Model{
+					ID:   "test-model",
+					Name: "test-model",
+				}
+			},
+		},
+	}
+
+	// Test regular rate limit error should retry
+	apiErr := &openai.Error{
+		StatusCode: 429,
+		Message:    "Rate limit reached for requests",
+		Type:       "rate_limit_exceeded",
+		Code:       "rate_limit_exceeded",
+	}
+
+	retry, _, err := client.shouldRetry(1, apiErr)
+	if !retry {
+		t.Error("Expected shouldRetry to return true for rate_limit_exceeded error, but got false")
+	}
+	if err != nil {
+		t.Errorf("Expected shouldRetry to return nil error for rate_limit_exceeded, but got: %v", err)
+	}
+}