From 603a3e3c71de8d71cfc2ea308be2fe342bb6d25c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 1 May 2025 14:17:33 +0200 Subject: [PATCH] add xai support (#135) --- internal/config/config.go | 16 +++++++- internal/llm/models/models.go | 1 + internal/llm/models/xai.go | 61 +++++++++++++++++++++++++++++++ internal/llm/provider/openai.go | 13 ++----- internal/llm/provider/provider.go | 9 +++++ internal/tui/tui.go | 4 +- 6 files changed, 90 insertions(+), 14 deletions(-) create mode 100644 internal/llm/models/xai.go diff --git a/internal/config/config.go b/internal/config/config.go index 737487bfca5da8dc41a85c0558485c3e8fc2b792..5a74320d6d7da6b3e5e723dc2f57010659ffee45 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -242,6 +242,13 @@ func setProviderDefaults() { if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" { viper.SetDefault("providers.openrouter.apiKey", apiKey) } + if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" { + viper.SetDefault("providers.xai.apiKey", apiKey) + } + if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" { + // api-key may be empty when using Entra ID credentials – that's okay + viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY")) + } // Use this order to set the default models // 1. Anthropic @@ -292,6 +299,13 @@ func setProviderDefaults() { return } + if viper.Get("providers.xai.apiKey") != "" { + viper.SetDefault("agents.coder.model", models.XAIGrok3Beta) + viper.SetDefault("agents.task.model", models.XAIGrok3Beta) + viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta) + return + } + // AWS Bedrock configuration if hasAWSCredentials() { viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet) @@ -301,8 +315,6 @@ func setProviderDefaults() { } if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" { - // api-key may be empty when using Entra ID credentials – that's okay - viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY")) viper.SetDefault("agents.coder.model", models.AzureGPT41) viper.SetDefault("agents.task.model", models.AzureGPT41Mini) viper.SetDefault("agents.title.model", models.AzureGPT41Mini) diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index fd0a2b41ba6229fd614baa4440de40b8a6017ad1..1e1cbde50bfa45d5bc44f72af68d4b86b5d3db27 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -89,4 +89,5 @@ func init() { maps.Copy(SupportedModels, GroqModels) maps.Copy(SupportedModels, AzureModels) maps.Copy(SupportedModels, OpenRouterModels) + maps.Copy(SupportedModels, XAIModels) } diff --git a/internal/llm/models/xai.go b/internal/llm/models/xai.go new file mode 100644 index 0000000000000000000000000000000000000000..00caf3b89750c0789f75f6273d49e38a4cdf6282 --- /dev/null +++ b/internal/llm/models/xai.go @@ -0,0 +1,61 @@ +package models + +const ( + ProviderXAI ModelProvider = "xai" + + XAIGrok3Beta ModelID = "grok-3-beta" + XAIGrok3MiniBeta ModelID = "grok-3-mini-beta" + XAIGrok3FastBeta ModelID = "grok-3-fast-beta" + XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta" +) + +var XAIModels = map[ModelID]Model{ + XAIGrok3Beta: { + ID: XAIGrok3Beta, + Name: "Grok3 Beta", + Provider: ProviderXAI, + APIModel: "grok-3-beta", + CostPer1MIn: 3.0, + CostPer1MInCached: 0, + CostPer1MOut: 15, + CostPer1MOutCached: 0, + ContextWindow: 131_072, + DefaultMaxTokens: 20_000, + }, + XAIGrok3MiniBeta: { + ID: XAIGrok3MiniBeta, + Name: "Grok3 Mini Beta", + Provider: ProviderXAI, + APIModel: "grok-3-mini-beta", + CostPer1MIn: 0.3, + CostPer1MInCached: 0, + CostPer1MOut: 0.5, + CostPer1MOutCached: 0, + ContextWindow: 131_072, + DefaultMaxTokens: 20_000, + }, + XAIGrok3FastBeta: { + ID: XAIGrok3FastBeta, + Name: "Grok3 Fast Beta", + Provider: ProviderXAI, + APIModel: "grok-3-fast-beta", + CostPer1MIn: 5, + CostPer1MInCached: 0, + CostPer1MOut: 25, + CostPer1MOutCached: 0, + ContextWindow: 131_072, + DefaultMaxTokens: 20_000, + }, + XAiGrok3MiniFastBeta: { + ID: XAiGrok3MiniFastBeta, + Name: "Grok3 Mini Fast Beta", + Provider: ProviderXAI, + APIModel: "grok-3-mini-fast-beta", + CostPer1MIn: 0.6, + CostPer1MInCached: 0, + CostPer1MOut: 4.0, + CostPer1MOutCached: 0, + ContextWindow: 131_072, + DefaultMaxTokens: 20_000, + }, +} diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index b557df5358dc682e1acc580dd4fe22378efd1c7a..d68cfbc2d2ab6628f50631164b201ff8c178f38c 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -258,15 +258,6 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t chunk := openaiStream.Current() acc.AddChunk(chunk) - if tool, ok := acc.JustFinishedToolCall(); ok { - toolCalls = append(toolCalls, message.ToolCall{ - ID: tool.Id, - Name: tool.Name, - Input: tool.Arguments, - Type: "function", - }) - } - for _, choice := range chunk.Choices { if choice.Delta.Content != "" { eventChan <- ProviderEvent{ @@ -282,7 +273,9 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t if err == nil || errors.Is(err, io.EOF) { // Stream completed successfully finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)) - + if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 { + toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...) + } if len(toolCalls) > 0 { finishReason = message.FinishReasonToolUse } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 1545bc27ac12af64f83e4ae5ea86fc25ee9d49fc..cad11eeb395022641dc3f4dec5601b47bbbb5b50 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -132,6 +132,15 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption options: clientOptions, client: newOpenAIClient(clientOptions), }, nil + case models.ProviderXAI: + clientOptions.openaiOptions = append(clientOptions.openaiOptions, + WithOpenAIBaseURL("https://api.x.ai/v1"), + ) + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + case models.ProviderMock: // TODO: implement mock client for test panic("not implemented") diff --git a/internal/tui/tui.go b/internal/tui/tui.go index d68aaa2ee2a2e0a4508b1e3c151186d5c5e14803..b9297dbee5c54726e72b204ba36017581ad8a1eb 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -56,8 +56,8 @@ var keys = keyMap{ ), Models: key.NewBinding( - key.WithKeys("ctrl+m"), - key.WithHelp("ctrl+m", "model selection"), + key.WithKeys("ctrl+o"), + key.WithHelp("ctrl+o", "model selection"), ), SwitchTheme: key.NewBinding(