initial tool call stream

Kujtim Hoxha created

Change summary

internal/llm/agent/agent.go             |  22 +++++
internal/llm/provider/anthropic.go      |  60 ++++++++++---
internal/llm/provider/openai.go         |   9 +
internal/llm/provider/provider.go       |   7 +
internal/message/content.go             |  43 +++++++++
internal/message/message.go             |   2 
internal/pubsub/broker.go               |   7 -
internal/tui/components/chat/list.go    | 117 ++++++---------------------
internal/tui/components/chat/message.go |  92 +++++++++++++++++---
internal/tui/layout/split.go            |  28 ++++++
internal/tui/page/chat.go               |  10 ++
11 files changed, 261 insertions(+), 136 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -380,6 +380,21 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
 	case provider.EventContentDelta:
 		assistantMsg.AppendContent(event.Content)
 		return a.messages.Update(ctx, *assistantMsg)
+	case provider.EventToolUseStart:
+		assistantMsg.AddToolCall(*event.ToolCall)
+		return a.messages.Update(ctx, *assistantMsg)
+	// TODO: see how to handle this
+	// case provider.EventToolUseDelta:
+	// 	tm := time.Unix(assistantMsg.UpdatedAt, 0)
+	// 	assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
+	// 	if time.Since(tm) > 1000*time.Millisecond {
+	// 		err := a.messages.Update(ctx, *assistantMsg)
+	// 		assistantMsg.UpdatedAt = time.Now().Unix()
+	// 		return err
+	// 	}
+	case provider.EventToolUseStop:
+		assistantMsg.FinishToolCall(event.ToolCall.ID)
+		return a.messages.Update(ctx, *assistantMsg)
 	case provider.EventError:
 		if errors.Is(event.Error, context.Canceled) {
 			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
@@ -456,6 +471,13 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
 				provider.WithReasoningEffort(agentConfig.ReasoningEffort),
 			),
 		)
+	} else if model.Provider == models.ProviderAnthropic && model.CanReason {
+		opts = append(
+			opts,
+			provider.WithAnthropicOptions(
+				provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
+			),
+		)
 	}
 	agentProvider, err := provider.NewProvider(
 		model.Provider,

internal/llm/provider/anthropic.go 🔗

@@ -93,8 +93,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
 			}
 
 			if len(blocks) == 0 {
-				logging.Warn("There is a message without content, investigate")
-				// This should never happend but we log this because we might have a bug in our cleanup method
+				logging.Warn("There is a message without content, investigate, this should not happen")
 				continue
 			}
 			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
@@ -196,8 +195,8 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
 	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
 	cfg := config.Get()
 	if cfg.Debug {
-		jsonData, _ := json.Marshal(preparedMessages)
-		logging.Debug("Prepared messages", "messages", string(jsonData))
+		// jsonData, _ := json.Marshal(preparedMessages)
+		// logging.Debug("Prepared messages", "messages", string(jsonData))
 	}
 	attempts := 0
 	for {
@@ -243,8 +242,8 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
 	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
 	cfg := config.Get()
 	if cfg.Debug {
-		jsonData, _ := json.Marshal(preparedMessages)
-		logging.Debug("Prepared messages", "messages", string(jsonData))
+		// jsonData, _ := json.Marshal(preparedMessages)
+		// logging.Debug("Prepared messages", "messages", string(jsonData))
 	}
 	attempts := 0
 	eventChan := make(chan ProviderEvent)
@@ -257,6 +256,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
 			)
 			accumulatedMessage := anthropic.Message{}
 
+			currentToolCallID := ""
 			for anthropicStream.Next() {
 				event := anthropicStream.Current()
 				err := accumulatedMessage.Accumulate(event)
@@ -267,7 +267,19 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
 
 				switch event := event.AsAny().(type) {
 				case anthropic.ContentBlockStartEvent:
-					eventChan <- ProviderEvent{Type: EventContentStart}
+					if event.ContentBlock.Type == "text" {
+						eventChan <- ProviderEvent{Type: EventContentStart}
+					} else if event.ContentBlock.Type == "tool_use" {
+						currentToolCallID = event.ContentBlock.ID
+						eventChan <- ProviderEvent{
+							Type: EventToolUseStart,
+							ToolCall: &message.ToolCall{
+								ID:       event.ContentBlock.ID,
+								Name:     event.ContentBlock.Name,
+								Finished: false,
+							},
+						}
+					}
 
 				case anthropic.ContentBlockDeltaEvent:
 					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
@@ -280,11 +292,30 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
 							Type:    EventContentDelta,
 							Content: event.Delta.Text,
 						}
+					} else if event.Delta.Type == "input_json_delta" {
+						if currentToolCallID != "" {
+							eventChan <- ProviderEvent{
+								Type: EventToolUseDelta,
+								ToolCall: &message.ToolCall{
+									ID:       currentToolCallID,
+									Finished: false,
+									Input:    event.Delta.JSON.PartialJSON.Raw(),
+								},
+							}
+						}
 					}
