fix: update Llama provider configuration and enhance error handling

Daniel Merja created

Change summary

internal/config/provider_mock.go |  42 ++++++-
internal/llm/provider/llama.go   | 175 +++++++++++++++++++++++++++------
2 files changed, 177 insertions(+), 40 deletions(-)

Detailed changes

internal/config/provider_mock.go 🔗

@@ -295,31 +295,55 @@ func MockProviders() []provider.Provider {
 			APIKey:              "$LLAMA_API_KEY",
 			APIEndpoint:         "https://api.llama.com/compat/v1",
 			Type:                provider.TypeLlama,
-			DefaultLargeModelID: "Llama-4-Maverick-17B-128E-Instruct-FP8",
-			DefaultSmallModelID: "Llama-3.3-8B-Instruct",
+			DefaultLargeModelID: "Llama-3.3-70B-Instruct",
+			DefaultSmallModelID: "Llama-4-Scout-17B-16E-Instruct-FP8",
 			Models: []provider.Model{
 				{
 					ID:                 "Llama-4-Maverick-17B-128E-Instruct-FP8",
 					Name:               "Llama 4 Maverick 17B 128E Instruct FP8",
-					CostPer1MIn:        2.0,
-					CostPer1MOut:       8.0,
+					CostPer1MIn:        0.0,
+					CostPer1MOut:       0.0,
 					CostPer1MInCached:  0.0,
 					CostPer1MOutCached: 0.0,
 					ContextWindow:      128000,
 					DefaultMaxTokens:   32000,
-					CanReason:          true,
+					CanReason:          false,
+					SupportsImages:     true,
+				},
+				{
+					ID:                 "Llama-4-Scout-17B-16E-Instruct-FP8",
+					Name:               "Llama 4 Scout 17B 16E Instruct FP8",
+					CostPer1MIn:        0.0,
+					CostPer1MOut:       0.0,
+					CostPer1MInCached:  0.0,
+					CostPer1MOutCached: 0.0,
+					ContextWindow:      128000,
+					DefaultMaxTokens:   16000,
+					CanReason:          false,
 					SupportsImages:     true,
 				},
+				{
+					ID:                 "Llama-3.3-70B-Instruct",
+					Name:               "Llama 3.3 70B Instruct",
+					CostPer1MIn:        0.0,
+					CostPer1MOut:       0.0,
+					CostPer1MInCached:  0.0,
+					CostPer1MOutCached: 0.0,
+					ContextWindow:      128000,
+					DefaultMaxTokens:   8000,
+					CanReason:          false,
+					SupportsImages:     false,
+				},
 				{
 					ID:                 "Llama-3.3-8B-Instruct",
 					Name:               "Llama 3.3 8B Instruct",
-					CostPer1MIn:        0.5,
-					CostPer1MOut:       2.0,
+					CostPer1MIn:        0.0,
+					CostPer1MOut:       0.0,
 					CostPer1MInCached:  0.0,
 					CostPer1MOutCached: 0.0,
 					ContextWindow:      128000,
-					DefaultMaxTokens:   16000,
-					CanReason:          true,
+					DefaultMaxTokens:   8000,
+					CanReason:          false,
 					SupportsImages:     false,
 				},
 			},

internal/llm/provider/llama.go 🔗

@@ -3,8 +3,13 @@ package provider
 import (
 	"context"
 	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"time"
 
 	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/fur/provider"
 	"github.com/charmbracelet/crush/internal/llm/tools"
 	"github.com/charmbracelet/crush/internal/logging"
 	"github.com/charmbracelet/crush/internal/message"
@@ -16,25 +21,35 @@ import (
 type LlamaClient ProviderClient
 
 func newLlamaClient(opts providerClientOptions) LlamaClient {
+	return &llamaClient{
+		providerOptions: opts,
+		client:          createLlamaClient(opts),
+	}
+}
+
+type llamaClient struct {
+	providerOptions providerClientOptions
+	client          openai.Client
+}
+
+func createLlamaClient(opts providerClientOptions) openai.Client {
 	openaiClientOptions := []option.RequestOption{}
 	if opts.apiKey != "" {
 		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 	}
-	openaiClientOptions = append(openaiClientOptions, option.WithBaseURL("https://api.llama.com/compat/v1/"))
+
+	baseURL := "https://api.llama.com/compat/v1/"
+	if opts.baseURL != "" {
+		baseURL = opts.baseURL
+	}
+	openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(baseURL))
+
 	if opts.extraHeaders != nil {
 		for key, value := range opts.extraHeaders {
 			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 		}
 	}
-	return &llamaClient{
-		providerOptions: opts,
-		client:          openai.NewClient(openaiClientOptions...),
-	}
-}
-
-type llamaClient struct {
-	providerOptions providerClientOptions
-	client          openai.Client
+	return openai.NewClient(openaiClientOptions...)
 }
 
 func (l *llamaClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
@@ -50,9 +65,24 @@ func (l *llamaClient) send(ctx context.Context, messages []message.Message, tool
 	for {
 		attempts++
 		openaiResponse, err := l.client.Chat.Completions.New(ctx, params)
+		// If there is an error we are going to see if we can retry the call
 		if err != nil {
-			return nil, err
+			retry, after, retryErr := l.shouldRetry(attempts, err)
+			if retryErr != nil {
+				return nil, retryErr
+			}
+			if retry {
+				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+				select {
+				case <-ctx.Done():
+					return nil, ctx.Err()
+				case <-time.After(time.Duration(after) * time.Millisecond):
+					continue
+				}
+			}
+			return nil, retryErr
 		}
+
 		content := ""
 		if openaiResponse.Choices[0].Message.Content != "" {
 			content = openaiResponse.Choices[0].Message.Content
@@ -86,48 +116,131 @@ func (l *llamaClient) stream(ctx context.Context, messages []message.Message, to
 	eventChan := make(chan ProviderEvent)
 	go func() {
 		attempts := 0
-		acc := openai.ChatCompletionAccumulator{}
-		currentContent := ""
-		toolCalls := make([]message.ToolCall, 0)
 		for {
 			attempts++
 			openaiStream := l.client.Chat.Completions.NewStreaming(ctx, params)
+
+			acc := openai.ChatCompletionAccumulator{}
+			currentContent := ""
+			toolCalls := make([]message.ToolCall, 0)
+
 			for openaiStream.Next() {
 				chunk := openaiStream.Current()
 				acc.AddChunk(chunk)
 				for _, choice := range chunk.Choices {
 					if choice.Delta.Content != "" {
+						eventChan <- ProviderEvent{
+							Type:    EventContentDelta,
+							Content: choice.Delta.Content,
+						}
 						currentContent += choice.Delta.Content
 					}
 				}
-				eventChan <- ProviderEvent{Type: EventContentDelta, Content: currentContent}
 			}
-			if err := openaiStream.Err(); err != nil {
-				eventChan <- ProviderEvent{Type: EventError, Error: err}
+
+			err := openaiStream.Err()
+			if err == nil || errors.Is(err, io.EOF) {
+				if cfg.Options.Debug {
+					jsonData, _ := json.Marshal(acc.ChatCompletion)
+					logging.Debug("Response", "messages", string(jsonData))
+				}
+				resultFinishReason := acc.ChatCompletion.Choices[0].FinishReason
+				if resultFinishReason == "" {
+					// If the finish reason is empty, we assume it was a successful completion
+					resultFinishReason = "stop"
+				}
+				// Stream completed successfully
+				finishReason := l.finishReason(resultFinishReason)
+				if len(acc.Choices[0].Message.ToolCalls) > 0 {
+					toolCalls = append(toolCalls, l.toolCalls(acc.ChatCompletion)...)
+				}
+				if len(toolCalls) > 0 {
+					finishReason = message.FinishReasonToolUse
+				}
+
+				eventChan <- ProviderEvent{
+					Type: EventComplete,
+					Response: &ProviderResponse{
+						Content:      currentContent,
+						ToolCalls:    toolCalls,
+						Usage:        l.usage(acc.ChatCompletion),
+						FinishReason: finishReason,
+					},
+				}
+				close(eventChan)
 				return
 			}
-			toolCalls = l.toolCalls(acc.ChatCompletion)
-			finishReason := l.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
-			if len(toolCalls) > 0 {
-				finishReason = message.FinishReasonToolUse
+
+			// If there is an error we are going to see if we can retry the call
+			retry, after, retryErr := l.shouldRetry(attempts, err)
+			if retryErr != nil {
+				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+				close(eventChan)
+				return
 			}
-			eventChan <- ProviderEvent{
-				Type: EventComplete,
-				Response: &ProviderResponse{
-					Content:      currentContent,
-					ToolCalls:    toolCalls,
-					Usage:        l.usage(acc.ChatCompletion),
-					FinishReason: finishReason,
-				},
+			if retry {
+				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+				select {
+				case <-ctx.Done():
+					// context cancelled
+					if ctx.Err() != nil {
+						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
+					}
+					close(eventChan)
+					return
+				case <-time.After(time.Duration(after) * time.Millisecond):
+					continue
+				}
 			}
+			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+			close(eventChan)
 			return
 		}
 	}()
 	return eventChan
 }
 
+func (l *llamaClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+	var apiErr *openai.Error
+	if !errors.As(err, &apiErr) {
+		return false, 0, err
+	}
+
+	if attempts > maxRetries {
+		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+	}
+
+	// Check for token expiration (401 Unauthorized)
+	if apiErr.StatusCode == 401 {
+		var err error
+		l.providerOptions.apiKey, err = config.ResolveAPIKey(l.providerOptions.config.APIKey)
+		if err != nil {
+			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+		}
+		l.client = createLlamaClient(l.providerOptions)
+		return true, 0, nil
+	}
+
+	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
+		return false, 0, err
+	}
+
+	retryMs := 0
+	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
+
+	backoffMs := 2000 * (1 << (attempts - 1))
+	jitterMs := int(float64(backoffMs) * 0.2)
+	retryMs = backoffMs + jitterMs
+	if len(retryAfterValues) > 0 {
+		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
+			retryMs = retryMs * 1000
+		}
+	}
+	return true, int64(retryMs), nil
+}
+
 func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
-	// Copied from openaiClient
+	// Add system message first
 	openaiMessages = append(openaiMessages, openai.SystemMessage(l.providerOptions.systemMessage))
 	for _, msg := range messages {
 		switch msg.Role {
@@ -136,7 +249,7 @@ func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessage
 			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 			for _, binaryContent := range msg.BinaryContent() {
-				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String("llama")}
+				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderLlama)}
 				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 			}