chat.go

  1package chat
  2
  3import (
  4	"context"
  5	"time"
  6
  7	tea "github.com/charmbracelet/bubbletea/v2"
  8	"github.com/charmbracelet/crush/internal/app"
  9	"github.com/charmbracelet/crush/internal/llm/agent"
 10	"github.com/charmbracelet/crush/internal/message"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/charmbracelet/crush/internal/session"
 13	"github.com/charmbracelet/crush/internal/tui/components/chat/messages"
 14	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 15	"github.com/charmbracelet/crush/internal/tui/layout"
 16	"github.com/charmbracelet/crush/internal/tui/util"
 17	"github.com/charmbracelet/lipgloss/v2"
 18)
 19
 20type SendMsg struct {
 21	Text        string
 22	Attachments []message.Attachment
 23}
 24
 25type SessionSelectedMsg = session.Session
 26
 27type SessionClearedMsg struct{}
 28
 29const (
 30	NotFound = -1
 31)
 32
 33// MessageListCmp represents a component that displays a list of chat messages
 34// with support for real-time updates and session management.
 35type MessageListCmp interface {
 36	util.Model
 37	layout.Sizeable
 38	layout.Focusable
 39}
 40
 41// messageListCmp implements MessageListCmp, providing a virtualized list
 42// of chat messages with support for tool calls, real-time updates, and
 43// session switching.
 44type messageListCmp struct {
 45	app              *app.App
 46	width, height    int
 47	session          session.Session
 48	listCmp          list.ListModel
 49	previousSelected int // Last selected item index for restoring focus
 50
 51	lastUserMessageTime int64
 52}
 53
 54// NewMessagesListCmp creates a new message list component with custom keybindings
 55// and reverse ordering (newest messages at bottom).
 56func NewMessagesListCmp(app *app.App) MessageListCmp {
 57	defaultKeymaps := list.DefaultKeyMap()
 58	listCmp := list.New(
 59		list.WithGapSize(1),
 60		list.WithReverse(true),
 61		list.WithKeyMap(defaultKeymaps),
 62	)
 63	return &messageListCmp{
 64		app:              app,
 65		listCmp:          listCmp,
 66		previousSelected: list.NoSelection,
 67	}
 68}
 69
 70// Init initializes the component (no initialization needed).
 71func (m *messageListCmp) Init() tea.Cmd {
 72	return tea.Sequence(m.listCmp.Init(), m.listCmp.Blur())
 73}
 74
 75// Update handles incoming messages and updates the component state.
 76func (m *messageListCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 77	switch msg := msg.(type) {
 78	case SessionSelectedMsg:
 79		if msg.ID != m.session.ID {
 80			cmd := m.SetSession(msg)
 81			return m, cmd
 82		}
 83		return m, nil
 84	case SessionClearedMsg:
 85		m.session = session.Session{}
 86		return m, m.listCmp.SetItems([]util.Model{})
 87
 88	case pubsub.Event[message.Message]:
 89		cmd := m.handleMessageEvent(msg)
 90		return m, cmd
 91	default:
 92		var cmds []tea.Cmd
 93		u, cmd := m.listCmp.Update(msg)
 94		m.listCmp = u.(list.ListModel)
 95		cmds = append(cmds, cmd)
 96		return m, tea.Batch(cmds...)
 97	}
 98}
 99
