diff --git a/README.md b/README.md index b4e1a814071f5189a74f800e9d21cefcc03f4e2a..91c0df07539e98433c31bb00900cb44fe46f8c55 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -> [!WARNING] +b> [!WARNING] > 🚧 This is a pre-release under heavy, active development. Things are still in flux but we’re excited to share early progress. # Crush @@ -44,6 +44,7 @@ providers. | `AZURE_OPENAI_ENDPOINT` | Azure OpenAI models | | `AZURE_OPENAI_API_KEY` | Azure OpenAI models (optional when using Entra ID) | | `AZURE_OPENAI_API_VERSION` | Azure OpenAI models | +| `LLAMA_API_KEY` | Llama API | ## License diff --git a/crush-schema.json b/crush-schema.json index ea356c0e585b8a243ee1110d68264c0f2301752f..0800aad162c6d9539ab370cac56e3ce772319f16 100644 --- a/crush-schema.json +++ b/crush-schema.json @@ -498,7 +498,14 @@ "mistralai/mistral-7b-instruct-v0.1", "openai/gpt-3.5-turbo-16k", "openai/gpt-4", - "openai/gpt-4-0314" + "openai/gpt-4-0314", + "Llama-4-Maverick-17B-128E-Instruct-FP8", + "Llama-4-Scout-17B-128E-Instruct-FP8", + "Llama-3.3-70B-Instruct", + "Llama-3.3-8B-Instruct", + "Cerebras-Llama-4-Maverick-17B-128E-Instruct", + "Cerebras-Llama-4-Scout-17B-16E-Instruct", + "Groq-Llama-4-Maverick-17B-128E-Instruct" ], "title": "Model ID", "description": "ID of the preferred model" @@ -513,7 +520,8 @@ "bedrock", "vertex", "xai", - "openrouter" + "openrouter", + "llama" ], "title": "Provider", "description": "Provider for the preferred model" @@ -569,7 +577,8 @@ "bedrock", "vertex", "xai", - "openrouter" + "openrouter", + "llama" ], "title": "Provider ID", "description": "Unique identifier for the provider" diff --git a/internal/config/config.go b/internal/config/config.go index 544d3ece6f7b653787d06ebc1ac2ff2d7a48cf3f..65855eb443d34a2585bb86ce461dd36015e776ea 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -685,6 +685,12 @@ func providerDefaultConfig(providerID provider.InferenceProvider) ProviderConfig ID: providerID, ProviderType: provider.TypeVertexAI, } + case provider.InferenceProviderLlama: + return ProviderConfig{ + ID: providerID, + ProviderType: provider.TypeLlama, + BaseURL: "https://api.llama.com/compat/v1", + } default: return ProviderConfig{ ID: providerID, @@ -1061,6 +1067,7 @@ func (c *Config) validateProviders(errors *ValidationErrors) { provider.TypeBedrock, provider.TypeVertexAI, provider.TypeXAI, + provider.TypeLlama, // Added Llama } for providerID, providerConfig := range c.Providers { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index de8024bdd126bd46e13eb6ece102c9de69458266..1ff78510a0adeafda1fa985206a1947ba6aa9fbc 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -21,6 +21,7 @@ func reset() { "GEMINI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", + "LLAMA_API_KEY", // Google Cloud / VertexAI "GOOGLE_GENAI_USE_VERTEXAI", @@ -405,11 +406,12 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { os.Setenv("GEMINI_API_KEY", "test-gemini-key") os.Setenv("XAI_API_KEY", "test-xai-key") os.Setenv("OPENROUTER_API_KEY", "test-openrouter-key") + os.Setenv("LLAMA_API_KEY", "test-llama-key") cfg, err := Init(cwdDir, false) require.NoError(t, err) - assert.Len(t, cfg.Providers, 5) + assert.Len(t, cfg.Providers, 6) anthropicProvider := cfg.Providers[provider.InferenceProviderAnthropic] assert.Equal(t, "test-anthropic-key", anthropicProvider.APIKey) @@ -431,6 +433,11 @@ func TestEnvVars_AllSupportedAPIKeys(t *testing.T) { assert.Equal(t, "test-openrouter-key", openrouterProvider.APIKey) assert.Equal(t, provider.TypeOpenAI, openrouterProvider.ProviderType) assert.Equal(t, "https://openrouter.ai/api/v1", openrouterProvider.BaseURL) + + llamaProvider := cfg.Providers[provider.InferenceProviderLlama] + assert.Equal(t, "test-llama-key", llamaProvider.APIKey) + assert.Equal(t, provider.TypeLlama, llamaProvider.ProviderType) + assert.Equal(t, "https://api.llama.com/compat/v1", llamaProvider.BaseURL) } func TestEnvVars_PartialEnvironmentVariables(t *testing.T) { diff --git a/internal/config/provider_mock.go b/internal/config/provider_mock.go index 801afdd8d6c9891eb47fa53294c047917b031637..e9a236f5b32951c2d79eb4303865faaedcc8b9c3 100644 --- a/internal/config/provider_mock.go +++ b/internal/config/provider_mock.go @@ -289,5 +289,40 @@ func MockProviders() []provider.Provider { }, }, }, + { + Name: "Llama API", + ID: provider.InferenceProviderLlama, + APIKey: "$LLAMA_API_KEY", + APIEndpoint: "https://api.llama.com/compat/v1", + Type: provider.TypeLlama, + DefaultLargeModelID: "Llama-4-Maverick-17B-128E-Instruct-FP8", + DefaultSmallModelID: "Llama-3.3-8B-Instruct", + Models: []provider.Model{ + { + ID: "Llama-4-Maverick-17B-128E-Instruct-FP8", + Name: "Llama 4 Maverick 17B 128E Instruct FP8", + CostPer1MIn: 2.0, + CostPer1MOut: 8.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 32000, + CanReason: true, + SupportsImages: true, + }, + { + ID: "Llama-3.3-8B-Instruct", + Name: "Llama 3.3 8B Instruct", + CostPer1MIn: 0.5, + CostPer1MOut: 2.0, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + ContextWindow: 128000, + DefaultMaxTokens: 16000, + CanReason: true, + SupportsImages: false, + }, + }, + }, } } diff --git a/internal/fur/provider/provider.go b/internal/fur/provider/provider.go index e3c0f6209cbe71c239da104b38c3022e090599aa..2aad9c1a0b375d70a03a7a5b8dcd16a7e7a23c52 100644 --- a/internal/fur/provider/provider.go +++ b/internal/fur/provider/provider.go @@ -13,6 +13,7 @@ const ( TypeBedrock Type = "bedrock" TypeVertexAI Type = "vertexai" TypeXAI Type = "xai" + TypeLlama Type = "llama" ) // InferenceProvider represents the inference provider identifier. @@ -28,6 +29,7 @@ const ( InferenceProviderVertexAI InferenceProvider = "vertexai" InferenceProviderXAI InferenceProvider = "xai" InferenceProviderOpenRouter InferenceProvider = "openrouter" + InferenceProviderLlama InferenceProvider = "llama" ) // Provider represents an AI provider configuration. @@ -69,5 +71,6 @@ func KnownProviders() []InferenceProvider { InferenceProviderVertexAI, InferenceProviderXAI, InferenceProviderOpenRouter, + InferenceProviderLlama, } } diff --git a/internal/llm/provider/llama.go b/internal/llm/provider/llama.go new file mode 100644 index 0000000000000000000000000000000000000000..f0f13bab494fbcc52134c60633a4de459e8f5e17 --- /dev/null +++ b/internal/llm/provider/llama.go @@ -0,0 +1,279 @@ +package provider + +import ( + "context" + "encoding/json" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/logging" + "github.com/charmbracelet/crush/internal/message" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/shared" +) + +type LlamaClient ProviderClient + +func newLlamaClient(opts providerClientOptions) LlamaClient { + openaiClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) + } + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL("https://api.llama.com/compat/v1/")) + if opts.extraHeaders != nil { + for key, value := range opts.extraHeaders { + openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) + } + } + return &llamaClient{ + providerOptions: opts, + client: openai.NewClient(openaiClientOptions...), + } +} + +type llamaClient struct { + providerOptions providerClientOptions + client openai.Client +} + +func (l *llamaClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + openaiMessages := l.convertMessages(messages) + openaiTools := l.convertTools(tools) + params := l.preparedParams(openaiMessages, openaiTools) + cfg := config.Get() + if cfg.Options.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + for { + attempts++ + openaiResponse, err := l.client.Chat.Completions.New(ctx, params) + if err != nil { + return nil, err + } + content := "" + if openaiResponse.Choices[0].Message.Content != "" { + content = openaiResponse.Choices[0].Message.Content + } + toolCalls := l.toolCalls(*openaiResponse) + finishReason := l.finishReason(string(openaiResponse.Choices[0].FinishReason)) + if len(toolCalls) > 0 { + finishReason = message.FinishReasonToolUse + } + return &ProviderResponse{ + Content: content, + ToolCalls: toolCalls, + Usage: l.usage(*openaiResponse), + FinishReason: finishReason, + }, nil + } +} + +func (l *llamaClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + openaiMessages := l.convertMessages(messages) + openaiTools := l.convertTools(tools) + params := l.preparedParams(openaiMessages, openaiTools) + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + } + cfg := config.Get() + if cfg.Options.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + eventChan := make(chan ProviderEvent) + go func() { + attempts := 0 + acc := openai.ChatCompletionAccumulator{} + currentContent := "" + toolCalls := make([]message.ToolCall, 0) + for { + attempts++ + openaiStream := l.client.Chat.Completions.NewStreaming(ctx, params) + for openaiStream.Next() { + chunk := openaiStream.Current() + acc.AddChunk(chunk) + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + currentContent += choice.Delta.Content + } + } + eventChan <- ProviderEvent{Type: EventContentDelta, Content: currentContent} + } + if err := openaiStream.Err(); err != nil { + eventChan <- ProviderEvent{Type: EventError, Error: err} + return + } + toolCalls = l.toolCalls(acc.ChatCompletion) + finishReason := l.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)) + if len(toolCalls) > 0 { + finishReason = message.FinishReasonToolUse + } + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: l.usage(acc.ChatCompletion), + FinishReason: finishReason, + }, + } + return + } + }() + return eventChan +} + +func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { + // Copied from openaiClient + openaiMessages = append(openaiMessages, openai.SystemMessage(l.providerOptions.systemMessage)) + for _, msg := range messages { + switch msg.Role { + case message.User: + var content []openai.ChatCompletionContentPartUnionParam + textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()} + content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock}) + for _, binaryContent := range msg.BinaryContent() { + imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String("llama")} + imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL} + content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock}) + } + openaiMessages = append(openaiMessages, openai.UserMessage(content)) + case message.Assistant: + assistantMsg := openai.ChatCompletionAssistantMessageParam{ + Role: "assistant", + } + if msg.Content().String() != "" { + assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{ + OfString: openai.String(msg.Content().String()), + } + } + if len(msg.ToolCalls()) > 0 { + assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls())) + for i, call := range msg.ToolCalls() { + assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{ + ID: call.ID, + Type: "function", + Function: openai.ChatCompletionMessageToolCallFunctionParam{ + Name: call.Name, + Arguments: call.Input, + }, + } + } + } + openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{OfAssistant: &assistantMsg}) + case message.Tool: + for _, result := range msg.ToolResults() { + openaiMessages = append(openaiMessages, openai.ToolMessage(result.Content, result.ToolCallID)) + } + } + } + return +} + +func (l *llamaClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { + openaiTools := make([]openai.ChatCompletionToolParam, len(tools)) + for i, tool := range tools { + info := tool.Info() + openaiTools[i] = openai.ChatCompletionToolParam{ + Function: openai.FunctionDefinitionParam{ + Name: info.Name, + Description: openai.String(info.Description), + Parameters: openai.FunctionParameters{ + "type": "object", + "properties": info.Parameters, + "required": info.Required, + }, + }, + } + } + return openaiTools +} + +func (l *llamaClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + model := l.providerOptions.model(l.providerOptions.modelType) + cfg := config.Get() + modelConfig := cfg.Models.Large + if l.providerOptions.modelType == config.SmallModel { + modelConfig = cfg.Models.Small + } + reasoningEffort := model.ReasoningEffort + if modelConfig.ReasoningEffort != "" { + reasoningEffort = modelConfig.ReasoningEffort + } + params := openai.ChatCompletionNewParams{ + Model: openai.ChatModel(model.ID), + Messages: messages, + Tools: tools, + } + maxTokens := model.DefaultMaxTokens + if modelConfig.MaxTokens > 0 { + maxTokens = modelConfig.MaxTokens + } + if l.providerOptions.maxTokens > 0 { + maxTokens = l.providerOptions.maxTokens + } + if model.CanReason { + params.MaxCompletionTokens = openai.Int(maxTokens) + switch reasoningEffort { + case "low": + params.ReasoningEffort = shared.ReasoningEffortLow + case "medium": + params.ReasoningEffort = shared.ReasoningEffortMedium + case "high": + params.ReasoningEffort = shared.ReasoningEffortHigh + default: + params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort) + } + } else { + params.MaxTokens = openai.Int(maxTokens) + } + return params +} + +func (l *llamaClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall { + var toolCalls []message.ToolCall + if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { + for _, call := range completion.Choices[0].Message.ToolCalls { + toolCall := message.ToolCall{ + ID: call.ID, + Name: call.Function.Name, + Input: call.Function.Arguments, + Type: "function", + Finished: true, + } + toolCalls = append(toolCalls, toolCall) + } + } + return toolCalls +} + +func (l *llamaClient) finishReason(reason string) message.FinishReason { + switch reason { + case "stop": + return message.FinishReasonEndTurn + case "length": + return message.FinishReasonMaxTokens + case "tool_calls": + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown + } +} + +func (l *llamaClient) usage(completion openai.ChatCompletion) TokenUsage { + cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens + inputTokens := completion.Usage.PromptTokens - cachedTokens + return TokenUsage{ + InputTokens: inputTokens, + OutputTokens: completion.Usage.CompletionTokens, + CacheCreationTokens: 0, // OpenAI doesn't provide this directly + CacheReadTokens: cachedTokens, + } +} + +func (l *llamaClient) Model() config.Model { + return l.providerOptions.model(l.providerOptions.modelType) +} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 3ffbf86c00c5e3ca27f1b68965f4ff950f1f7454..c0eb3b18933bc0ff596e7fca423de34af79d9476 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -189,6 +189,11 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi options: clientOptions, client: newOpenAIClient(clientOptions), }, nil + case provider.TypeLlama: + return &baseProvider[LlamaClient]{ + options: clientOptions, + client: newLlamaClient(clientOptions), + }, nil } return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType) }