@@ -295,31 +295,55 @@ func MockProviders() []provider.Provider {
APIKey: "$LLAMA_API_KEY",
APIEndpoint: "https://api.llama.com/compat/v1",
Type: provider.TypeLlama,
- DefaultLargeModelID: "Llama-4-Maverick-17B-128E-Instruct-FP8",
- DefaultSmallModelID: "Llama-3.3-8B-Instruct",
+ DefaultLargeModelID: "Llama-3.3-70B-Instruct",
+ DefaultSmallModelID: "Llama-4-Scout-17B-16E-Instruct-FP8",
Models: []provider.Model{
{
ID: "Llama-4-Maverick-17B-128E-Instruct-FP8",
Name: "Llama 4 Maverick 17B 128E Instruct FP8",
- CostPer1MIn: 2.0,
- CostPer1MOut: 8.0,
+ CostPer1MIn: 0.0,
+ CostPer1MOut: 0.0,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.0,
ContextWindow: 128000,
DefaultMaxTokens: 32000,
- CanReason: true,
+ CanReason: false,
+ SupportsImages: true,
+ },
+ {
+ ID: "Llama-4-Scout-17B-16E-Instruct-FP8",
+ Name: "Llama 4 Scout 17B 16E Instruct FP8",
+ CostPer1MIn: 0.0,
+ CostPer1MOut: 0.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 16000,
+ CanReason: false,
SupportsImages: true,
},
+ {
+ ID: "Llama-3.3-70B-Instruct",
+ Name: "Llama 3.3 70B Instruct",
+ CostPer1MIn: 0.0,
+ CostPer1MOut: 0.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 8000,
+ CanReason: false,
+ SupportsImages: false,
+ },
{
ID: "Llama-3.3-8B-Instruct",
Name: "Llama 3.3 8B Instruct",
- CostPer1MIn: 0.5,
- CostPer1MOut: 2.0,
+ CostPer1MIn: 0.0,
+ CostPer1MOut: 0.0,
CostPer1MInCached: 0.0,
CostPer1MOutCached: 0.0,
ContextWindow: 128000,
- DefaultMaxTokens: 16000,
- CanReason: true,
+ DefaultMaxTokens: 8000,
+ CanReason: false,
SupportsImages: false,
},
},
@@ -3,8 +3,13 @@ package provider
import (
"context"
"encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "time"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/logging"
"github.com/charmbracelet/crush/internal/message"
@@ -16,25 +21,35 @@ import (
type LlamaClient ProviderClient
func newLlamaClient(opts providerClientOptions) LlamaClient {
+ return &llamaClient{
+ providerOptions: opts,
+ client: createLlamaClient(opts),
+ }
+}
+
+type llamaClient struct {
+ providerOptions providerClientOptions
+ client openai.Client
+}
+
+func createLlamaClient(opts providerClientOptions) openai.Client {
openaiClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
}
- openaiClientOptions = append(openaiClientOptions, option.WithBaseURL("https://api.llama.com/compat/v1/"))
+
+ baseURL := "https://api.llama.com/compat/v1/"
+ if opts.baseURL != "" {
+ baseURL = opts.baseURL
+ }
+ openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(baseURL))
+
if opts.extraHeaders != nil {
for key, value := range opts.extraHeaders {
openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
}
}
- return &llamaClient{
- providerOptions: opts,
- client: openai.NewClient(openaiClientOptions...),
- }
-}
-
-type llamaClient struct {
- providerOptions providerClientOptions
- client openai.Client
+ return openai.NewClient(openaiClientOptions...)
}
func (l *llamaClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
@@ -50,9 +65,24 @@ func (l *llamaClient) send(ctx context.Context, messages []message.Message, tool
for {
attempts++
openaiResponse, err := l.client.Chat.Completions.New(ctx, params)
+ // If there is an error we are going to see if we can retry the call
if err != nil {
- return nil, err
+ retry, after, retryErr := l.shouldRetry(attempts, err)
+ if retryErr != nil {
+ return nil, retryErr
+ }
+ if retry {
+ logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
+ }
+ return nil, retryErr
}
+
content := ""
if openaiResponse.Choices[0].Message.Content != "" {
content = openaiResponse.Choices[0].Message.Content
@@ -86,48 +116,131 @@ func (l *llamaClient) stream(ctx context.Context, messages []message.Message, to
eventChan := make(chan ProviderEvent)
go func() {
attempts := 0
- acc := openai.ChatCompletionAccumulator{}
- currentContent := ""
- toolCalls := make([]message.ToolCall, 0)
for {
attempts++
openaiStream := l.client.Chat.Completions.NewStreaming(ctx, params)
+
+ acc := openai.ChatCompletionAccumulator{}
+ currentContent := ""
+ toolCalls := make([]message.ToolCall, 0)
+
for openaiStream.Next() {
chunk := openaiStream.Current()
acc.AddChunk(chunk)
for _, choice := range chunk.Choices {
if choice.Delta.Content != "" {
+ eventChan <- ProviderEvent{
+ Type: EventContentDelta,
+ Content: choice.Delta.Content,
+ }
currentContent += choice.Delta.Content
}
}
- eventChan <- ProviderEvent{Type: EventContentDelta, Content: currentContent}
}
- if err := openaiStream.Err(); err != nil {
- eventChan <- ProviderEvent{Type: EventError, Error: err}
+
+ err := openaiStream.Err()
+ if err == nil || errors.Is(err, io.EOF) {
+ if cfg.Options.Debug {
+ jsonData, _ := json.Marshal(acc.ChatCompletion)
+ logging.Debug("Response", "messages", string(jsonData))
+ }
+ resultFinishReason := acc.ChatCompletion.Choices[0].FinishReason
+ if resultFinishReason == "" {
+ // If the finish reason is empty, we assume it was a successful completion
+ resultFinishReason = "stop"
+ }
+ // Stream completed successfully
+ finishReason := l.finishReason(resultFinishReason)
+ if len(acc.Choices[0].Message.ToolCalls) > 0 {
+ toolCalls = append(toolCalls, l.toolCalls(acc.ChatCompletion)...)
+ }
+ if len(toolCalls) > 0 {
+ finishReason = message.FinishReasonToolUse
+ }
+
+ eventChan <- ProviderEvent{
+ Type: EventComplete,
+ Response: &ProviderResponse{
+ Content: currentContent,
+ ToolCalls: toolCalls,
+ Usage: l.usage(acc.ChatCompletion),
+ FinishReason: finishReason,
+ },
+ }
+ close(eventChan)
return
}
- toolCalls = l.toolCalls(acc.ChatCompletion)
- finishReason := l.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
- if len(toolCalls) > 0 {
- finishReason = message.FinishReasonToolUse
+
+ // If there is an error we are going to see if we can retry the call
+ retry, after, retryErr := l.shouldRetry(attempts, err)
+ if retryErr != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ close(eventChan)
+ return
}
- eventChan <- ProviderEvent{
- Type: EventComplete,
- Response: &ProviderResponse{
- Content: currentContent,
- ToolCalls: toolCalls,
- Usage: l.usage(acc.ChatCompletion),
- FinishReason: finishReason,
- },
+ if retry {
+ logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
+ select {
+ case <-ctx.Done():
+ // context cancelled
+ if ctx.Err() != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
+ }
+ close(eventChan)
+ return
+ case <-time.After(time.Duration(after) * time.Millisecond):
+ continue
+ }
}
+ eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
+ close(eventChan)
return
}
}()
return eventChan
}
+func (l *llamaClient) shouldRetry(attempts int, err error) (bool, int64, error) {
+ var apiErr *openai.Error
+ if !errors.As(err, &apiErr) {
+ return false, 0, err
+ }
+
+ if attempts > maxRetries {
+ return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
+ }
+
+ // Check for token expiration (401 Unauthorized)
+ if apiErr.StatusCode == 401 {
+ var err error
+ l.providerOptions.apiKey, err = config.ResolveAPIKey(l.providerOptions.config.APIKey)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+ }
+ l.client = createLlamaClient(l.providerOptions)
+ return true, 0, nil
+ }
+
+ if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
+ return false, 0, err
+ }
+
+ retryMs := 0
+ retryAfterValues := apiErr.Response.Header.Values("Retry-After")
+
+ backoffMs := 2000 * (1 << (attempts - 1))
+ jitterMs := int(float64(backoffMs) * 0.2)
+ retryMs = backoffMs + jitterMs
+ if len(retryAfterValues) > 0 {
+ if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
+ retryMs = retryMs * 1000
+ }
+ }
+ return true, int64(retryMs), nil
+}
+
func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
- // Copied from openaiClient
+ // Add system message first
openaiMessages = append(openaiMessages, openai.SystemMessage(l.providerOptions.systemMessage))
for _, msg := range messages {
switch msg.Role {
@@ -136,7 +249,7 @@ func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessage
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
for _, binaryContent := range msg.BinaryContent() {
- imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String("llama")}
+ imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderLlama)}
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
}