feat: add Llama API provider support

danielmerja created

Change summary

README.md                         |   3 
crush-schema.json                 |  15 +
internal/config/config.go         |   7 
internal/config/config_test.go    |   9 
internal/config/provider_mock.go  |  35 ++++
internal/fur/provider/provider.go |   3 
internal/llm/provider/llama.go    | 279 +++++++++++++++++++++++++++++++++
internal/llm/provider/provider.go |   5 
8 files changed, 351 insertions(+), 5 deletions(-)

Detailed changes

README.md πŸ”—

@@ -1,4 +1,4 @@
-> [!WARNING]
+b> [!WARNING]
 > 🚧 This is a pre-release under heavy, active development. Things are still in flux but we’re excited to share early progress.
 
 # Crush
@@ -44,6 +44,7 @@ providers.
 | `AZURE_OPENAI_ENDPOINT`    | Azure OpenAI models                                |
 | `AZURE_OPENAI_API_KEY`     | Azure OpenAI models (optional when using Entra ID) |
 | `AZURE_OPENAI_API_VERSION` | Azure OpenAI models                                |
+| `LLAMA_API_KEY`             | Llama API                                          |
 
 ## License
 

crush-schema.json πŸ”—

@@ -498,7 +498,14 @@
             "mistralai/mistral-7b-instruct-v0.1",
             "openai/gpt-3.5-turbo-16k",
             "openai/gpt-4",
-            "openai/gpt-4-0314"
+            "openai/gpt-4-0314",
+            "Llama-4-Maverick-17B-128E-Instruct-FP8",
+            "Llama-4-Scout-17B-128E-Instruct-FP8",
+            "Llama-3.3-70B-Instruct",
+            "Llama-3.3-8B-Instruct",
+            "Cerebras-Llama-4-Maverick-17B-128E-Instruct",
+            "Cerebras-Llama-4-Scout-17B-16E-Instruct",
+            "Groq-Llama-4-Maverick-17B-128E-Instruct"
           ],
           "title": "Model ID",
           "description": "ID of the preferred model"
@@ -513,7 +520,8 @@
             "bedrock",
             "vertex",
             "xai",
-            "openrouter"
+            "openrouter",
+            "llama"
           ],
           "title": "Provider",
           "description": "Provider for the preferred model"
@@ -569,7 +577,8 @@
             "bedrock",
             "vertex",
             "xai",
-            "openrouter"
+            "openrouter",
+            "llama"
           ],
           "title": "Provider ID",
           "description": "Unique identifier for the provider"

internal/config/config.go πŸ”—

@@ -685,6 +685,12 @@ func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig
 			ID:           providerID,
 			ProviderType: provider.TypeVertexAI,
 		}
+	case provider.InferenceProviderLlama:
+		return ProviderConfig{
+			ID:           providerID,
+			ProviderType: provider.TypeLlama,
+			BaseURL:      "https://api.llama.com/compat/v1",
+		}
 	default:
 		return ProviderConfig{
 			ID:           providerID,
@@ -1061,6 +1067,7 @@ func (c *Config) validateProviders(errors *ValidationErrors) {
 		provider.TypeBedrock,
 		provider.TypeVertexAI,
 		provider.TypeXAI,
+		provider.TypeLlama, // Added Llama
 	}
 
 	for providerID, providerConfig := range c.Providers {

internal/config/config_test.go πŸ”—

@@ -21,6 +21,7 @@ func reset() {
 		"GEMINI_API_KEY",
 		"XAI_API_KEY",
 		"OPENROUTER_API_KEY",
+		"LLAMA_API_KEY",
 
 		// Google Cloud / VertexAI
 		"GOOGLE_GENAI_USE_VERTEXAI",
@@ -405,11 +406,12 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) {
 	os.Setenv("GEMINI_API_KEY", "test-gemini-key")
 	os.Setenv("XAI_API_KEY", "test-xai-key")
 	os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
+	os.Setenv("LLAMA_API_KEY", "test-llama-key")
 
 	cfg, err := Init(cwdDir, false)
 
 	require.NoError(t, err)
-	assert.Len(t, cfg.Providers, 5)
+	assert.Len(t, cfg.Providers, 6)
 
 	anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
 	assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey)
@@ -431,6 +433,11 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) {
 	assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey)
 	assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType)
 	assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL)
