From 6c15eaf0d4ef5b9048975b98974c141894d98336 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 15 Jul 2025 19:01:15 +0200 Subject: [PATCH] feat: implement thinking mode for anthropic models --- internal/llm/agent/agent.go | 8 +- internal/llm/prompt/coder.go | 7 +- internal/llm/provider/anthropic.go | 22 +++- internal/llm/provider/provider.go | 32 +++--- internal/message/content.go | 62 +++++++++- internal/tui/components/chat/chat.go | 2 +- .../tui/components/chat/messages/messages.go | 107 +++++++++++++----- .../components/dialogs/commands/commands.go | 27 +++++ internal/tui/page/chat/chat.go | 31 +++++ 9 files changed, 240 insertions(+), 58 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ad63f659bdbeb2b00f4827030a8075d6696992bc..adb1975e734903ac4051f69356b20c5544687401 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -600,12 +600,17 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg switch event.Type { case provider.EventThinkingDelta: - assistantMsg.AppendReasoningContent(event.Content) + assistantMsg.AppendReasoningContent(event.Thinking) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventSignatureDelta: + assistantMsg.AppendReasoningSignature(event.Signature) return a.messages.Update(ctx, *assistantMsg) case provider.EventContentDelta: + assistantMsg.FinishThinking() assistantMsg.AppendContent(event.Content) return a.messages.Update(ctx, *assistantMsg) case provider.EventToolUseStart: + assistantMsg.FinishThinking() slog.Info("Tool call started", "toolCall", event.ToolCall) assistantMsg.AddToolCall(*event.ToolCall) return a.messages.Update(ctx, *assistantMsg) @@ -619,6 +624,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg case provider.EventError: return event.Error case provider.EventComplete: + assistantMsg.FinishThinking() assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.AddFinish(event.Response.FinishReason, "", "") if err := a.messages.Update(ctx, *assistantMsg); err != nil { diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index dfe2068cd45edf515291b2d759fac4e133912980..f4284faccee052e82e8ed82a820b16af58ccc64c 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -74,7 +74,7 @@ When making changes to files, first understand the file's code conventions. Mimi - Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository. # Code style -- Do not add comments to the code you write, unless the user asks you to, or the code is complex and requires additional context. +- IMPORTANT: DO NOT ADD ***ANY*** COMMENTS unless asked - If completing the user's task requires writing or modifying files: - Your code and final answer should follow these _CODING GUIDELINES_: @@ -204,7 +204,7 @@ When making changes to files, first understand the file's code conventions. Mimi - Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository. # Code style -- Do not add comments to the code you write, unless the user asks you to, or the code is complex and requires additional context. +- IMPORTANT: DO NOT ADD ***ANY*** COMMENTS unless asked # Doing tasks The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended: @@ -249,6 +249,9 @@ When you spend time searching for commands to typecheck, lint, build, or test, y - **Explaining Changes:** After completing a code modification or file operation *do not* provide summaries unless asked. - **Do Not revert changes:** Do not revert changes to the codebase unless asked to do so by the user. Only revert changes made by you if they have resulted in an error or if the user has explicitly asked you to revert the changes. +# Code style +- IMPORTANT: DO NOT ADD ***ANY*** COMMENTS unless asked + # Primary Workflows ## Software Engineering Tasks diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 8e8e3237f55d58fa995d15baf60400a485ec95a2..ace2c22e2b03b4fbc80b1aeedaf79ced8a0eff48 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -72,6 +72,13 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic case message.Assistant: blocks := []anthropic.ContentBlockParamUnion{} + + // Add thinking blocks first if present (required when thinking is enabled with tool use) + if reasoningContent := msg.ReasoningContent(); reasoningContent.Thinking != "" { + thinkingBlock := anthropic.NewThinkingBlock(reasoningContent.Signature, reasoningContent.Thinking) + blocks = append(blocks, thinkingBlock) + } + if msg.Content().String() != "" { content := anthropic.NewTextBlock(msg.Content().String()) if cache && !a.providerOptions.disableCache { @@ -159,16 +166,14 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to } temperature := anthropic.Float(0) - if a.Model().CanReason && modelConfig.Think { - thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8)) - temperature = anthropic.Float(1) - } - maxTokens := model.DefaultMaxTokens if modelConfig.MaxTokens > 0 { maxTokens = modelConfig.MaxTokens } - + if a.Model().CanReason && modelConfig.Think { + thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8)) + temperature = anthropic.Float(1) + } // Override max tokens if set in provider options if a.providerOptions.maxTokens > 0 { maxTokens = a.providerOptions.maxTokens @@ -300,6 +305,11 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message Type: EventThinkingDelta, Thinking: event.Delta.Thinking, } + } else if event.Delta.Type == "signature_delta" && event.Delta.Signature != "" { + eventChan <- ProviderEvent{ + Type: EventSignatureDelta, + Signature: event.Delta.Signature, + } } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" { eventChan <- ProviderEvent{ Type: EventContentDelta, diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 193affc2a2b5a6dcdecee596a839882c40f70a42..12dd09392942b0c00e7caa975deefffa994b47b8 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -15,16 +15,17 @@ type EventType string const maxRetries = 8 const ( - EventContentStart EventType = "content_start" - EventToolUseStart EventType = "tool_use_start" - EventToolUseDelta EventType = "tool_use_delta" - EventToolUseStop EventType = "tool_use_stop" - EventContentDelta EventType = "content_delta" - EventThinkingDelta EventType = "thinking_delta" - EventContentStop EventType = "content_stop" - EventComplete EventType = "complete" - EventError EventType = "error" - EventWarning EventType = "warning" + EventContentStart EventType = "content_start" + EventToolUseStart EventType = "tool_use_start" + EventToolUseDelta EventType = "tool_use_delta" + EventToolUseStop EventType = "tool_use_stop" + EventContentDelta EventType = "content_delta" + EventThinkingDelta EventType = "thinking_delta" + EventSignatureDelta EventType = "signature_delta" + EventContentStop EventType = "content_stop" + EventComplete EventType = "complete" + EventError EventType = "error" + EventWarning EventType = "warning" ) type TokenUsage struct { @@ -44,11 +45,12 @@ type ProviderResponse struct { type ProviderEvent struct { Type EventType - Content string - Thinking string - Response *ProviderResponse - ToolCall *message.ToolCall - Error error + Content string + Thinking string + Signature string + Response *ProviderResponse + ToolCall *message.ToolCall + Error error } type Provider interface { SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) diff --git a/internal/message/content.go b/internal/message/content.go index b8d2c1aa370559977f4c8eb80803ab5fbfe83cf9..bdaf1577e34a4667bdb5c8cd2683865ec5cd08ac 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -36,7 +36,10 @@ type ContentPart interface { } type ReasoningContent struct { - Thinking string `json:"thinking"` + Thinking string `json:"thinking"` + Signature string `json:"signature"` + StartedAt int64 `json:"started_at,omitempty"` + FinishedAt int64 `json:"finished_at,omitempty"` } func (tc ReasoningContent) String() string { @@ -230,15 +233,68 @@ func (m *Message) AppendReasoningContent(delta string) { found := false for i, part := range m.Parts { if c, ok := part.(ReasoningContent); ok { - m.Parts[i] = ReasoningContent{Thinking: c.Thinking + delta} + m.Parts[i] = ReasoningContent{ + Thinking: c.Thinking + delta, + Signature: c.Signature, + StartedAt: c.StartedAt, + FinishedAt: c.FinishedAt, + } found = true } } if !found { - m.Parts = append(m.Parts, ReasoningContent{Thinking: delta}) + m.Parts = append(m.Parts, ReasoningContent{ + Thinking: delta, + StartedAt: time.Now().Unix(), + }) + } +} + +func (m *Message) AppendReasoningSignature(signature string) { + for i, part := range m.Parts { + if c, ok := part.(ReasoningContent); ok { + m.Parts[i] = ReasoningContent{ + Thinking: c.Thinking, + Signature: c.Signature + signature, + StartedAt: c.StartedAt, + FinishedAt: c.FinishedAt, + } + return + } + } + m.Parts = append(m.Parts, ReasoningContent{Signature: signature}) +} + +func (m *Message) FinishThinking() { + for i, part := range m.Parts { + if c, ok := part.(ReasoningContent); ok { + if c.FinishedAt == 0 { + m.Parts[i] = ReasoningContent{ + Thinking: c.Thinking, + Signature: c.Signature, + StartedAt: c.StartedAt, + FinishedAt: time.Now().Unix(), + } + } + return + } } } +func (m *Message) ThinkingDuration() time.Duration { + reasoning := m.ReasoningContent() + if reasoning.StartedAt == 0 { + return 0 + } + + endTime := reasoning.FinishedAt + if endTime == 0 { + endTime = time.Now().Unix() + } + + return time.Duration(endTime-reasoning.StartedAt) * time.Second +} + func (m *Message) FinishToolCall(toolCallID string) { for i, part := range m.Parts { if c, ok := part.(ToolCall); ok { diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index 71f6e1e66ed7d6d1ad80486c1017d02af14b11f4..8601182e2e46bad8ee90aac25ff763fa6bd5f752 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -329,7 +329,7 @@ func (m *messageListCmp) updateAssistantMessageContent(msg message.Message, assi // shouldShowAssistantMessage determines if an assistant message should be displayed. func (m *messageListCmp) shouldShowAssistantMessage(msg message.Message) bool { - return len(msg.ToolCalls()) == 0 || msg.Content().Text != "" || msg.IsThinking() + return len(msg.ToolCalls()) == 0 || msg.Content().Text != "" || msg.ReasoningContent().Thinking != "" || msg.IsThinking() } // updateToolCalls handles updates to tool calls, updating existing ones and adding new ones. diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index bfb8af47b6bd13eb2e1e9fb844b1935a6fccbd4d..b2d34966fe8a4d035a1fe8cda7c2d2a3d459293b 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/charmbracelet/bubbles/v2/viewport" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/x/ansi" @@ -42,6 +43,9 @@ type messageCmp struct { message message.Message // The underlying message content spinning bool // Whether to show loading animation anim util.Model // Animation component for loading states + + // Thinking viewport for displaying reasoning content + thinkingViewport viewport.Model } var focusedMessageBorder = lipgloss.Border{ @@ -51,6 +55,11 @@ var focusedMessageBorder = lipgloss.Border{ // NewMessageCmp creates a new message component with the given message and options func NewMessageCmp(msg message.Message) MessageCmp { t := styles.CurrentTheme() + + thinkingViewport := viewport.New() + thinkingViewport.SetHeight(1) + thinkingViewport.KeyMap = viewport.KeyMap{} + m := &messageCmp{ message: msg, anim: anim.New(anim.Settings{ @@ -59,6 +68,7 @@ func NewMessageCmp(msg message.Message) MessageCmp { GradColorB: t.Secondary, CycleColors: true, }), + thinkingViewport: thinkingViewport, } return m } @@ -139,8 +149,38 @@ func (msg *messageCmp) style() lipgloss.Style { // renderAssistantMessage renders assistant messages with optional footer information. // Shows model name, response time, and finish reason when the message is complete. func (m *messageCmp) renderAssistantMessage() string { - parts := []string{ - m.markdownContent(), + t := styles.CurrentTheme() + parts := []string{} + content := m.message.Content().String() + thinking := m.message.IsThinking() + finished := m.message.IsFinished() + finishedData := m.message.FinishPart() + thinkingContent := "" + + if thinking || m.message.ReasoningContent().Thinking != "" { + thinkingContent = m.renderThinkingContent() + } else if finished && content == "" && finishedData.Reason == message.FinishReasonEndTurn { + content = "" + } else if finished && content == "" && finishedData.Reason == message.FinishReasonCanceled { + content = "*Canceled*" + } else if finished && content == "" && finishedData.Reason == message.FinishReasonError { + errTag := t.S().Base.Padding(0, 1).Background(t.Red).Foreground(t.White).Render("ERROR") + truncated := ansi.Truncate(finishedData.Message, m.textWidth()-2-lipgloss.Width(errTag), "...") + title := fmt.Sprintf("%s %s", errTag, t.S().Base.Foreground(t.FgHalfMuted).Render(truncated)) + details := t.S().Base.Foreground(t.FgSubtle).Width(m.textWidth() - 2).Render(finishedData.Details) + // Handle error messages differently + return fmt.Sprintf("%s\n\n%s", title, details) + } + + if thinkingContent != "" { + parts = append(parts, thinkingContent) + } + + if content != "" { + if thinkingContent != "" { + parts = append(parts, "") + } + parts = append(parts, m.toMarkdown(content)) } joined := lipgloss.JoinVertical(lipgloss.Left, parts...) @@ -152,7 +192,7 @@ func (m *messageCmp) renderAssistantMessage() string { func (m *messageCmp) renderUserMessage() string { t := styles.CurrentTheme() parts := []string{ - m.markdownContent(), + m.toMarkdown(m.message.Content().String()), } attachmentStyles := t.S().Text. MarginLeft(1). @@ -182,34 +222,41 @@ func (m *messageCmp) toMarkdown(content string) string { return strings.TrimSuffix(rendered, "\n") } -// markdownContent processes the message content and handles special states. -// Returns appropriate content for thinking, finished, and error states. -func (m *messageCmp) markdownContent() string { +func (m *messageCmp) renderThinkingContent() string { t := styles.CurrentTheme() - content := m.message.Content().String() - if m.message.Role == message.Assistant { - thinking := m.message.IsThinking() - finished := m.message.IsFinished() - finishedData := m.message.FinishPart() - if thinking { - // Handle the thinking state - // TODO: maybe add the thinking content if available later. - content = fmt.Sprintf("**%s %s**", styles.LoadingIcon, "Thinking...") - } else if finished && content == "" && finishedData.Reason == message.FinishReasonEndTurn { - // Sometimes the LLMs respond with no content when they think the previous tool result - // provides the requested question - content = "" - } else if finished && content == "" && finishedData.Reason == message.FinishReasonCanceled { - content = "*Canceled*" - } else if finished && content == "" && finishedData.Reason == message.FinishReasonError { - errTag := t.S().Base.Padding(0, 1).Background(t.Red).Foreground(t.White).Render("ERROR") - truncated := ansi.Truncate(finishedData.Message, m.textWidth()-2-lipgloss.Width(errTag), "...") - title := fmt.Sprintf("%s %s", errTag, t.S().Base.Foreground(t.FgHalfMuted).Render(truncated)) - details := t.S().Base.Foreground(t.FgSubtle).Width(m.textWidth() - 2).Render(finishedData.Details) - return fmt.Sprintf("%s\n\n%s", title, details) + reasoningContent := m.message.ReasoningContent() + if reasoningContent.Thinking == "" { + return "" + } + lines := strings.Split(reasoningContent.Thinking, "\n") + var content strings.Builder + lineStyle := t.S().Muted.Background(t.BgBaseLighter) + for _, line := range lines { + if line == "" { + continue + } + content.WriteString(lineStyle.Width(m.textWidth()-2).Render(line) + "\n") + } + fullContent := content.String() + height := util.Clamp(lipgloss.Height(fullContent), 1, 10) + m.thinkingViewport.SetHeight(height) + m.thinkingViewport.SetWidth(m.textWidth()) + m.thinkingViewport.SetContent(fullContent) + m.thinkingViewport.GotoBottom() + var footer string + if reasoningContent.StartedAt > 0 { + duration := m.message.ThinkingDuration() + opts := core.StatusOpts{ + Title: "Thinking...", + Description: duration.String(), + } + if reasoningContent.FinishedAt > 0 { + opts.NoIcon = true + opts.Title = "Thought for" } + footer = t.S().Base.PaddingLeft(1).Render(core.Status(opts, m.textWidth()-1)) } - return m.toMarkdown(content) + return lineStyle.Width(m.textWidth()).Padding(0, 1).Render(m.thinkingViewport.View()) + "\n\n" + footer } // shouldSpin determines whether the message should show a loading animation. @@ -257,8 +304,8 @@ func (m *messageCmp) GetSize() (int, int) { // SetSize updates the width of the message component for text wrapping func (m *messageCmp) SetSize(width int, height int) tea.Cmd { - // For better readability, we limit the width to a maximum of 120 characters - m.width = min(width, 120) + m.width = util.Clamp(width, 1, 120) + m.thinkingViewport.SetWidth(m.width - 4) return nil } diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index 10cdbbd539f06836550b7da6a857d35db3becd74..a14138ff51ecf8164cf0fc595c758b0247aa3277 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -6,6 +6,8 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/lipgloss/v2" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" "github.com/charmbracelet/crush/internal/llm/prompt" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/completions" @@ -58,6 +60,7 @@ type ( SwitchSessionsMsg struct{} SwitchModelMsg struct{} ToggleCompactModeMsg struct{} + ToggleThinkingMsg struct{} CompactMsg struct { SessionID string } @@ -260,6 +263,30 @@ func (c *commandDialogCmp) defaultCommands() []Command { }, }) } + + // Only show thinking toggle for Anthropic models that can reason + cfg := config.Get() + if agentCfg, ok := cfg.Agents["coder"]; ok { + providerCfg := cfg.GetProviderForModel(agentCfg.Model) + model := cfg.GetModelByType(agentCfg.Model) + if providerCfg != nil && model != nil && + providerCfg.Type == provider.TypeAnthropic && model.CanReason { + selectedModel := cfg.Models[agentCfg.Model] + status := "Enable" + if selectedModel.Think { + status = "Disable" + } + commands = append(commands, Command{ + ID: "toggle_thinking", + Title: status + " Thinking Mode", + Description: "Toggle model thinking for reasoning-capable models", + Handler: func(cmd Command) tea.Cmd { + return util.CmdHandler(ToggleThinkingMsg{}) + }, + }) + } + } + // Only show toggle compact mode command if window width is larger than compact breakpoint (90) if c.wWidth > 120 && c.sessionID != "" { commands = append(commands, Command{ diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 07fac7133a6003ad951962c8dd5ad55c52bcb67f..cd6605c902d4e5d38196083e4f8556461a0aeca1 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -183,6 +183,8 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmd = p.updateCompactConfig(false) } return p, tea.Batch(p.SetSize(p.width, p.height), cmd) + case commands.ToggleThinkingMsg: + return p, p.toggleThinking() case pubsub.Event[session.Session]: u, cmd := p.header.Update(msg) p.header = u.(header.Header) @@ -409,6 +411,35 @@ func (p *chatPage) updateCompactConfig(compact bool) tea.Cmd { } } +func (p *chatPage) toggleThinking() tea.Cmd { + return func() tea.Msg { + cfg := config.Get() + agentCfg := cfg.Agents["coder"] + currentModel := cfg.Models[agentCfg.Model] + + // Toggle the thinking mode + currentModel.Think = !currentModel.Think + cfg.Models[agentCfg.Model] = currentModel + + // Update the agent with the new configuration + if err := p.app.UpdateAgentModel(); err != nil { + return util.InfoMsg{ + Type: util.InfoTypeError, + Msg: "Failed to update thinking mode: " + err.Error(), + } + } + + status := "disabled" + if currentModel.Think { + status = "enabled" + } + return util.InfoMsg{ + Type: util.InfoTypeInfo, + Msg: "Thinking mode " + status, + } + } +} + func (p *chatPage) setCompactMode(compact bool) { if p.compact == compact { return