From 38df9e139903e40ab79717496a22a4f5ba593b18 Mon Sep 17 00:00:00 2001 From: Daniel Merja Date: Sun, 6 Jul 2025 20:54:06 -0700 Subject: [PATCH] fix: update Llama provider configuration and enhance error handling --- internal/config/provider_mock.go | 42 ++++++-- internal/llm/provider/llama.go | 175 +++++++++++++++++++++++++------ 2 files changed, 177 insertions(+), 40 deletions(-) diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go index e9a236f5b32951c2d79eb4303865faaedcc8b9c3..506522e2923d2a5aa7483bceb57d1a67fd1ae4d6 100644 --- a/internal/config/provider_mock.go +++ b/internal/config/provider_mock.go @@ -295,31 +295,55 @@ func MockProviders() []provider.Provider { APIKey: "$LLAMA_API_KEY", APIEndpoint: "https://api.llama.com/compat/v1", Type: provider.TypeLlama, - DefaultLargeModelID: "Llama-4-Maverick-17B-128E-Instruct-FP8", - DefaultSmallModelID: "Llama-3.3-8B-Instruct", + DefaultLargeModelID: "Llama-3.3-70B-Instruct", + DefaultSmallModelID: "Llama-4-Scout-17B-16E-Instruct-FP8", Models: []provider.Model{ { ID: "Llama-4-Maverick-17B-128E-Instruct-FP8", Name: "Llama 4 Maverick 17B 128E Instruct FP8", - CostPer1MIn: 2.0, - CostPer1MOut: 8.0, + CostPer1MIn: 0.0, + CostPer1MOut: 0.0, CostPer1MInCached: 0.0, CostPer1MOutCached: 0.0, ContextWindow: 128000, DefaultMaxTokens: 32000, - CanReason: true, + CanReason: false, + SupportsImages: true, + }, + { + ID: "Llama-4-Scout-17B-16E-Instruct-FP8", + Name: "Llama 4 Scout 17B 16E Instruct FP8", + CostPer1MIn: 0.0, + CostPer1MOut: 0.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 16000, + CanReason: false, SupportsImages: true, }, + { + ID: "Llama-3.3-70B-Instruct", + Name: "Llama 3.3 70B Instruct", + CostPer1MIn: 0.0, + CostPer1MOut: 0.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 8000, + CanReason: false, + SupportsImages: false, + }, { ID: "Llama-3.3-8B-Instruct", Name: "Llama 3.3 8B Instruct", - CostPer1MIn: 0.5, - CostPer1MOut: 2.0, + CostPer1MIn: 0.0, + CostPer1MOut: 0.0, CostPer1MInCached: 0.0, CostPer1MOutCached: 0.0, ContextWindow: 128000, - DefaultMaxTokens: 16000, - CanReason: true, + DefaultMaxTokens: 8000, + CanReason: false, SupportsImages: false, }, }, diff --git a/internal/llm/provider/llama.go b/internal/llm/provider/llama.go index f0f13bab494fbcc52134c60633a4de459e8f5e17..c04d795710499a1dc87e35bbe57d49eb7e7ce75b 100644 --- a/internal/llm/provider/llama.go +++ b/internal/llm/provider/llama.go @@ -3,8 +3,13 @@ package provider import ( "context" "encoding/json" + "errors" + "fmt" + "io" + "time" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/logging" "github.com/charmbracelet/crush/internal/message" @@ -16,25 +21,35 @@ import ( type LlamaClient ProviderClient func newLlamaClient(opts providerClientOptions) LlamaClient { + return &llamaClient{ + providerOptions: opts, + client: createLlamaClient(opts), + } +} + +type llamaClient struct { + providerOptions providerClientOptions + client openai.Client +} + +func createLlamaClient(opts providerClientOptions) openai.Client { openaiClientOptions := []option.RequestOption{} if opts.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } - openaiClientOptions = append(openaiClientOptions, option.WithBaseURL("https://api.llama.com/compat/v1/")) + + baseURL := "https://api.llama.com/compat/v1/" + if opts.baseURL != "" { + baseURL = opts.baseURL + } + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(baseURL)) + if opts.extraHeaders != nil { for key, value := range opts.extraHeaders { openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) } } - return &llamaClient{ - providerOptions: opts, - client: openai.NewClient(openaiClientOptions...), - } -} - -type llamaClient struct { - providerOptions providerClientOptions - client openai.Client + return openai.NewClient(openaiClientOptions...) } func (l *llamaClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { @@ -50,9 +65,24 @@ func (l *llamaClient) send(ctx context.Context, messages []message.Message, tool for { attempts++ openaiResponse, err := l.client.Chat.Completions.New(ctx, params) + // If there is an error we are going to see if we can retry the call if err != nil { - return nil, err + retry, after, retryErr := l.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr + } + if retry { + logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } + content := "" if openaiResponse.Choices[0].Message.Content != "" { content = openaiResponse.Choices[0].Message.Content @@ -86,48 +116,131 @@ func (l *llamaClient) stream(ctx context.Context, messages []message.Message, to eventChan := make(chan ProviderEvent) go func() { attempts := 0 - acc := openai.ChatCompletionAccumulator{} - currentContent := "" - toolCalls := make([]message.ToolCall, 0) for { attempts++ openaiStream := l.client.Chat.Completions.NewStreaming(ctx, params) + + acc := openai.ChatCompletionAccumulator{} + currentContent := "" + toolCalls := make([]message.ToolCall, 0) + for openaiStream.Next() { chunk := openaiStream.Current() acc.AddChunk(chunk) for _, choice := range chunk.Choices { if choice.Delta.Content != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: choice.Delta.Content, + } currentContent += choice.Delta.Content } } - eventChan <- ProviderEvent{Type: EventContentDelta, Content: currentContent} } - if err := openaiStream.Err(); err != nil { - eventChan <- ProviderEvent{Type: EventError, Error: err} + + err := openaiStream.Err() + if err == nil || errors.Is(err, io.EOF) { + if cfg.Options.Debug { + jsonData, _ := json.Marshal(acc.ChatCompletion) + logging.Debug("Response", "messages", string(jsonData)) + } + resultFinishReason := acc.ChatCompletion.Choices[0].FinishReason + if resultFinishReason == "" { + // If the finish reason is empty, we assume it was a successful completion + resultFinishReason = "stop" + } + // Stream completed successfully + finishReason := l.finishReason(resultFinishReason) + if len(acc.Choices[0].Message.ToolCalls) > 0 { + toolCalls = append(toolCalls, l.toolCalls(acc.ChatCompletion)...) + } + if len(toolCalls) > 0 { + finishReason = message.FinishReasonToolUse + } + + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: l.usage(acc.ChatCompletion), + FinishReason: finishReason, + }, + } + close(eventChan) return } - toolCalls = l.toolCalls(acc.ChatCompletion) - finishReason := l.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)) - if len(toolCalls) > 0 { - finishReason = message.FinishReasonToolUse + + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := l.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) + return } - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: l.usage(acc.ChatCompletion), - FinishReason: finishReason, - }, + if retry { + logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } } + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } }() return eventChan } +func (l *llamaClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apiErr *openai.Error + if !errors.As(err, &apiErr) { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + // Check for token expiration (401 Unauthorized) + if apiErr.StatusCode == 401 { + var err error + l.providerOptions.apiKey, err = config.ResolveAPIKey(l.providerOptions.config.APIKey) + if err != nil { + return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + } + l.client = createLlamaClient(l.providerOptions) + return true, 0, nil + } + + if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 { + return false, 0, err + } + + retryMs := 0 + retryAfterValues := apiErr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 + } + } + return true, int64(retryMs), nil +} + func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { - // Copied from openaiClient + // Add system message first openaiMessages = append(openaiMessages, openai.SystemMessage(l.providerOptions.systemMessage)) for _, msg := range messages { switch msg.Role { @@ -136,7 +249,7 @@ func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessage textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) for _, binaryContent := range msg.BinaryContent() { - imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String("llama")} + imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderLlama)} imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) }