100// View renders the message list or an initial screen if empty.
101func (m *messageListCmp) View() tea.View {
102	return tea.NewView(
103		lipgloss.JoinVertical(
104			lipgloss.Left,
105			m.listCmp.View().String(),
106		),
107	)
108}
109
110// handleChildSession handles messages from child sessions (agent tools).
111func (m *messageListCmp) handleChildSession(event pubsub.Event[message.Message]) tea.Cmd {
112	var cmds []tea.Cmd
113	if len(event.Payload.ToolCalls()) == 0 {
114		return nil
115	}
116	items := m.listCmp.Items()
117	toolCallInx := NotFound
118	var toolCall messages.ToolCallCmp
119	for i := len(items) - 1; i >= 0; i-- {
120		if msg, ok := items[i].(messages.ToolCallCmp); ok {
121			if msg.GetToolCall().ID == event.Payload.SessionID {
122				toolCallInx = i
123				toolCall = msg
124			}
125		}
126	}
127	if toolCallInx == NotFound {
128		return nil
129	}
130	nestedToolCalls := toolCall.GetNestedToolCalls()
131	for _, tc := range event.Payload.ToolCalls() {
132		found := false
133		for existingInx, existingTC := range nestedToolCalls {
134			if existingTC.GetToolCall().ID == tc.ID {
135				nestedToolCalls[existingInx].SetToolCall(tc)
136				found = true
137				break
138			}
139		}
140		if !found {
141			nestedCall := messages.NewToolCallCmp(
142				event.Payload.ID,
143				tc,
144				messages.WithToolCallNested(true),
145			)
146			cmds = append(cmds, nestedCall.Init())
147			nestedToolCalls = append(
148				nestedToolCalls,
149				nestedCall,
150			)
151		}
152	}
153	toolCall.SetNestedToolCalls(nestedToolCalls)
154	m.listCmp.UpdateItem(
155		toolCallInx,
156		toolCall,
157	)
158	return tea.Batch(cmds...)
159}
160
161// handleMessageEvent processes different types of message events (created/updated).
162func (m *messageListCmp) handleMessageEvent(event pubsub.Event[message.Message]) tea.Cmd {
163	switch event.Type {
164	case pubsub.CreatedEvent:
165		if event.Payload.SessionID != m.session.ID {
166			return m.handleChildSession(event)
167		}
168		if m.messageExists(event.Payload.ID) {
169			return nil
170		}
171		return m.handleNewMessage(event.Payload)
172	case pubsub.UpdatedEvent:
173		if event.Payload.SessionID != m.session.ID {
174			return m.handleChildSession(event)
175		}
176		return m.handleUpdateAssistantMessage(event.Payload)
177	}
178	return nil
179}
180
181// messageExists checks if a message with the given ID already exists in the list.
182func (m *messageListCmp) messageExists(messageID string) bool {
183	items := m.listCmp.Items()
184	// Search backwards as new messages are more likely to be at the end
185	for i := len(items) - 1; i >= 0; i-- {
186		if msg, ok := items[i].(messages.MessageCmp); ok && msg.GetMessage().ID == messageID {
187			return true
188		}
189	}
190	return false
191}
192
193// handleNewMessage routes new messages to appropriate handlers based on role.
194func (m *messageListCmp) handleNewMessage(msg message.Message) tea.Cmd {
195	switch msg.Role {
196	case message.User:
197		return m.handleNewUserMessage(msg)
198	case message.Assistant:
199		return m.handleNewAssistantMessage(msg)
200	case message.Tool:
201		return m.handleToolMessage(msg)
202	}
203	return nil
204}
205
206// handleNewUserMessage adds a new user message to the list and updates the timestamp.
207func (m *messageListCmp) handleNewUserMessage(msg message.Message) tea.Cmd {
208	m.lastUserMessageTime = msg.CreatedAt
209	return m.listCmp.AppendItem(messages.NewMessageCmp(msg))
210}
211
212// handleToolMessage updates existing tool calls with their results.
213func (m *messageListCmp) handleToolMessage(msg message.Message) tea.Cmd {
214	items := m.listCmp.Items()
215	for _, tr := range msg.ToolResults() {
216		if toolCallIndex := m.findToolCallByID(items, tr.ToolCallID); toolCallIndex != NotFound {
217			toolCall := items[toolCallIndex].(messages.ToolCallCmp)
218			toolCall.SetToolResult(tr)
219			m.listCmp.UpdateItem(toolCallIndex, toolCall)
220		}
221	}
222	return nil
223}
224
225// findToolCallByID searches for a tool call with the specified ID.
226// Returns the index if found, NotFound otherwise.
227func (m *messageListCmp) findToolCallByID(items []util.Model, toolCallID string) int {
228	// Search backwards as tool calls are more likely to be recent
229	for i := len(items) - 1; i >= 0; i-- {
230		if toolCall, ok := items[i].(messages.ToolCallCmp); ok && toolCall.GetToolCall().ID == toolCallID {
231			return i
232		}
233	}
234	return NotFound
235}
236
237// handleUpdateAssistantMessage processes updates to assistant messages,
238// managing both message content and associated tool calls.
239func (m *messageListCmp) handleUpdateAssistantMessage(msg message.Message) tea.Cmd {
240	var cmds []tea.Cmd
241	items := m.listCmp.Items()
242
243	// Find existing assistant message and tool calls for this message
244	assistantIndex, existingToolCalls := m.findAssistantMessageAndToolCalls(items, msg.ID)
245
246	// Handle assistant message content
247	if cmd := m.updateAssistantMessageContent(msg, assistantIndex); cmd != nil {
248		cmds = append(cmds, cmd)
249	}
250
251	// Handle tool calls
252	if cmd := m.updateToolCalls(msg, existingToolCalls); cmd != nil {
253		cmds = append(cmds, cmd)
254	}
255
256	return tea.Batch(cmds...)
257}
258
259// findAssistantMessageAndToolCalls locates the assistant message and its tool calls.
260func (m *messageListCmp) findAssistantMessageAndToolCalls(items []util.Model, messageID string) (int, map[int]messages.ToolCallCmp) {
261	assistantIndex := NotFound
262	toolCalls := make(map[int]messages.ToolCallCmp)
263
264	// Search backwards as messages are more likely to be at the end
265	for i := len(items) - 1; i >= 0; i-- {
266		item := items[i]
267		if asMsg, ok := item.(messages.MessageCmp); ok {
268			if asMsg.GetMessage().ID == messageID {
269				assistantIndex = i
270			}
271		} else if tc, ok := item.(messages.ToolCallCmp); ok {
272			if tc.ParentMessageId() == messageID {
273				toolCalls[i] = tc
274			}
275		}
276	}
277
278	return assistantIndex, toolCalls
279}
280
281// updateAssistantMessageContent updates or removes the assistant message based on content.
282func (m *messageListCmp) updateAssistantMessageContent(msg message.Message, assistantIndex int) tea.Cmd {
283	if assistantIndex == NotFound {
284		return nil
285	}
286
287	shouldShowMessage := m.shouldShowAssistantMessage(msg)
288	hasToolCallsOnly := len(msg.ToolCalls()) > 0 && msg.Content().Text == ""
289
290	if shouldShowMessage {
291		m.listCmp.UpdateItem(
292			assistantIndex,
293			messages.NewMessageCmp(
294				msg,
295				messages.WithLastUserMessageTime(time.Unix(m.lastUserMessageTime, 0)),
296			),
297		)
298	} else if hasToolCallsOnly {
299		m.listCmp.DeleteItem(assistantIndex)
300	}
301
302	return nil
303}
304
305// shouldShowAssistantMessage determines if an assistant message should be displayed.
306func (m *messageListCmp) shouldShowAssistantMessage(msg message.Message) bool {
307	return len(msg.ToolCalls()) == 0 || msg.Content().Text != "" || msg.IsThinking()
308}
309
310// updateToolCalls handles updates to tool calls, updating existing ones and adding new ones.
311func (m *messageListCmp) updateToolCalls(msg message.Message, existingToolCalls map[int]messages.ToolCallCmp) tea.Cmd {
312	var cmds []tea.Cmd
313
314	for _, tc := range msg.ToolCalls() {
315		if cmd := m.updateOrAddToolCall(tc, existingToolCalls, msg.ID); cmd != nil {
316			cmds = append(cmds, cmd)
317		}
318	}
319
320	return tea.Batch(cmds...)
321}
322
323// updateOrAddToolCall updates an existing tool call or adds a new one.
324func (m *messageListCmp) updateOrAddToolCall(tc message.ToolCall, existingToolCalls map[int]messages.ToolCallCmp, messageID string) tea.Cmd {
325	// Try to find existing tool call
326	for index, existingTC := range existingToolCalls {
327		if tc.ID == existingTC.GetToolCall().ID {
328			existingTC.SetToolCall(tc)
329			m.listCmp.UpdateItem(index, existingTC)
330			return nil
331		}
332	}
333
334	// Add new tool call if not found
335	return m.listCmp.AppendItem(messages.NewToolCallCmp(messageID, tc))
336}
337
338// handleNewAssistantMessage processes new assistant messages and their tool calls.
339func (m *messageListCmp) handleNewAssistantMessage(msg message.Message) tea.Cmd {
340	var cmds []tea.Cmd
341
342	// Add assistant message if it should be displayed
343	if m.shouldShowAssistantMessage(msg) {
344		cmd := m.listCmp.AppendItem(
345			messages.NewMessageCmp(
346				msg,
347				messages.WithLastUserMessageTime(time.Unix(m.lastUserMessageTime, 0)),
348			),
349		)
350		cmds = append(cmds, cmd)
351	}
352
353	// Add tool calls
354	for _, tc := range msg.ToolCalls() {
355		cmd := m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc))
356		cmds = append(cmds, cmd)
357	}
358
359	return tea.Batch(cmds...)
360}
361
362// SetSession loads and displays messages for a new session.
363func (m *messageListCmp) SetSession(session session.Session) tea.Cmd {
364	if m.session.ID == session.ID {
365		return nil
366	}
367
368	m.session = session
369	sessionMessages, err := m.app.Messages.List(context.Background(), session.ID)
370	if err != nil {
371		return util.ReportError(err)
372	}
373
374	if len(sessionMessages) == 0 {
375		return m.listCmp.SetItems([]util.Model{})
376	}
377
378	// Initialize with first message timestamp
379	m.lastUserMessageTime = sessionMessages[0].CreatedAt
380
381	// Build tool result map for efficient lookup
382	toolResultMap := m.buildToolResultMap(sessionMessages)
383
384	// Convert messages to UI components
385	uiMessages := m.convertMessagesToUI(sessionMessages, toolResultMap)
386
387	return m.listCmp.SetItems(uiMessages)
388}
389
390// buildToolResultMap creates a map of tool call ID to tool result for efficient lookup.
391func (m *messageListCmp) buildToolResultMap(messages []message.Message) map[string]message.ToolResult {
392	toolResultMap := make(map[string]message.ToolResult)
393	for _, msg := range messages {
394		for _, tr := range msg.ToolResults() {
395			toolResultMap[tr.ToolCallID] = tr
396		}
397	}
398	return toolResultMap
399}
400
401// convertMessagesToUI converts database messages to UI components.
402func (m *messageListCmp) convertMessagesToUI(sessionMessages []message.Message, toolResultMap map[string]message.ToolResult) []util.Model {
403	uiMessages := make([]util.Model, 0)
404
405	for _, msg := range sessionMessages {
406		switch msg.Role {
407		case message.User:
408			m.lastUserMessageTime = msg.CreatedAt
409			uiMessages = append(uiMessages, messages.NewMessageCmp(msg))
410		case message.Assistant:
411			uiMessages = append(uiMessages, m.convertAssistantMessage(msg, toolResultMap)...)
412		}
413	}
414
415	return uiMessages
416}
417
418// convertAssistantMessage converts an assistant message and its tool calls to UI components.
419func (m *messageListCmp) convertAssistantMessage(msg message.Message, toolResultMap map[string]message.ToolResult) []util.Model {
420	var uiMessages []util.Model
421
422	// Add assistant message if it should be displayed
423	if m.shouldShowAssistantMessage(msg) {
424		uiMessages = append(
425			uiMessages,
426			messages.NewMessageCmp(
427				msg,
428				messages.WithLastUserMessageTime(time.Unix(m.lastUserMessageTime, 0)),
429			),
430		)
431	}
432
433	// Add tool calls with their results and status
434	for _, tc := range msg.ToolCalls() {
435		options := m.buildToolCallOptions(tc, msg, toolResultMap)
436		uiMessages = append(uiMessages, messages.NewToolCallCmp(msg.ID, tc, options...))
437		// If this tool call is the agent tool, fetch nested tool calls
438		if tc.Name == agent.AgentToolName {
439			nestedMessages, _ := m.app.Messages.List(context.Background(), tc.ID)
440			nestedUIMessages := m.convertMessagesToUI(nestedMessages, make(map[string]message.ToolResult))
441			nestedToolCalls := make([]messages.ToolCallCmp, 0, len(nestedUIMessages))
442			for _, nestedMsg := range nestedUIMessages {
443				if toolCall, ok := nestedMsg.(messages.ToolCallCmp); ok {
444					toolCall.SetIsNested(true)
445					nestedToolCalls = append(nestedToolCalls, toolCall)
446				}
447			}
448			uiMessages[len(uiMessages)-1].(messages.ToolCallCmp).SetNestedToolCalls(nestedToolCalls)
449		}
450	}
451
452	return uiMessages
453}
454
455// buildToolCallOptions creates options for tool call components based on results and status.
456func (m *messageListCmp) buildToolCallOptions(tc message.ToolCall, msg message.Message, toolResultMap map[string]message.ToolResult) []messages.ToolCallOption {
457	var options []messages.ToolCallOption
458
459	// Add tool result if available
460	if tr, ok := toolResultMap[tc.ID]; ok {
461		options = append(options, messages.WithToolCallResult(tr))
462	}
463
464	// Add cancelled status if applicable
465	if msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonCanceled {
466		options = append(options, messages.WithToolCallCancelled())
467	}
468
469	return options
470}
471
472// GetSize returns the current width and height of the component.
473func (m *messageListCmp) GetSize() (int, int) {
474	return m.width, m.height
475}
476
477// SetSize updates the component dimensions and propagates to the list component.
478func (m *messageListCmp) SetSize(width int, height int) tea.Cmd {
479	m.width = width
480	m.height = height - 1
481	return m.listCmp.SetSize(width, height-1)
482}
483
484// Blur implements MessageListCmp.
485func (m *messageListCmp) Blur() tea.Cmd {
486	return m.listCmp.Blur()
487}
488
489// Focus implements MessageListCmp.
490func (m *messageListCmp) Focus() tea.Cmd {
491	return m.listCmp.Focus()
492}
493
494// IsFocused implements MessageListCmp.
495func (m *messageListCmp) IsFocused() bool {
496	return m.listCmp.IsFocused()
497}