+
+	llamaProvider := cfg.Providers[provider.InferenceProviderLlama]
+	assert.Equal(t, "test-llama-key", llamaProvider.APIKey)
+	assert.Equal(t, provider.TypeLlama, llamaProvider.ProviderType)
+	assert.Equal(t, "https://api.llama.com/compat/v1", llamaProvider.BaseURL)
 }
 
 func TestEnvVars_PartialEnvironmentVariables(t *testing.T) {

internal/config/provider_mock.go πŸ”—

@@ -289,5 +289,40 @@ func MockProviders() []provider.Provider {
 				},
 			},
 		},
+		{
+			Name:                "Llama API",
+			ID:                  provider.InferenceProviderLlama,
+			APIKey:              "$LLAMA_API_KEY",
+			APIEndpoint:         "https://api.llama.com/compat/v1",
+			Type:                provider.TypeLlama,
+			DefaultLargeModelID: "Llama-4-Maverick-17B-128E-Instruct-FP8",
+			DefaultSmallModelID: "Llama-3.3-8B-Instruct",
+			Models: []provider.Model{
+				{
+					ID:                 "Llama-4-Maverick-17B-128E-Instruct-FP8",
+					Name:               "Llama 4 Maverick 17B 128E Instruct FP8",
+					CostPer1MIn:        2.0,
+					CostPer1MOut:       8.0,
+					CostPer1MInCached:  0.0,
+					CostPer1MOutCached: 0.0,
+					ContextWindow:      128000,
+					DefaultMaxTokens:   32000,
+					CanReason:          true,
+					SupportsImages:     true,
+				},
+				{
+					ID:                 "Llama-3.3-8B-Instruct",
+					Name:               "Llama 3.3 8B Instruct",
+					CostPer1MIn:        0.5,
+					CostPer1MOut:       2.0,
+					CostPer1MInCached:  0.0,
+					CostPer1MOutCached: 0.0,
+					ContextWindow:      128000,
+					DefaultMaxTokens:   16000,
+					CanReason:          true,
+					SupportsImages:     false,
+				},
+			},
+		},
 	}
 }

internal/fur/provider/provider.go πŸ”—

@@ -13,6 +13,7 @@ const (
 	TypeBedrock   Type = "bedrock"
 	TypeVertexAI  Type = "vertexai"
 	TypeXAI       Type = "xai"
+	TypeLlama     Type = "llama"
 )
 
 // InferenceProvider represents the inference provider identifier.
@@ -28,6 +29,7 @@ const (
 	InferenceProviderVertexAI   InferenceProvider = "vertexai"
 	InferenceProviderXAI        InferenceProvider = "xai"
 	InferenceProviderOpenRouter InferenceProvider = "openrouter"
+	InferenceProviderLlama      InferenceProvider = "llama"
 )
 
 // Provider represents an AI provider configuration.
@@ -69,5 +71,6 @@ func KnownProviders() []InferenceProvider {
 		InferenceProviderVertexAI,
 		InferenceProviderXAI,
 		InferenceProviderOpenRouter,
+		InferenceProviderLlama,
 	}
 }

internal/llm/provider/llama.go πŸ”—

