diff --git a/.github/cla-signatures.json b/.github/cla-signatures.json index e92b62b75054734fd16450e5ec5e3eb56ee3cf57..4ecfd86887d5d072491a8fa764628e8935e4ebfe 100644 --- a/.github/cla-signatures.json +++ b/.github/cla-signatures.json @@ -671,6 +671,14 @@ "created_at": "2025-09-27T13:09:22Z", "repoId": 987670088, "pullRequestNo": 1141 + }, + { + "name": "Wangch29", + "id": 115294077, + "comment_id": 3344526018, + "created_at": "2025-09-29T01:19:40Z", + "repoId": 987670088, + "pullRequestNo": 1148 } ] } \ No newline at end of file diff --git a/Taskfile.yaml b/Taskfile.yaml index 80d6bd86d1070e2f4e900660a7cab060ebdfbcea..54b50a68217b6ff66ddf1de9a28a8f45d224fefc 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -97,7 +97,8 @@ tasks: - sh: "[ $(git status --porcelain=2 | wc -l) = 0 ]" msg: "Git is dirty" cmds: + - git commit --allow-empty -m "{{.NEXT}}" - git tag -d nightly - git tag --sign {{.NEXT}} {{.CLI_ARGS}} - - echo "pushing {{.NEXT}}..." + - echo "Pushing {{.NEXT}}..." - git push origin --tags diff --git a/go.mod b/go.mod index d3e668320cfdec39160d618b189f1470bf07d028..699233cdd52fe59165e8f9c44a85d1413f1bc4b6 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.13.0 github.com/joho/godotenv v1.5.1 - github.com/mark3labs/mcp-go v0.40.0 + github.com/mark3labs/mcp-go v0.41.0 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.29.0 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index f69217e9d4e9831abc8e1b47b80e23a19dcfcffa..f54651f8f6b5fa0e6f9f4a3ee53a61d0eec0970c 100644 --- a/go.sum +++ b/go.sum @@ -194,8 +194,8 @@ github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQ github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.40.0 h1:M0oqK412OHBKut9JwXSsj4KanSmEKpzoW8TcxoPOkAU= -github.com/mark3labs/mcp-go v0.40.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mark3labs/mcp-go v0.41.0 h1:IFfJaovCet65F3av00bE1HzSnmHpMRWM1kz96R98I70= +github.com/mark3labs/mcp-go v0.41.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= diff --git a/internal/event/logger.go b/internal/event/logger.go index 7648ae2c2cca91ed20535c0d65a677cd4db84500..7581676b018f5ac6001827db851a132792d21985 100644 --- a/internal/event/logger.go +++ b/internal/event/logger.go @@ -1,6 +1,7 @@ package event import ( + "fmt" "log/slog" "github.com/posthog/posthog-go" @@ -11,17 +12,17 @@ var _ posthog.Logger = logger{} type logger struct{} func (logger) Debugf(format string, args ...any) { - slog.Debug(format, args...) + slog.Debug(fmt.Sprintf(format, args...)) } func (logger) Logf(format string, args ...any) { - slog.Info(format, args...) + slog.Info(fmt.Sprintf(format, args...)) } func (logger) Warnf(format string, args ...any) { - slog.Warn(format, args...) + slog.Warn(fmt.Sprintf(format, args...)) } func (logger) Errorf(format string, args ...any) { - slog.Error(format, args...) + slog.Error(fmt.Sprintf(format, args...)) } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index c8d19cf6da0a312475ca1610371ec462c287d04f..ff907c1caeea3b266bda767eb37cb408cccb2ff5 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -27,8 +27,6 @@ import ( "github.com/charmbracelet/crush/internal/shell" ) -const streamChunkTimeout = 80 * time.Second - type AgentEventType string const ( @@ -577,7 +575,6 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg // Add the session and message ID into the context if needed by tools. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) - // Process each event in the stream. loop: for { select { @@ -593,9 +590,6 @@ loop: } return assistantMsg, nil, processErr } - case <-time.After(streamChunkTimeout): - a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout") - return assistantMsg, nil, fmt.Errorf("stream chunk timeout") case <-ctx.Done(): a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "") return assistantMsg, nil, ctx.Err() @@ -1129,7 +1123,13 @@ func (a *agent) setupEvents(ctx context.Context) { continue } cfg := config.Get() - tools := getTools(ctx, name, a.permissions, c, cfg.WorkingDir()) + tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir()) + if err != nil { + slog.Error("error listing tools", "error", err) + updateMCPState(name, MCPStateError, err, nil, 0) + _ = c.Close() + continue + } updateMcpTools(name, tools) // Update the lazy map with the new tools a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2())) @@ -1144,7 +1144,5 @@ func (a *agent) setupEvents(ctx context.Context) { } }() - a.cleanupFuncs = append(a.cleanupFuncs, func() { - cancel() - }) + a.cleanupFuncs = append(a.cleanupFuncs, cancel) } diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index fb6168ea7819822eaf882940d46cf1b96cc1428d..baeac3f2546c9668d6d34dec49f9fef5cde4bf8e 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -197,13 +197,10 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return runTool(ctx, b.mcpName, b.tool.Name, params.Input) } -func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool { +func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) { result, err := c.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { - slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) - c.Close() - return nil + return nil, err } mcpTools := make([]tools.BaseTool, 0, len(result.Tools)) for _, tool := range result.Tools { @@ -214,7 +211,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service, workingDir: workingDir, }) } - return mcpTools + return mcpTools, nil } // SubscribeMCPEvents returns a channel for MCP events @@ -324,6 +321,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) defer cancel() + c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver()) if err != nil { return @@ -331,8 +329,16 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con mcpClients.Set(name, c) - tools := getTools(ctx, name, permissions, c, cfg.WorkingDir()) + tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir()) + if err != nil { + slog.Error("error listing tools", "error", err) + updateMCPState(name, MCPStateError, err, nil, 0) + c.Close() + return + } + updateMcpTools(name, tools) + mcpClients.Set(name, c) updateMCPState(name, MCPStateConnected, nil, c, len(tools)) }(name, m) } @@ -375,8 +381,8 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon initCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - if err := c.Start(ctx); err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) + if err := c.Start(initCtx); err != nil { + updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error starting mcp client", "error", err, "name", name) _ = c.Close() return nil, err diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 3987deb7ebcc6330c9d3bcb4a52aeeb292eab43f..a846d8d582524bb6bf9c8ed31e3796ec8d94b419 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -43,9 +43,14 @@ func createGeminiClient(opts providerClientOptions) (*genai.Client, error) { cc := &genai.ClientConfig{ APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI, - HTTPOptions: genai.HTTPOptions{ - BaseURL: opts.baseURL, - }, + } + if opts.baseURL != "" { + resolvedBaseURL, err := config.Get().Resolve(opts.baseURL) + if err == nil && resolvedBaseURL != "" { + cc.HTTPOptions = genai.HTTPOptions{ + BaseURL: resolvedBaseURL, + } + } } if config.Get().Options.Debug { cc.HTTPClient = log.NewHTTPClient() @@ -65,9 +70,8 @@ func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Cont var parts []*genai.Part parts = append(parts, &genai.Part{Text: msg.Content().String()}) for _, binaryContent := range msg.BinaryContent() { - imageFormat := strings.Split(binaryContent.MIMEType, "/") parts = append(parts, &genai.Part{InlineData: &genai.Blob{ - MIMEType: imageFormat[1], + MIMEType: binaryContent.MIMEType, Data: binaryContent.Data, }}) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 8ec366caff4156fbf4baae76fc24ce5c30d4a91d..3e92e077b3156ddccc186e0b104b7db174290c18 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -529,11 +529,19 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) return true, 0, nil } - if apiErr.StatusCode != http.StatusTooManyRequests && apiErr.StatusCode != http.StatusInternalServerError { + if apiErr.StatusCode == http.StatusTooManyRequests { + // Check if this is an insufficient quota error (permanent) + if apiErr.Type == "insufficient_quota" || apiErr.Code == "insufficient_quota" { + return false, 0, fmt.Errorf("OpenAI quota exceeded: %s. Please check your plan and billing details", apiErr.Message) + } + // Other 429 errors (rate limiting) can be retried + } else if apiErr.StatusCode != http.StatusInternalServerError { return false, 0, err } - retryAfterValues = apiErr.Response.Header.Values("Retry-After") + if apiErr.Response != nil { + retryAfterValues = apiErr.Response.Header.Values("Retry-After") + } } if apiErr != nil { diff --git a/internal/llm/provider/openai_test.go b/internal/llm/provider/openai_test.go index 8088ba22b4cd49b26130cd3812e8705e8dfe1cba..52b0a20c9316d67ba987ccc5051aa2f6d321aff4 100644 --- a/internal/llm/provider/openai_test.go +++ b/internal/llm/provider/openai_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "time" @@ -88,3 +89,78 @@ func TestOpenAIClientStreamChoices(t *testing.T) { } } } + +func TestOpenAIClient429InsufficientQuotaError(t *testing.T) { + client := &openaiClient{ + providerOptions: providerClientOptions{ + modelType: config.SelectedModelTypeLarge, + apiKey: "test-key", + systemMessage: "test", + config: config.ProviderConfig{ + ID: "test-openai", + APIKey: "test-key", + }, + model: func(config.SelectedModelType) catwalk.Model { + return catwalk.Model{ + ID: "test-model", + Name: "test-model", + } + }, + }, + } + + // Test insufficient_quota error should not retry + apiErr := &openai.Error{ + StatusCode: 429, + Message: "You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.", + Type: "insufficient_quota", + Code: "insufficient_quota", + } + + retry, _, err := client.shouldRetry(1, apiErr) + if retry { + t.Error("Expected shouldRetry to return false for insufficient_quota error, but got true") + } + if err == nil { + t.Error("Expected shouldRetry to return an error for insufficient_quota, but got nil") + } + if err != nil && !strings.Contains(err.Error(), "quota") { + t.Errorf("Expected error message to mention quota, got: %v", err) + } +} + +func TestOpenAIClient429RateLimitError(t *testing.T) { + client := &openaiClient{ + providerOptions: providerClientOptions{ + modelType: config.SelectedModelTypeLarge, + apiKey: "test-key", + systemMessage: "test", + config: config.ProviderConfig{ + ID: "test-openai", + APIKey: "test-key", + }, + model: func(config.SelectedModelType) catwalk.Model { + return catwalk.Model{ + ID: "test-model", + Name: "test-model", + } + }, + }, + } + + // Test regular rate limit error should retry + apiErr := &openai.Error{ + StatusCode: 429, + Message: "Rate limit reached for requests", + Type: "rate_limit_exceeded", + Code: "rate_limit_exceeded", + } + + retry, _, err := client.shouldRetry(1, apiErr) + if !retry { + t.Error("Expected shouldRetry to return true for rate_limit_exceeded error, but got false") + } + if err != nil { + t.Errorf("Expected shouldRetry to return nil error for rate_limit_exceeded, but got: %v", err) + } +}