Merge pull request #505 from charmbracelet/fix-openai-provider

Kujtim Hoxha created

Fix openai compatible api  for certain providers

Change summary

go.mod                                            |   4 
go.sum                                            |   8 
internal/llm/provider/openai.go                   | 138 ++++++++--------
internal/tui/components/chat/messages/messages.go |   3 
internal/tui/components/chat/splash/splash.go     |  22 --
internal/tui/components/dialogs/models/list.go    |  15 +
internal/tui/components/dialogs/models/models.go  |  19 --
7 files changed, 90 insertions(+), 119 deletions(-)

Detailed changes

go.mod 🔗

@@ -14,7 +14,7 @@ require (
 	github.com/charlievieth/fastwalk v1.0.11
 	github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5
 	github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250730165737-56ff7146d52d
-	github.com/charmbracelet/catwalk v0.3.5
+	github.com/charmbracelet/catwalk v0.4.5
 	github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674
 	github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
 	github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250721205738-ea66aa652ee0
@@ -56,7 +56,7 @@ require (
 	github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
 	go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
 	golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
-	golang.org/x/oauth2 v0.25.0 // indirect
+	golang.org/x/oauth2 v0.30.0 // indirect
 	golang.org/x/time v0.8.0 // indirect
 	google.golang.org/api v0.211.0 // indirect
 )

go.sum 🔗

@@ -78,8 +78,8 @@ github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5
 github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5/go.mod h1:6HamsBKWqEC/FVHuQMHgQL+knPyvHH55HwJDHl/adMw=
 github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250730165737-56ff7146d52d h1:YMXLZHSo8DjytVY/b5dK8LDuyQsVUmBK3ydQMpu2Ui4=
 github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250730165737-56ff7146d52d/go.mod h1:XIQ1qQfRph6Z5o2EikCydjumo0oDInQySRHuPATzbZc=
-github.com/charmbracelet/catwalk v0.3.5 h1:ChMvA5ooTNZhDKFagmGNQgIZvZp8XjpdaJ+cDmhgCgA=
-github.com/charmbracelet/catwalk v0.3.5/go.mod h1:gUUCqqZ8bk4D7ZzGTu3I77k7cC2x4exRuJBN1H2u2pc=
+github.com/charmbracelet/catwalk v0.4.5 h1:Kv3PadDe8IF8gpcYTfAJdCee5Bv4HufvtNT61FXtq5g=
+github.com/charmbracelet/catwalk v0.4.5/go.mod h1:WnKgNPmQHuMyk7GtwAQwl+ezHusfH40IvzML2qwUGwc=
 github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40=
 github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0=
 github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 h1:+Cz+VfxD5DO+JT1LlswXWhre0HYLj6l2HW8HVGfMuC0=
@@ -332,8 +332,8 @@ golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
 golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
 golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
 golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
-golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=
-golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
+golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

internal/llm/provider/openai.go 🔗

@@ -2,10 +2,12 @@ package provider
 
 import (
 	"context"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
 	"log/slog"
+	"slices"
 	"strings"
 	"time"
 
@@ -14,6 +16,7 @@ import (
 	"github.com/charmbracelet/crush/internal/llm/tools"
 	"github.com/charmbracelet/crush/internal/log"
 	"github.com/charmbracelet/crush/internal/message"
+	"github.com/google/uuid"
 	"github.com/openai/openai-go"
 	"github.com/openai/openai-go/option"
 	"github.com/openai/openai-go/packages/param"
@@ -70,8 +73,9 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
 		systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
 	}
 
-	systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
+	system := openai.SystemMessage(systemMessage)
 	if isAnthropicModel && !o.providerOptions.disableCache {
+		systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
 		systemTextBlock.SetExtraFields(
 			map[string]any{
 				"cache_control": map[string]string{
@@ -79,10 +83,10 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
 				},
 			},
 		)
+		var content []openai.ChatCompletionContentPartTextParam
+		content = append(content, systemTextBlock)
+		system = openai.SystemMessage(content)
 	}
-	var content []openai.ChatCompletionContentPartTextParam
-	content = append(content, systemTextBlock)
-	system := openai.SystemMessage(content)
 	openaiMessages = append(openaiMessages, system)
 
 	for i, msg := range messages {
@@ -93,9 +97,12 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
 		switch msg.Role {
 		case message.User:
 			var content []openai.ChatCompletionContentPartUnionParam
+
 			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
+			hasBinaryContent := false
 			for _, binaryContent := range msg.BinaryContent() {
+				hasBinaryContent = true
 				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
 				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 
@@ -108,8 +115,11 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
 					},
 				})
 			}
-
-			openaiMessages = append(openaiMessages, openai.UserMessage(content))
+			if hasBinaryContent || (isAnthropicModel && !o.providerOptions.disableCache) {
+				openaiMessages = append(openaiMessages, openai.UserMessage(content))
+			} else {
+				openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
+			}
 
 		case message.Assistant:
 			assistantMsg := openai.ChatCompletionAssistantMessageParam{
@@ -134,13 +144,15 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
 						},
 					},
 				}