-				// TODO: check if we can somehow stream tool calls
-
 				case anthropic.ContentBlockStopEvent:
-					eventChan <- ProviderEvent{Type: EventContentStop}
+					if currentToolCallID != "" {
+						eventChan <- ProviderEvent{
+							Type: EventToolUseStop,
+							ToolCall: &message.ToolCall{
+								ID: currentToolCallID,
+							},
+						}
+						currentToolCallID = ""
+					} else {
+						eventChan <- ProviderEvent{Type: EventContentStop}
+					}
 
 				case anthropic.MessageStopEvent:
 					content := ""
@@ -378,10 +409,11 @@ func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
 		switch variant := block.AsAny().(type) {
 		case anthropic.ToolUseBlock:
 			toolCall := message.ToolCall{
-				ID:    variant.ID,
-				Name:  variant.Name,
-				Input: string(variant.Input),
-				Type:  string(variant.Type),
+				ID:       variant.ID,
+				Name:     variant.Name,
+				Input:    string(variant.Input),
+				Type:     string(variant.Type),
+				Finished: true,
 			}
 			toolCalls = append(toolCalls, toolCall)
 		}

internal/llm/provider/openai.go 🔗

@@ -344,10 +344,11 @@ func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.Too
 	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
 		for _, call := range completion.Choices[0].Message.ToolCalls {
 			toolCall := message.ToolCall{
-				ID:    call.ID,
-				Name:  call.Function.Name,
-				Input: call.Function.Arguments,
-				Type:  "function",
+				ID:       call.ID,
+				Name:     call.Function.Name,
+				Input:    call.Function.Arguments,
+				Type:     "function",
+				Finished: true,
 			}
 			toolCalls = append(toolCalls, toolCall)
 		}

internal/llm/provider/provider.go 🔗

@@ -15,6 +15,9 @@ const maxRetries = 8
 
 const (
 	EventContentStart  EventType = "content_start"
+	EventToolUseStart  EventType = "tool_use_start"
+	EventToolUseDelta  EventType = "tool_use_delta"
+	EventToolUseStop   EventType = "tool_use_stop"
 	EventContentDelta  EventType = "content_delta"
 	EventThinkingDelta EventType = "thinking_delta"
 	EventContentStop   EventType = "content_stop"
@@ -43,8 +46,8 @@ type ProviderEvent struct {
 	Content  string
 	Thinking string
 	Response *ProviderResponse
-
-	Error error
+	ToolCall *message.ToolCall
+	Error    error
 }
 type Provider interface {
 	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)

internal/message/content.go 🔗

@@ -233,6 +233,40 @@ func (m *Message) AppendReasoningContent(delta string) {
 	}
 }
 
