add xai support (#135)

Kujtim Hoxha created

Change summary

internal/config/config.go         | 16 +++++++-
internal/llm/models/models.go     |  1 
internal/llm/models/xai.go        | 61 +++++++++++++++++++++++++++++++++
internal/llm/provider/openai.go   | 13 +-----
internal/llm/provider/provider.go |  9 ++++
internal/tui/tui.go               |  4 +-
6 files changed, 90 insertions(+), 14 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -242,6 +242,13 @@ func setProviderDefaults() {
 	if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
 		viper.SetDefault("providers.openrouter.apiKey", apiKey)
 	}
+	if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" {
+		viper.SetDefault("providers.xai.apiKey", apiKey)
+	}
+	if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" {
+		// api-key may be empty when using Entra ID credentials – that's okay
+		viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
+	}
 
 	// Use this order to set the default models
 	// 1. Anthropic
@@ -292,6 +299,13 @@ func setProviderDefaults() {
 		return
 	}
 
+	if viper.Get("providers.xai.apiKey") != "" {
+		viper.SetDefault("agents.coder.model", models.XAIGrok3Beta)
+		viper.SetDefault("agents.task.model", models.XAIGrok3Beta)
+		viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta)
+		return
+	}
+
 	// AWS Bedrock configuration
 	if hasAWSCredentials() {
 		viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet)
@@ -301,8 +315,6 @@ func setProviderDefaults() {
 	}
 
 	if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" {
-		// api-key may be empty when using Entra ID credentials – that's okay
-		viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
 		viper.SetDefault("agents.coder.model", models.AzureGPT41)
 		viper.SetDefault("agents.task.model", models.AzureGPT41Mini)
 		viper.SetDefault("agents.title.model", models.AzureGPT41Mini)

internal/llm/models/models.go 🔗

@@ -89,4 +89,5 @@ func init() {
 	maps.Copy(SupportedModels, GroqModels)
 	maps.Copy(SupportedModels, AzureModels)
 	maps.Copy(SupportedModels, OpenRouterModels)
+	maps.Copy(SupportedModels, XAIModels)
 }

internal/llm/models/xai.go 🔗

@@ -0,0 +1,61 @@
+package models
+
+const (
+	ProviderXAI ModelProvider = "xai"
+
+	XAIGrok3Beta         ModelID = "grok-3-beta"
+	XAIGrok3MiniBeta     ModelID = "grok-3-mini-beta"
+	XAIGrok3FastBeta     ModelID = "grok-3-fast-beta"
+	XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
+)
+
+var XAIModels = map[ModelID]Model{
+	XAIGrok3Beta: {
+		ID:                 XAIGrok3Beta,
+		Name:               "Grok3 Beta",
+		Provider:           ProviderXAI,
+		APIModel:           "grok-3-beta",
+		CostPer1MIn:        3.0,
+		CostPer1MInCached:  0,
+		CostPer1MOut:       15,
+		CostPer1MOutCached: 0,
+		ContextWindow:      131_072,
+		DefaultMaxTokens:   20_000,
+	},
+	XAIGrok3MiniBeta: {
+		ID:                 XAIGrok3MiniBeta,
+		Name:               "Grok3 Mini Beta",
+		Provider:           ProviderXAI,
+		APIModel:           "grok-3-mini-beta",
+		CostPer1MIn:        0.3,
+		CostPer1MInCached:  0,
+		CostPer1MOut:       0.5,
+		CostPer1MOutCached: 0,
+		ContextWindow:      131_072,
+		DefaultMaxTokens:   20_000,
+	},
+	XAIGrok3FastBeta: {
+		ID:                 XAIGrok3FastBeta,
+		Name:               "Grok3 Fast Beta",
+		Provider:           ProviderXAI,
+		APIModel:           "grok-3-fast-beta",
+		CostPer1MIn:        5,
+		CostPer1MInCached:  0,
+		CostPer1MOut:       25,
+		CostPer1MOutCached: 0,
+		ContextWindow:      131_072,
+		DefaultMaxTokens:   20_000,
+	},
+	XAiGrok3MiniFastBeta: {
+		ID:                 XAiGrok3MiniFastBeta,
+		Name:               "Grok3 Mini Fast Beta",
+		Provider:           ProviderXAI,
+		APIModel:           "grok-3-mini-fast-beta",
+		CostPer1MIn:        0.6,
+		CostPer1MInCached:  0,
+		CostPer1MOut:       4.0,
+		CostPer1MOutCached: 0,
+		ContextWindow:      131_072,
+		DefaultMaxTokens:   20_000,
+	},
+}

internal/llm/provider/openai.go 🔗

@@ -258,15 +258,6 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 				chunk := openaiStream.Current()
 				acc.AddChunk(chunk)
 
-				if tool, ok := acc.JustFinishedToolCall(); ok {
-					toolCalls = append(toolCalls, message.ToolCall{
-						ID:    tool.Id,
-						Name:  tool.Name,
-						Input: tool.Arguments,
-						Type:  "function",
-					})
-				}
-
 				for _, choice := range chunk.Choices {
 					if choice.Delta.Content != "" {
 						eventChan <- ProviderEvent{
@@ -282,7 +273,9 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 			if err == nil || errors.Is(err, io.EOF) {
 				// Stream completed successfully
 				finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
-
+				if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
+					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
+				}
 				if len(toolCalls) > 0 {
 					finishReason = message.FinishReasonToolUse
 				}

internal/llm/provider/provider.go 🔗

@@ -132,6 +132,15 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption
 			options: clientOptions,
 			client:  newOpenAIClient(clientOptions),
 		}, nil
+	case models.ProviderXAI:
+		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
+			WithOpenAIBaseURL("https://api.x.ai/v1"),
+		)
+		return &baseProvider[OpenAIClient]{
+			options: clientOptions,
+			client:  newOpenAIClient(clientOptions),
+		}, nil
+
 	case models.ProviderMock:
 		// TODO: implement mock client for test
 		panic("not implemented")

internal/tui/tui.go 🔗

@@ -56,8 +56,8 @@ var keys = keyMap{
 	),
 
 	Models: key.NewBinding(
-		key.WithKeys("ctrl+m"),
-		key.WithHelp("ctrl+m", "model selection"),
+		key.WithKeys("ctrl+o"),
+		key.WithHelp("ctrl+o", "model selection"),
 	),
 
 	SwitchTheme: key.NewBinding(