+				if !isAnthropicModel {
+					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
+						OfString: param.NewOpt(msg.Content().String()),
+					}
+				}
 			}
 
 			if len(msg.ToolCalls()) > 0 {
 				hasContent = true
-				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
-					OfString: param.NewOpt(msg.Content().String()),
-				}
 				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 				for i, call := range msg.ToolCalls() {
 					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
@@ -329,21 +341,26 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 			acc := openai.ChatCompletionAccumulator{}
 			currentContent := ""
 			toolCalls := make([]message.ToolCall, 0)
-
-			var currentToolCallID string
-			var currentToolCall openai.ChatCompletionMessageToolCall
 			var msgToolCalls []openai.ChatCompletionMessageToolCall
-			currentToolIndex := 0
 			for openaiStream.Next() {
 				chunk := openaiStream.Current()
 				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
 				if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
-					chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
-					currentToolIndex++
+					chunk.Choices[0].Delta.ToolCalls[0].Index = 0
 				}
 				acc.AddChunk(chunk)
-				// This fixes multiple tool calls for some providers
-				for _, choice := range chunk.Choices {
+				for i, choice := range chunk.Choices {
+					reasoning, ok := choice.Delta.JSON.ExtraFields["reasoning"]
+					if ok && reasoning.Raw() != "" {
+						reasoningStr := ""
+						json.Unmarshal([]byte(reasoning.Raw()), &reasoningStr)
+						if reasoningStr != "" {
+							eventChan <- ProviderEvent{
+								Type:     EventThinkingDelta,
+								Thinking: reasoningStr,
+							}
+						}
+					}
 					if choice.Delta.Content != "" {
 						eventChan <- ProviderEvent{
 							Type:    EventContentDelta,
@@ -352,63 +369,50 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
 						currentContent += choice.Delta.Content
 					} else if len(choice.Delta.ToolCalls) > 0 {
 						toolCall := choice.Delta.ToolCalls[0]
-						// Detect tool use start
-						if currentToolCallID == "" {
-							if toolCall.ID != "" {
-								currentToolCallID = toolCall.ID
-								eventChan <- ProviderEvent{
-									Type: EventToolUseStart,
-									ToolCall: &message.ToolCall{
-										ID:       toolCall.ID,
-										Name:     toolCall.Function.Name,
-										Finished: false,
-									},
+						newToolCall := false
+						if len(msgToolCalls)-1 >= int(toolCall.Index) { // tool call exists
+							existingToolCall := msgToolCalls[toolCall.Index]
+							if toolCall.ID != "" && toolCall.ID != existingToolCall.ID {
+								found := false
+								// try to find the tool based on the ID
+								for i, tool := range msgToolCalls {
+									if tool.ID == toolCall.ID {
+										msgToolCalls[i].Function.Arguments += toolCall.Function.Arguments
+										found = true
+									}
 								}
-								currentToolCall = openai.ChatCompletionMessageToolCall{
-									ID:   toolCall.ID,
-									Type: "function",
-									Function: openai.ChatCompletionMessageToolCallFunction{
-										Name:      toolCall.Function.Name,
-										Arguments: toolCall.Function.Arguments,
-									},
+								if !found {
+									newToolCall = true
 								}
-							}
-						} else {
-							// Delta tool use
-							if toolCall.ID == "" || toolCall.ID == currentToolCallID {
-								currentToolCall.Function.Arguments += toolCall.Function.Arguments
 							} else {
-								// Detect new tool use
-								if toolCall.ID != currentToolCallID {
-									msgToolCalls = append(msgToolCalls, currentToolCall)
-									currentToolCallID = toolCall.ID
-									eventChan <- ProviderEvent{
-										Type: EventToolUseStart,
-										ToolCall: &message.ToolCall{
-											ID:       toolCall.ID,
-											Name:     toolCall.Function.Name,
-											Finished: false,
-										},
-									}
-									currentToolCall = openai.ChatCompletionMessageToolCall{
-										ID:   toolCall.ID,
-										Type: "function",
-										Function: openai.ChatCompletionMessageToolCallFunction{
-											Name:      toolCall.Function.Name,
-											Arguments: toolCall.Function.Arguments,
-										},
-									}
-								}
+								msgToolCalls[toolCall.Index].Function.Arguments += toolCall.Function.Arguments
 							}
+						} else {
+							newToolCall = true
 						}
-					}
-					// Kujtim: some models send finish stop even for tool calls
-					if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
-						msgToolCalls = append(msgToolCalls, currentToolCall)
-						if len(acc.Choices) > 0 {
-							acc.Choices[0].Message.ToolCalls = msgToolCalls
+						if newToolCall { // new tool call
+							if toolCall.ID == "" {
+								toolCall.ID = uuid.NewString()
+							}
+							eventChan <- ProviderEvent{
+								Type: EventToolUseStart,
+								ToolCall: &message.ToolCall{
+									ID:       toolCall.ID,
+									Name:     toolCall.Function.Name,
+									Finished: false,
+								},
+							}
+							msgToolCalls = append(msgToolCalls, openai.ChatCompletionMessageToolCall{
+								ID:   toolCall.ID,
+								Type: "function",
+								Function: openai.ChatCompletionMessageToolCallFunction{
+									Name:      toolCall.Function.Name,
+									Arguments: toolCall.Function.Arguments,
+								},
+							})
 						}
 					}
+					acc.Choices[i].Message.ToolCalls = slices.Clone(msgToolCalls)
 				}
 			}
 

internal/tui/components/chat/messages/messages.go 🔗

@@ -274,6 +274,9 @@ func (m *messageCmp) renderThinkingContent() string {
 	if reasoningContent.StartedAt > 0 {
 		duration := m.message.ThinkingDuration()
 		if reasoningContent.FinishedAt > 0 {
+			if duration.String() == "0s" {
+				return ""
+			}
 			m.anim.SetLabel("")
 			opts := core.StatusOpts{
 				Title:       "Thought for",

internal/tui/components/chat/splash/splash.go 🔗

@@ -3,7 +3,6 @@ package splash
 import (
 	"fmt"
 	"os"
-	"slices"
 	"strings"
 	"time"
 
@@ -103,27 +102,6 @@ func New() Splash {
 
 func (s *splashCmp) SetOnboarding(onboarding bool) {
 	s.isOnboarding = onboarding
-	if onboarding {
-		providers, err := config.Providers()
-		if err != nil {
-			return
-		}
-		filteredProviders := []catwalk.Provider{}
-		simpleProviders := []string{
-			"anthropic",
-			"openai",
-			"gemini",
-			"xai",
-			"groq",
-			"openrouter",
-		}
-		for _, p := range providers {
-			if slices.Contains(simpleProviders, string(p.ID)) {
-				filteredProviders = append(filteredProviders, p)
-			}
-		}
-		s.modelList.SetProviders(filteredProviders)
-	}
 }
 
 func (s *splashCmp) SetProjectInit(needsInit bool) {

internal/tui/components/dialogs/models/list.go 🔗

@@ -3,6 +3,7 @@ package models
 import (
 	"fmt"
 	"slices"
+	"strings"
 
 	tea "github.com/charmbracelet/bubbletea/v2"
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -49,7 +50,15 @@ func (m *ModelListComponent) Init() tea.Cmd {
 	var cmds []tea.Cmd
 	if len(m.providers) == 0 {
 		providers, err := config.Providers()
-		m.providers = providers
+		filteredProviders := []catwalk.Provider{}
+		for _, p := range providers {
+			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
+			if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
+				filteredProviders = append(filteredProviders, p)
+			}
+		}
+
+		m.providers = filteredProviders
 		if err != nil {
 			cmds = append(cmds, util.ReportError(err))
 		}
@@ -242,7 +251,3 @@ func (m *ModelListComponent) GetModelType() int {
 func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
 	m.list.SetInputPlaceholder(placeholder)
 }
-
-func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
-	m.providers = providers
-}

internal/tui/components/dialogs/models/models.go 🔗

@@ -2,7 +2,6 @@ package models
 
 import (
 	"fmt"
-	"slices"
 	"time"
 
 	"github.com/charmbracelet/bubbles/v2/help"
@@ -96,24 +95,6 @@ func NewModelDialogCmp() ModelDialog {
 }
 
 func (m *modelDialogCmp) Init() tea.Cmd {
-	providers, err := config.Providers()
-	if err == nil {
-		filteredProviders := []catwalk.Provider{}
-		simpleProviders := []string{
-			"anthropic",
-			"openai",
-			"gemini",
-			"xai",
-			"groq",
-			"openrouter",
-		}
-		for _, p := range providers {
-			if slices.Contains(simpleProviders, string(p.ID)) {
-				filteredProviders = append(filteredProviders, p)
-			}
-		}
-		m.modelList.SetProviders(filteredProviders)
-	}
 	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
 }