+func (m *Message) FinishToolCall(toolCallID string) {
+	for i, part := range m.Parts {
+		if c, ok := part.(ToolCall); ok {
+			if c.ID == toolCallID {
+				m.Parts[i] = ToolCall{
+					ID:       c.ID,
+					Name:     c.Name,
+					Input:    c.Input,
+					Type:     c.Type,
+					Finished: true,
+				}
+				return
+			}
+		}
+	}
+}
+
+func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
+	for i, part := range m.Parts {
+		if c, ok := part.(ToolCall); ok {
+			if c.ID == toolCallID {
+				m.Parts[i] = ToolCall{
+					ID:       c.ID,
+					Name:     c.Name,
+					Input:    c.Input + inputDelta,
+					Type:     c.Type,
+					Finished: c.Finished,
+				}
+				return
+			}
+		}
+	}
+}
+
 func (m *Message) AddToolCall(tc ToolCall) {
 	for i, part := range m.Parts {
 		if c, ok := part.(ToolCall); ok {
@@ -246,6 +280,15 @@ func (m *Message) AddToolCall(tc ToolCall) {
 }
 
 func (m *Message) SetToolCalls(tc []ToolCall) {
+	// remove any existing tool call part it could have multiple
+	parts := make([]ContentPart, 0)
+	for _, part := range m.Parts {
+		if _, ok := part.(ToolCall); ok {
+			continue
+		}
+		parts = append(parts, part)
+	}
+	m.Parts = parts
 	for _, toolCall := range tc {
 		m.Parts = append(m.Parts, toolCall)
 	}

internal/message/message.go 🔗

@@ -5,6 +5,7 @@ import (
 	"database/sql"
 	"encoding/json"
 	"fmt"
+	"time"
 
 	"github.com/google/uuid"
 	"github.com/kujtimiihoxha/opencode/internal/db"
@@ -116,6 +117,7 @@ func (s *service) Update(ctx context.Context, message Message) error {
 	if err != nil {
 		return err
 	}
+	message.UpdatedAt = time.Now().Unix()
 	s.Publish(pubsub.UpdatedEvent, message)
 	return nil
 }

internal/pubsub/broker.go 🔗

@@ -7,13 +7,6 @@ import (
 
 const bufferSize = 1024
 
-type Logger interface {
-	Debug(msg string, args ...any)
-	Info(msg string, args ...any)
-	Warn(msg string, args ...any)
-	Error(msg string, args ...any)
-}
-
 // Broker allows clients to publish events and subscribe to events
 type Broker[T any] struct {
 	subs map[chan Event[T]]struct{} // subscriptions

internal/tui/components/chat/list.go 🔗

@@ -4,8 +4,6 @@ import (
 	"context"
 	"fmt"
 	"math"
-	"sync"
-	"time"
 
 	"github.com/charmbracelet/bubbles/key"
 	"github.com/charmbracelet/bubbles/spinner"
@@ -13,7 +11,6 @@ import (
 	tea "github.com/charmbracelet/bubbletea"
 	"github.com/charmbracelet/lipgloss"
 	"github.com/kujtimiihoxha/opencode/internal/app"
-	"github.com/kujtimiihoxha/opencode/internal/logging"
 	"github.com/kujtimiihoxha/opencode/internal/message"
 	"github.com/kujtimiihoxha/opencode/internal/pubsub"
 	"github.com/kujtimiihoxha/opencode/internal/session"
@@ -35,89 +32,14 @@ type messagesCmp struct {
 	messages      []message.Message
 	uiMessages    []uiMessage
 	currentMsgID  string
-	mutex         sync.Mutex
 	cachedContent map[string]cacheItem
 	spinner       spinner.Model
-	lastUpdate    time.Time
 	rendering     bool
 }
 type renderFinishedMsg struct{}
 
 func (m *messagesCmp) Init() tea.Cmd {
-	return tea.Batch(m.viewport.Init())
-}
-
-func (m *messagesCmp) preloadSessions() tea.Cmd {
-	return func() tea.Msg {
-		m.mutex.Lock()
-		defer m.mutex.Unlock()
-		sessions, err := m.app.Sessions.List(context.Background())
-		if err != nil {
-			return util.ReportError(err)()
-		}
-		if len(sessions) == 0 {
-			return nil
-		}
-		if len(sessions) > 20 {
-			sessions = sessions[:20]
-		}
-		for _, s := range sessions {
-			messages, err := m.app.Messages.List(context.Background(), s.ID)
-			if err != nil {
-				return util.ReportError(err)()
-			}
-			if len(messages) == 0 {
-				continue
-			}
-			m.cacheSessionMessages(messages, m.width)
-
-		}
-		logging.Debug("preloaded sessions")
-
-		return func() tea.Msg {
-			return renderFinishedMsg{}
-		}
-	}
-}
-
-func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int) {
-	pos := 0
-	if m.width == 0 {
-		return
-	}
-	for inx, msg := range messages {
-		switch msg.Role {
-		case message.User:
-			userMsg := renderUserMessage(
-				msg,
-				false,
-				width,
-				pos,
-			)
-			m.cachedContent[msg.ID] = cacheItem{
-				width:   width,
-				content: []uiMessage{userMsg},
-			}
-			pos += userMsg.height + 1 // + 1 for spacing
-		case message.Assistant:
-			assistantMessages := renderAssistantMessage(
-				msg,
-				inx,
-				messages,
-				m.app.Messages,
-				"",
-				width,
-				pos,
-			)
-			for _, msg := range assistantMessages {
-				pos += msg.height + 1 // + 1 for spacing
-			}
-			m.cachedContent[msg.ID] = cacheItem{
-				width:   width,
-				content: assistantMessages,
-			}
-		}
-	}
+	return tea.Batch(m.viewport.Init(), m.spinner.Tick)
 }
 
 func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -360,21 +282,35 @@ func hasToolsWithoutResponse(messages []message.Message) bool {
 				break
 			}
 		}
-		if !found {
+		if !found && v.Finished {
 			return true
 		}
 	}
+	return false
+}
 
+func hasUnfinishedToolCalls(messages []message.Message) bool {
+	toolCalls := make([]message.ToolCall, 0)
+	for _, m := range messages {
+		toolCalls = append(toolCalls, m.ToolCalls()...)
+	}
+	for _, v := range toolCalls {
+		if !v.Finished {
+			return true
+		}
+	}
 	return false
 }
 
 func (m *messagesCmp) working() string {
 	text := ""
-	if m.IsAgentWorking() {
+	if m.IsAgentWorking() && len(m.messages) > 0 {
 		task := "Thinking..."
 		lastMessage := m.messages[len(m.messages)-1]
 		if hasToolsWithoutResponse(m.messages) {
 			task = "Waiting for tool response..."
+		} else if hasUnfinishedToolCalls(m.messages) {
+			task = "Building tool call..."
 		} else if !lastMessage.IsFinished() {
 			task = "Generating..."
 		}
@@ -434,8 +370,7 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd {
 		delete(m.cachedContent, msg.ID)
 	}
 	m.uiMessages = make([]uiMessage, 0)
-	m.renderView()
-	return m.preloadSessions()
+	return nil
 }
 
 func (m *messagesCmp) GetSize() (int, int) {
@@ -446,16 +381,16 @@ func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
 	if m.session.ID == session.ID {
 		return nil
 	}
+	m.session = session
+	messages, err := m.app.Messages.List(context.Background(), session.ID)
+	if err != nil {
+		return util.ReportError(err)
+	}
+	m.messages = messages
+	m.currentMsgID = m.messages[len(m.messages)-1].ID
+	delete(m.cachedContent, m.currentMsgID)
 	m.rendering = true
 	return func() tea.Msg {
-		m.session = session
-		messages, err := m.app.Messages.List(context.Background(), session.ID)
-		if err != nil {
-			return util.ReportError(err)
-		}
-		m.messages = messages
-		m.currentMsgID = m.messages[len(m.messages)-1].ID
-		delete(m.cachedContent, m.currentMsgID)
 		m.renderView()
 		return renderFinishedMsg{}
 	}

internal/tui/components/chat/message.go 🔗

@@ -113,18 +113,10 @@ func renderAssistantMessage(
 	width int,
 	position int,
 ) []uiMessage {
-	// find the user message that is before this assistant message
-	var userMsg message.Message
-	for i := msgIndex - 1; i >= 0; i-- {
-		msg := allMessages[i]
-		if msg.Role == message.User {
-			userMsg = allMessages[i]
-			break
-		}
-	}
-
 	messages := []uiMessage{}
 	content := msg.Content().String()
+	thinking := msg.IsThinking()
+	thinkingContent := msg.ReasoningContent().Thinking
 	finished := msg.IsFinished()
 	finishData := msg.FinishPart()
 	info := []string{}
@@ -133,7 +125,7 @@ func renderAssistantMessage(
 	if finished {
 		switch finishData.Reason {
 		case message.FinishReasonEndTurn:
-			took := formatTimeDifference(userMsg.CreatedAt, finishData.Time)
+			took := formatTimeDifference(msg.CreatedAt, finishData.Time)
 			info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render(
 				fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took),
 			))
@@ -166,6 +158,9 @@ func renderAssistantMessage(
 		})
 		position += messages[0].height
 		position++ // for the space
+	} else if thinking && thinkingContent != "" {
+		// Render the thinking content
+		content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width)
 	}
 
 	for i, toolCall := range msg.ToolCalls() {
@@ -218,10 +213,40 @@ func toolName(name string) string {
 		return "View"
 	case tools.WriteToolName:
 		return "Write"
+	case tools.PatchToolName:
+		return "Patch"
 	}
 	return name
 }
 
+func getToolAction(name string) string {
+	switch name {
+	case agent.AgentToolName:
+		return "Preparing prompt..."
+	case tools.BashToolName:
+		return "Building command..."
+	case tools.EditToolName:
+		return "Preparing edit..."
+	case tools.FetchToolName:
+		return "Writing fetch..."
+	case tools.GlobToolName:
+		return "Finding files..."
+	case tools.GrepToolName:
+		return "Searching content..."
+	case tools.LSToolName:
+		return "Listing directory..."
+	case tools.SourcegraphToolName:
+		return "Searching code..."
+	case tools.ViewToolName:
+		return "Reading file..."
+	case tools.WriteToolName:
+		return "Preparing write..."
+	case tools.PatchToolName:
+		return "Preparing patch..."
+	}
+	return "Working..."
+}
+
 // renders params, params[0] (params[1]=params[2] ....)
 func renderParams(paramsWidth int, params ...string) string {
 	if len(params) == 0 {
@@ -490,8 +515,47 @@ func renderToolMessage(
 	if nested {
 		width = width - 3
 	}
+	style := styles.BaseStyle.
+		Width(width - 1).
+		BorderLeft(true).
+		BorderStyle(lipgloss.ThickBorder()).
+		PaddingLeft(1).
+		BorderForeground(styles.ForgroundDim)
+
 	response := findToolResponse(toolCall.ID, allMessages)
 	toolName := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%s: ", toolName(toolCall.Name)))
+
+	if !toolCall.Finished {
+		// Get a brief description of what the tool is doing
+		toolAction := getToolAction(toolCall.Name)
+
+		// toolInput := strings.ReplaceAll(toolCall.Input, "\n", " ")
+		// truncatedInput := toolInput
+		// if len(truncatedInput) > 10 {
+		// 	truncatedInput = truncatedInput[len(truncatedInput)-10:]
+		// }
+		//
+		// truncatedInput = styles.BaseStyle.
+		// 	Italic(true).
+		// 	Width(width - 2 - lipgloss.Width(toolName)).
+		// 	Background(styles.BackgroundDim).
+		// 	Foreground(styles.ForgroundMid).
+		// 	Render(truncatedInput)
+
+		progressText := styles.BaseStyle.
+			Width(width - 2 - lipgloss.Width(toolName)).
+			Foreground(styles.ForgroundDim).
+			Render(fmt.Sprintf("%s", toolAction))
+
+		content := style.Render(lipgloss.JoinHorizontal(lipgloss.Left, toolName, progressText))
+		toolMsg := uiMessage{
+			messageType: toolMessageType,
+			position:    position,
+			height:      lipgloss.Height(content),
+			content:     content,
+		}
+		return toolMsg
+	}
 	params := renderToolParams(width-2-lipgloss.Width(toolName), toolCall)
 	responseContent := ""
 	if response != nil {
@@ -504,12 +568,6 @@ func renderToolMessage(
 			Foreground(styles.ForgroundDim).
 			Render("Waiting for response...")
 	}
-	style := styles.BaseStyle.
-		Width(width - 1).
-		BorderLeft(true).
-		BorderStyle(lipgloss.ThickBorder()).
-		PaddingLeft(1).
-		BorderForeground(styles.ForgroundDim)
 
 	parts := []string{}
 	if !nested {

internal/tui/layout/split.go 🔗

@@ -14,6 +14,10 @@ type SplitPaneLayout interface {
 	SetLeftPanel(panel Container) tea.Cmd
 	SetRightPanel(panel Container) tea.Cmd
 	SetBottomPanel(panel Container) tea.Cmd
+
+	ClearLeftPanel() tea.Cmd
+	ClearRightPanel() tea.Cmd
+	ClearBottomPanel() tea.Cmd
 }
 
 type splitPaneLayout struct {
@@ -192,6 +196,30 @@ func (s *splitPaneLayout) SetBottomPanel(panel Container) tea.Cmd {
 	return nil
 }
 
+func (s *splitPaneLayout) ClearLeftPanel() tea.Cmd {
+	s.leftPanel = nil
+	if s.width > 0 && s.height > 0 {
+		return s.SetSize(s.width, s.height)
+	}
+	return nil
+}
+
+func (s *splitPaneLayout) ClearRightPanel() tea.Cmd {
+	s.rightPanel = nil
+	if s.width > 0 && s.height > 0 {
+		return s.SetSize(s.width, s.height)
+	}
+	return nil
+}
+
+func (s *splitPaneLayout) ClearBottomPanel() tea.Cmd {
+	s.bottomPanel = nil
+	if s.width > 0 && s.height > 0 {
+		return s.SetSize(s.width, s.height)
+	}
+	return nil
+}
+
 func (s *splitPaneLayout) BindingKeys() []key.Binding {
 	keys := []key.Binding{}
 	if s.leftPanel != nil {

internal/tui/page/chat.go 🔗

@@ -57,6 +57,14 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		if cmd != nil {
 			return p, cmd
 		}
+	case chat.SessionSelectedMsg:
+		if p.session.ID == "" {
+			cmd := p.setSidebar()
+			if cmd != nil {
+				cmds = append(cmds, cmd)
+			}
+		}
+		p.session = msg
 	case chat.EditorFocusMsg:
 		p.editingMode = bool(msg)
 	case tea.KeyMsg:
@@ -91,7 +99,7 @@ func (p *chatPage) setSidebar() tea.Cmd {
 }
 
 func (p *chatPage) clearSidebar() tea.Cmd {
-	return p.layout.SetRightPanel(nil)
+	return p.layout.ClearRightPanel()
 }
 
 func (p *chatPage) sendMessage(text string) tea.Cmd {