@@ -2,10 +2,12 @@ package provider
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"io"
"log/slog"
+ "slices"
"strings"
"time"
@@ -14,6 +16,7 @@ import (
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/message"
+ "github.com/google/uuid"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
@@ -70,8 +73,9 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
}
- systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
+ system := openai.SystemMessage(systemMessage)
if isAnthropicModel && !o.providerOptions.disableCache {
+ systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
systemTextBlock.SetExtraFields(
map[string]any{
"cache_control": map[string]string{
@@ -79,10 +83,10 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
},
},
)
+ var content []openai.ChatCompletionContentPartTextParam
+ content = append(content, systemTextBlock)
+ system = openai.SystemMessage(content)
}
- var content []openai.ChatCompletionContentPartTextParam
- content = append(content, systemTextBlock)
- system := openai.SystemMessage(content)
openaiMessages = append(openaiMessages, system)
for i, msg := range messages {
@@ -93,9 +97,12 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
switch msg.Role {
case message.User:
var content []openai.ChatCompletionContentPartUnionParam
+
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
+ hasBinaryContent := false
for _, binaryContent := range msg.BinaryContent() {
+ hasBinaryContent = true
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
@@ -108,8 +115,11 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
},
})
}
-
- openaiMessages = append(openaiMessages, openai.UserMessage(content))
+ if hasBinaryContent || (isAnthropicModel && !o.providerOptions.disableCache) {
+ openaiMessages = append(openaiMessages, openai.UserMessage(content))
+ } else {
+ openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
+ }
case message.Assistant:
assistantMsg := openai.ChatCompletionAssistantMessageParam{
@@ -134,13 +144,15 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
},
},
}
+ if !isAnthropicModel {
+ assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
+ OfString: param.NewOpt(msg.Content().String()),
+ }
+ }
}
if len(msg.ToolCalls()) > 0 {
hasContent = true
- assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
- OfString: param.NewOpt(msg.Content().String()),
- }
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
for i, call := range msg.ToolCalls() {
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
@@ -329,21 +341,26 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
acc := openai.ChatCompletionAccumulator{}
currentContent := ""
toolCalls := make([]message.ToolCall, 0)
-
- var currentToolCallID string
- var currentToolCall openai.ChatCompletionMessageToolCall
var msgToolCalls []openai.ChatCompletionMessageToolCall
- currentToolIndex := 0
for openaiStream.Next() {
chunk := openaiStream.Current()
// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
- chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
- currentToolIndex++
+ chunk.Choices[0].Delta.ToolCalls[0].Index = 0
}
acc.AddChunk(chunk)
- // This fixes multiple tool calls for some providers
- for _, choice := range chunk.Choices {
+ for i, choice := range chunk.Choices {
+ reasoning, ok := choice.Delta.JSON.ExtraFields["reasoning"]
+ if ok && reasoning.Raw() != "" {
+ reasoningStr := ""
+ json.Unmarshal([]byte(reasoning.Raw()), &reasoningStr)
+ if reasoningStr != "" {
+ eventChan <- ProviderEvent{
+ Type: EventThinkingDelta,
+ Thinking: reasoningStr,
+ }
+ }
+ }
if choice.Delta.Content != "" {
eventChan <- ProviderEvent{
Type: EventContentDelta,
@@ -352,63 +369,50 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
currentContent += choice.Delta.Content
} else if len(choice.Delta.ToolCalls) > 0 {
toolCall := choice.Delta.ToolCalls[0]
- // Detect tool use start
- if currentToolCallID == "" {
- if toolCall.ID != "" {
- currentToolCallID = toolCall.ID
- eventChan <- ProviderEvent{
- Type: EventToolUseStart,
- ToolCall: &message.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Function.Name,
- Finished: false,
- },
+ newToolCall := false
+ if len(msgToolCalls)-1 >= int(toolCall.Index) { // tool call exists
+ existingToolCall := msgToolCalls[toolCall.Index]
+ if toolCall.ID != "" && toolCall.ID != existingToolCall.ID {
+ found := false
+ // try to find the tool based on the ID
+ for i, tool := range msgToolCalls {
+ if tool.ID == toolCall.ID {
+ msgToolCalls[i].Function.Arguments += toolCall.Function.Arguments
+ found = true
+ }
}
- currentToolCall = openai.ChatCompletionMessageToolCall{
- ID: toolCall.ID,
- Type: "function",
- Function: openai.ChatCompletionMessageToolCallFunction{
- Name: toolCall.Function.Name,
- Arguments: toolCall.Function.Arguments,
- },
+ if !found {
+ newToolCall = true
}
- }
- } else {
- // Delta tool use
- if toolCall.ID == "" || toolCall.ID == currentToolCallID {
- currentToolCall.Function.Arguments += toolCall.Function.Arguments
} else {
- // Detect new tool use
- if toolCall.ID != currentToolCallID {
- msgToolCalls = append(msgToolCalls, currentToolCall)
- currentToolCallID = toolCall.ID
- eventChan <- ProviderEvent{
- Type: EventToolUseStart,
- ToolCall: &message.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Function.Name,
- Finished: false,
- },
- }
- currentToolCall = openai.ChatCompletionMessageToolCall{
- ID: toolCall.ID,
- Type: "function",
- Function: openai.ChatCompletionMessageToolCallFunction{
- Name: toolCall.Function.Name,
- Arguments: toolCall.Function.Arguments,
- },
- }
- }
+ msgToolCalls[toolCall.Index].Function.Arguments += toolCall.Function.Arguments
}
+ } else {
+ newToolCall = true
}
- }
- // Kujtim: some models send finish stop even for tool calls
- if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
- msgToolCalls = append(msgToolCalls, currentToolCall)
- if len(acc.Choices) > 0 {
- acc.Choices[0].Message.ToolCalls = msgToolCalls
+ if newToolCall { // new tool call
+ if toolCall.ID == "" {
+ toolCall.ID = uuid.NewString()
+ }
+ eventChan <- ProviderEvent{
+ Type: EventToolUseStart,
+ ToolCall: &message.ToolCall{
+ ID: toolCall.ID,
+ Name: toolCall.Function.Name,
+ Finished: false,
+ },
+ }
+ msgToolCalls = append(msgToolCalls, openai.ChatCompletionMessageToolCall{
+ ID: toolCall.ID,
+ Type: "function",
+ Function: openai.ChatCompletionMessageToolCallFunction{
+ Name: toolCall.Function.Name,
+ Arguments: toolCall.Function.Arguments,
+ },
+ })
}
}
+ acc.Choices[i].Message.ToolCalls = slices.Clone(msgToolCalls)
}
}
@@ -3,7 +3,6 @@ package splash
import (
"fmt"
"os"
- "slices"
"strings"
"time"
@@ -103,27 +102,6 @@ func New() Splash {
func (s *splashCmp) SetOnboarding(onboarding bool) {
s.isOnboarding = onboarding
- if onboarding {
- providers, err := config.Providers()
- if err != nil {
- return
- }
- filteredProviders := []catwalk.Provider{}
- simpleProviders := []string{
- "anthropic",
- "openai",
- "gemini",
- "xai",
- "groq",
- "openrouter",
- }
- for _, p := range providers {
- if slices.Contains(simpleProviders, string(p.ID)) {
- filteredProviders = append(filteredProviders, p)
- }
- }
- s.modelList.SetProviders(filteredProviders)
- }
}
func (s *splashCmp) SetProjectInit(needsInit bool) {
@@ -3,6 +3,7 @@ package models
import (
"fmt"
"slices"
+ "strings"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -49,7 +50,15 @@ func (m *ModelListComponent) Init() tea.Cmd {
var cmds []tea.Cmd
if len(m.providers) == 0 {
providers, err := config.Providers()
- m.providers = providers
+ filteredProviders := []catwalk.Provider{}
+ for _, p := range providers {
+ hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
+ if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
+ filteredProviders = append(filteredProviders, p)
+ }
+ }
+
+ m.providers = filteredProviders
if err != nil {
cmds = append(cmds, util.ReportError(err))
}
@@ -242,7 +251,3 @@ func (m *ModelListComponent) GetModelType() int {
func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
m.list.SetInputPlaceholder(placeholder)
}
-
-func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
- m.providers = providers
-}
@@ -2,7 +2,6 @@ package models
import (
"fmt"
- "slices"
"time"
"github.com/charmbracelet/bubbles/v2/help"
@@ -96,24 +95,6 @@ func NewModelDialogCmp() ModelDialog {
}
func (m *modelDialogCmp) Init() tea.Cmd {
- providers, err := config.Providers()
- if err == nil {
- filteredProviders := []catwalk.Provider{}
- simpleProviders := []string{
- "anthropic",
- "openai",
- "gemini",
- "xai",
- "groq",
- "openrouter",
- }
- for _, p := range providers {
- if slices.Contains(simpleProviders, string(p.ID)) {
- filteredProviders = append(filteredProviders, p)
- }
- }
- m.modelList.SetProviders(filteredProviders)
- }
return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
}