Detailed changes
@@ -1,4 +1,4 @@
-> [!WARNING]
+b> [!WARNING]
> π§ This is a pre-release under heavy, active development. Things are still in flux but weβre excited to share early progress.
# Crush
@@ -44,6 +44,7 @@ providers.
| `AZURE_OPENAI_ENDPOINT` | Azure OpenAI models |
| `AZURE_OPENAI_API_KEY` | Azure OpenAI models (optional when using Entra ID) |
| `AZURE_OPENAI_API_VERSION` | Azure OpenAI models |
+| `LLAMA_API_KEY` | Llama API |
## License
@@ -498,7 +498,14 @@
"mistralai/mistral-7b-instruct-v0.1",
"openai/gpt-3.5-turbo-16k",
"openai/gpt-4",
- "openai/gpt-4-0314"
+ "openai/gpt-4-0314",
+ "Llama-4-Maverick-17B-128E-Instruct-FP8",
+ "Llama-4-Scout-17B-128E-Instruct-FP8",
+ "Llama-3.3-70B-Instruct",
+ "Llama-3.3-8B-Instruct",
+ "Cerebras-Llama-4-Maverick-17B-128E-Instruct",
+ "Cerebras-Llama-4-Scout-17B-16E-Instruct",
+ "Groq-Llama-4-Maverick-17B-128E-Instruct"
],
"title": "Model ID",
"description": "ID of the preferred model"
@@ -513,7 +520,8 @@
"bedrock",
"vertex",
"xai",
- "openrouter"
+ "openrouter",
+ "llama"
],
"title": "Provider",
"description": "Provider for the preferred model"
@@ -569,7 +577,8 @@
"bedrock",
"vertex",
"xai",
- "openrouter"
+ "openrouter",
+ "llama"
],
"title": "Provider ID",
"description": "Unique identifier for the provider"
@@ -685,6 +685,12 @@ func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig
ID: providerID,
ProviderType: provider.TypeVertexAI,
}
+ case provider.InferenceProviderLlama:
+ return ProviderConfig{
+ ID: providerID,
+ ProviderType: provider.TypeLlama,
+ BaseURL: "https://api.llama.com/compat/v1",
+ }
default:
return ProviderConfig{
ID: providerID,
@@ -1061,6 +1067,7 @@ func (c *Config) validateProviders(errors *ValidationErrors) {
provider.TypeBedrock,
provider.TypeVertexAI,
provider.TypeXAI,
+ provider.TypeLlama, // Added Llama
}
for providerID, providerConfig := range c.Providers {
@@ -21,6 +21,7 @@ func reset() {
"GEMINI_API_KEY",
"XAI_API_KEY",
"OPENROUTER_API_KEY",
+ "LLAMA_API_KEY",
// Google Cloud / VertexAI
"GOOGLE_GENAI_USE_VERTEXAI",
@@ -405,11 +406,12 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) {
os.Setenv("GEMINI_API_KEY", "test-gemini-key")
os.Setenv("XAI_API_KEY", "test-xai-key")
os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key")
+ os.Setenv("LLAMA_API_KEY", "test-llama-key")
cfg, err := Init(cwdDir, false)
require.NoError(t, err)
- assert.Len(t, cfg.Providers, 5)
+ assert.Len(t, cfg.Providers, 6)
anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic]
assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey)
@@ -431,6 +433,11 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) {
assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey)
assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType)
assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL)
+
+ llamaProvider := cfg.Providers[provider.InferenceProviderLlama]
+ assert.Equal(t, "test-llama-key", llamaProvider.APIKey)
+ assert.Equal(t, provider.TypeLlama, llamaProvider.ProviderType)
+ assert.Equal(t, "https://api.llama.com/compat/v1", llamaProvider.BaseURL)
}
func TestEnvVars_PartialEnvironmentVariables(t *testing.T) {
@@ -289,5 +289,40 @@ func MockProviders() []provider.Provider {
},
},
},
+ {
+ Name: "Llama API",
+ ID: provider.InferenceProviderLlama,
+ APIKey: "$LLAMA_API_KEY",
+ APIEndpoint: "https://api.llama.com/compat/v1",
+ Type: provider.TypeLlama,
+ DefaultLargeModelID: "Llama-4-Maverick-17B-128E-Instruct-FP8",
+ DefaultSmallModelID: "Llama-3.3-8B-Instruct",
+ Models: []provider.Model{
+ {
+ ID: "Llama-4-Maverick-17B-128E-Instruct-FP8",
+ Name: "Llama 4 Maverick 17B 128E Instruct FP8",
+ CostPer1MIn: 2.0,
+ CostPer1MOut: 8.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 32000,
+ CanReason: true,
+ SupportsImages: true,
+ },
+ {
+ ID: "Llama-3.3-8B-Instruct",
+ Name: "Llama 3.3 8B Instruct",
+ CostPer1MIn: 0.5,
+ CostPer1MOut: 2.0,
+ CostPer1MInCached: 0.0,
+ CostPer1MOutCached: 0.0,
+ ContextWindow: 128000,
+ DefaultMaxTokens: 16000,
+ CanReason: true,
+ SupportsImages: false,
+ },
+ },
+ },
}
}
@@ -13,6 +13,7 @@ const (
TypeBedrock Type = "bedrock"
TypeVertexAI Type = "vertexai"
TypeXAI Type = "xai"
+ TypeLlama Type = "llama"
)
// InferenceProvider represents the inference provider identifier.
@@ -28,6 +29,7 @@ const (
InferenceProviderVertexAI InferenceProvider = "vertexai"
InferenceProviderXAI InferenceProvider = "xai"
InferenceProviderOpenRouter InferenceProvider = "openrouter"
+ InferenceProviderLlama InferenceProvider = "llama"
)
// Provider represents an AI provider configuration.
@@ -69,5 +71,6 @@ func KnownProviders() []InferenceProvider {
InferenceProviderVertexAI,
InferenceProviderXAI,
InferenceProviderOpenRouter,
+ InferenceProviderLlama,
}
}
@@ -0,0 +1,279 @@
+package provider
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/llm/tools"
+ "github.com/charmbracelet/crush/internal/logging"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/openai/openai-go"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/shared"
+)
+
+type LlamaClient ProviderClient
+
+func newLlamaClient(opts providerClientOptions) LlamaClient {
+ openaiClientOptions := []option.RequestOption{}
+ if opts.apiKey != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
+ }
+ openaiClientOptions = append(openaiClientOptions, option.WithBaseURL("https://api.llama.com/compat/v1/"))
+ if opts.extraHeaders != nil {
+ for key, value := range opts.extraHeaders {
+ openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
+ }
+ }
+ return &llamaClient{
+ providerOptions: opts,
+ client: openai.NewClient(openaiClientOptions...),
+ }
+}
+
+type llamaClient struct {
+ providerOptions providerClientOptions
+ client openai.Client
+}
+
+func (l *llamaClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
+ openaiMessages := l.convertMessages(messages)
+ openaiTools := l.convertTools(tools)
+ params := l.preparedParams(openaiMessages, openaiTools)
+ cfg := config.Get()
+ if cfg.Options.Debug {
+ jsonData, _ := json.Marshal(params)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ attempts := 0
+ for {
+ attempts++
+ openaiResponse, err := l.client.Chat.Completions.New(ctx, params)
+ if err != nil {
+ return nil, err
+ }
+ content := ""
+ if openaiResponse.Choices[0].Message.Content != "" {
+ content = openaiResponse.Choices[0].Message.Content
+ }
+ toolCalls := l.toolCalls(*openaiResponse)
+ finishReason := l.finishReason(string(openaiResponse.Choices[0].FinishReason))
+ if len(toolCalls) > 0 {
+ finishReason = message.FinishReasonToolUse
+ }
+ return &ProviderResponse{
+ Content: content,
+ ToolCalls: toolCalls,
+ Usage: l.usage(*openaiResponse),
+ FinishReason: finishReason,
+ }, nil
+ }
+}
+
+func (l *llamaClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
+ openaiMessages := l.convertMessages(messages)
+ openaiTools := l.convertTools(tools)
+ params := l.preparedParams(openaiMessages, openaiTools)
+ params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
+ IncludeUsage: openai.Bool(true),
+ }
+ cfg := config.Get()
+ if cfg.Options.Debug {
+ jsonData, _ := json.Marshal(params)
+ logging.Debug("Prepared messages", "messages", string(jsonData))
+ }
+ eventChan := make(chan ProviderEvent)
+ go func() {
+ attempts := 0
+ acc := openai.ChatCompletionAccumulator{}
+ currentContent := ""
+ toolCalls := make([]message.ToolCall, 0)
+ for {
+ attempts++
+ openaiStream := l.client.Chat.Completions.NewStreaming(ctx, params)
+ for openaiStream.Next() {
+ chunk := openaiStream.Current()
+ acc.AddChunk(chunk)
+ for _, choice := range chunk.Choices {
+ if choice.Delta.Content != "" {
+ currentContent += choice.Delta.Content
+ }
+ }
+ eventChan <- ProviderEvent{Type: EventContentDelta, Content: currentContent}
+ }
+ if err := openaiStream.Err(); err != nil {
+ eventChan <- ProviderEvent{Type: EventError, Error: err}
+ return
+ }
+ toolCalls = l.toolCalls(acc.ChatCompletion)
+ finishReason := l.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
+ if len(toolCalls) > 0 {
+ finishReason = message.FinishReasonToolUse
+ }
+ eventChan <- ProviderEvent{
+ Type: EventComplete,
+ Response: &ProviderResponse{
+ Content: currentContent,
+ ToolCalls: toolCalls,
+ Usage: l.usage(acc.ChatCompletion),
+ FinishReason: finishReason,
+ },
+ }
+ return
+ }
+ }()
+ return eventChan
+}
+
+func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
+ // Copied from openaiClient
+ openaiMessages = append(openaiMessages, openai.SystemMessage(l.providerOptions.systemMessage))
+ for _, msg := range messages {
+ switch msg.Role {
+ case message.User:
+ var content []openai.ChatCompletionContentPartUnionParam
+ textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
+ content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
+ for _, binaryContent := range msg.BinaryContent() {
+ imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String("llama")}
+ imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
+ content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
+ }
+ openaiMessages = append(openaiMessages, openai.UserMessage(content))
+ case message.Assistant:
+ assistantMsg := openai.ChatCompletionAssistantMessageParam{
+ Role: "assistant",
+ }
+ if msg.Content().String() != "" {
+ assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
+ OfString: openai.String(msg.Content().String()),
+ }
+ }
+ if len(msg.ToolCalls()) > 0 {
+ assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
+ for i, call := range msg.ToolCalls() {
+ assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
+ ID: call.ID,
+ Type: "function",
+ Function: openai.ChatCompletionMessageToolCallFunctionParam{
+ Name: call.Name,
+ Arguments: call.Input,
+ },
+ }
+ }
+ }
+ openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{OfAssistant: &assistantMsg})
+ case message.Tool:
+ for _, result := range msg.ToolResults() {
+ openaiMessages = append(openaiMessages, openai.ToolMessage(result.Content, result.ToolCallID))
+ }
+ }
+ }
+ return
+}
+
+func (l *llamaClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
+ openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
+ for i, tool := range tools {
+ info := tool.Info()
+ openaiTools[i] = openai.ChatCompletionToolParam{
+ Function: openai.FunctionDefinitionParam{
+ Name: info.Name,
+ Description: openai.String(info.Description),
+ Parameters: openai.FunctionParameters{
+ "type": "object",
+ "properties": info.Parameters,
+ "required": info.Required,
+ },
+ },
+ }
+ }
+ return openaiTools
+}
+
+func (l *llamaClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
+ model := l.providerOptions.model(l.providerOptions.modelType)
+ cfg := config.Get()
+ modelConfig := cfg.Models.Large
+ if l.providerOptions.modelType == config.SmallModel {
+ modelConfig = cfg.Models.Small
+ }
+ reasoningEffort := model.ReasoningEffort
+ if modelConfig.ReasoningEffort != "" {
+ reasoningEffort = modelConfig.ReasoningEffort
+ }
+ params := openai.ChatCompletionNewParams{
+ Model: openai.ChatModel(model.ID),
+ Messages: messages,
+ Tools: tools,
+ }
+ maxTokens := model.DefaultMaxTokens
+ if modelConfig.MaxTokens > 0 {
+ maxTokens = modelConfig.MaxTokens
+ }
+ if l.providerOptions.maxTokens > 0 {
+ maxTokens = l.providerOptions.maxTokens
+ }
+ if model.CanReason {
+ params.MaxCompletionTokens = openai.Int(maxTokens)
+ switch reasoningEffort {
+ case "low":
+ params.ReasoningEffort = shared.ReasoningEffortLow
+ case "medium":
+ params.ReasoningEffort = shared.ReasoningEffortMedium
+ case "high":
+ params.ReasoningEffort = shared.ReasoningEffortHigh
+ default:
+ params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
+ }
+ } else {
+ params.MaxTokens = openai.Int(maxTokens)
+ }
+ return params
+}
+
+func (l *llamaClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
+ var toolCalls []message.ToolCall
+ if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
+ for _, call := range completion.Choices[0].Message.ToolCalls {
+ toolCall := message.ToolCall{
+ ID: call.ID,
+ Name: call.Function.Name,
+ Input: call.Function.Arguments,
+ Type: "function",
+ Finished: true,
+ }
+ toolCalls = append(toolCalls, toolCall)
+ }
+ }
+ return toolCalls
+}
+
+func (l *llamaClient) finishReason(reason string) message.FinishReason {
+ switch reason {
+ case "stop":
+ return message.FinishReasonEndTurn
+ case "length":
+ return message.FinishReasonMaxTokens
+ case "tool_calls":
+ return message.FinishReasonToolUse
+ default:
+ return message.FinishReasonUnknown
+ }
+}
+
+func (l *llamaClient) usage(completion openai.ChatCompletion) TokenUsage {
+ cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
+ inputTokens := completion.Usage.PromptTokens - cachedTokens
+ return TokenUsage{
+ InputTokens: inputTokens,
+ OutputTokens: completion.Usage.CompletionTokens,
+ CacheCreationTokens: 0, // OpenAI doesn't provide this directly
+ CacheReadTokens: cachedTokens,
+ }
+}
+
+func (l *llamaClient) Model() config.Model {
+ return l.providerOptions.model(l.providerOptions.modelType)
+}
@@ -189,6 +189,11 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
options: clientOptions,
client: newOpenAIClient(clientOptions),
}, nil
+ case provider.TypeLlama:
+ return &baseProvider[LlamaClient]{
+ options: clientOptions,
+ client: newLlamaClient(clientOptions),
+ }, nil
}
return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
}