@@ -0,0 +1,279 @@
+package provider
+
+import (
+	"context"
+	"encoding/json"
+
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/llm/tools"
+	"github.com/charmbracelet/crush/internal/logging"
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/openai/openai-go"
+	"github.com/openai/openai-go/option"
+	"github.com/openai/openai-go/shared"
+)
+
+type LlamaClient ProviderClient
+
+func newLlamaClient(opts providerClientOptions) LlamaClient {
+	openaiClientOptions := []option.RequestOption{}
+	if opts.apiKey != "" {
+		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
+	}
+	openaiClientOptions = append(openaiClientOptions, option.WithBaseURL("https://api.llama.com/compat/v1/"))
+	if opts.extraHeaders != nil {
+		for key, value := range opts.extraHeaders {
+			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
+		}
+	}
+	return &llamaClient{
+		providerOptions: opts,
+		client:          openai.NewClient(openaiClientOptions...),
+	}
+}
+
+type llamaClient struct {
+	providerOptions providerClientOptions
+	client          openai.Client
+}
+
+func (l *llamaClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+	openaiMessages := l.convertMessages(messages)
+	openaiTools := l.convertTools(tools)
+	params := l.preparedParams(openaiMessages, openaiTools)
+	cfg := config.Get()
+	if cfg.Options.Debug {
+		jsonData, _ := json.Marshal(params)
+		logging.Debug("Prepared messages", "messages", string(jsonData))
+	}
+	attempts := 0
+	for {
+		attempts++
+		openaiResponse, err := l.client.Chat.Completions.New(ctx, params)
+		if err != nil {
+			return nil, err
+		}
+		content := ""
+		if openaiResponse.Choices[0].Message.Content != "" {
+			content = openaiResponse.Choices[0].Message.Content
+		}
+		toolCalls := l.toolCalls(*openaiResponse)
+		finishReason := l.finishReason(string(openaiResponse.Choices[0].FinishReason))
+		if len(toolCalls) > 0 {
+			finishReason = message.FinishReasonToolUse
+		}
+		return &ProviderResponse{
+			Content:      content,
+			ToolCalls:    toolCalls,
+			Usage:        l.usage(*openaiResponse),
+			FinishReason: finishReason,
+		}, nil
+	}
+}
+
+func (l *llamaClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+	openaiMessages := l.convertMessages(messages)
+	openaiTools := l.convertTools(tools)
+	params := l.preparedParams(openaiMessages, openaiTools)
+	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
+		IncludeUsage: openai.Bool(true),
+	}
+	cfg := config.Get()
+	if cfg.Options.Debug {
+		jsonData, _ := json.Marshal(params)
+		logging.Debug("Prepared messages", "messages", string(jsonData))
+	}
+	eventChan := make(chan ProviderEvent)
+	go func() {
+		attempts := 0
+		acc := openai.ChatCompletionAccumulator{}
+		currentContent := ""
+		toolCalls := make([]message.ToolCall, 0)
+		for {
+			attempts++
+			openaiStream := l.client.Chat.Completions.NewStreaming(ctx, params)
+			for openaiStream.Next() {
+				chunk := openaiStream.Current()
+				acc.AddChunk(chunk)
+				for _, choice := range chunk.Choices {
+					if choice.Delta.Content != "" {
+						currentContent += choice.Delta.Content
+					}
+				}
+				eventChan <- ProviderEvent{Type: EventContentDelta, Content: currentContent}
+			}
+			if err := openaiStream.Err(); err != nil {
+				eventChan <- ProviderEvent{Type: EventError, Error: err}
+				return
+			}
+			toolCalls = l.toolCalls(acc.ChatCompletion)
+			finishReason := l.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
+			if len(toolCalls) > 0 {
+				finishReason = message.FinishReasonToolUse
+			}
+			eventChan <- ProviderEvent{
+				Type: EventComplete,
+				Response: &ProviderResponse{
+					Content:      currentContent,
+					ToolCalls:    toolCalls,
+					Usage:        l.usage(acc.ChatCompletion),
+					FinishReason: finishReason,
+				},
+			}
+			return
+		}
+	}()
+	return eventChan
+}
+
+func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
+	// Copied from openaiClient
+	openaiMessages = append(openaiMessages, openai.SystemMessage(l.providerOptions.systemMessage))
+	for _, msg := range messages {
+		switch msg.Role {
+		case message.User:
+			var content []openai.ChatCompletionContentPartUnionParam
+			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
+			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
+			for _, binaryContent := range msg.BinaryContent() {
+				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String("llama")}
+				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
+				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
+			}
+			openaiMessages = append(openaiMessages, openai.UserMessage(content))
+		case message.Assistant:
+			assistantMsg := openai.ChatCompletionAssistantMessageParam{
+				Role: "assistant",
+			}
+			if msg.Content().String() != "" {
+				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
+					OfString: openai.String(msg.Content().String()),
+				}
+			}
+			if len(msg.ToolCalls()) > 0 {
+				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
+				for i, call := range msg.ToolCalls() {
+					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
+						ID:   call.ID,
+						Type: "function",
+						Function: openai.ChatCompletionMessageToolCallFunctionParam{
+							Name:      call.Name,
+							Arguments: call.Input,
+						},
+					}
+				}
+			}
+			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{OfAssistant: &assistantMsg})
+		case message.Tool:
+			for _, result := range msg.ToolResults() {
+				openaiMessages = append(openaiMessages, openai.ToolMessage(result.Content, result.ToolCallID))
+			}
+		}
+	}
+	return
+}
+
+func (l *llamaClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
+	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
+	for i, tool := range tools {
+		info := tool.Info()
+		openaiTools[i] = openai.ChatCompletionToolParam{
+			Function: openai.FunctionDefinitionParam{
+				Name:        info.Name,
+				Description: openai.String(info.Description),
+				Parameters: openai.FunctionParameters{
+					"type":       "object",
+					"properties": info.Parameters,
+					"required":   info.Required,
+				},
+			},
+		}
+	}
+	return openaiTools
+}
+
+func (l *llamaClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
+	model := l.providerOptions.model(l.providerOptions.modelType)
+	cfg := config.Get()
+	modelConfig := cfg.Models.Large
+	if l.providerOptions.modelType == config.SmallModel {
+		modelConfig = cfg.Models.Small
+	}
+	reasoningEffort := model.ReasoningEffort
+	if modelConfig.ReasoningEffort != "" {
+		reasoningEffort = modelConfig.ReasoningEffort
+	}
+	params := openai.ChatCompletionNewParams{
+		Model:    openai.ChatModel(model.ID),
+		Messages: messages,
+		Tools:    tools,
+	}
+	maxTokens := model.DefaultMaxTokens
+	if modelConfig.MaxTokens > 0 {
+		maxTokens = modelConfig.MaxTokens
+	}
+	if l.providerOptions.maxTokens > 0 {
+		maxTokens = l.providerOptions.maxTokens
+	}
+	if model.CanReason {
+		params.MaxCompletionTokens = openai.Int(maxTokens)
+		switch reasoningEffort {
+		case "low":
+			params.ReasoningEffort = shared.ReasoningEffortLow
+		case "medium":
+			params.ReasoningEffort = shared.ReasoningEffortMedium
+		case "high":
+			params.ReasoningEffort = shared.ReasoningEffortHigh
+		default:
+			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
+		}
+	} else {
+		params.MaxTokens = openai.Int(maxTokens)
+	}
+	return params
+}
+
+func (l *llamaClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
+	var toolCalls []message.ToolCall
+	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
+		for _, call := range completion.Choices[0].Message.ToolCalls {
+			toolCall := message.ToolCall{
+				ID:       call.ID,
+				Name:     call.Function.Name,
+				Input:    call.Function.Arguments,
+				Type:     "function",
+				Finished: true,
+			}
+			toolCalls = append(toolCalls, toolCall)
+		}
+	}
+	return toolCalls
+}
+
+func (l *llamaClient) finishReason(reason string) message.FinishReason {
+	switch reason {
+	case "stop":
+		return message.FinishReasonEndTurn
+	case "length":
+		return message.FinishReasonMaxTokens
+	case "tool_calls":
+		return message.FinishReasonToolUse
+	default:
+		return message.FinishReasonUnknown
+	}
+}
+
+func (l *llamaClient) usage(completion openai.ChatCompletion) TokenUsage {
+	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
+	inputTokens := completion.Usage.PromptTokens - cachedTokens
+	return TokenUsage{
+		InputTokens:         inputTokens,
+		OutputTokens:        completion.Usage.CompletionTokens,
+		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
+		CacheReadTokens:     cachedTokens,
+	}
+}
+
+func (l *llamaClient) Model() config.Model {
+	return l.providerOptions.model(l.providerOptions.modelType)
+}

internal/llm/provider/provider.go πŸ”—

@@ -189,6 +189,11 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
 			options: clientOptions,
 			client:  newOpenAIClient(clientOptions),
 		}, nil
+	case provider.TypeLlama:
+		return &baseProvider[LlamaClient]{
+			options: clientOptions,
+			client:  newLlamaClient(clientOptions),
+		}, nil
 	}
 	return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
 }