chat.go

  1package chat
  2
  3import (
  4	"context"
  5	"time"
  6
  7	"github.com/charmbracelet/bubbles/v2/key"
  8	tea "github.com/charmbracelet/bubbletea/v2"
  9	"github.com/charmbracelet/crush/internal/app"
 10	"github.com/charmbracelet/crush/internal/llm/agent"
 11	"github.com/charmbracelet/crush/internal/message"
 12	"github.com/charmbracelet/crush/internal/pubsub"
 13	"github.com/charmbracelet/crush/internal/session"
 14	"github.com/charmbracelet/crush/internal/tui/components/chat/messages"
 15	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 16	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21type SendMsg struct {
 22	Text        string
 23	Attachments []message.Attachment
 24}
 25
 26type SessionSelectedMsg = session.Session
 27
 28type SessionClearedMsg struct{}
 29
 30const (
 31	NotFound = -1
 32)
 33
 34// MessageListCmp represents a component that displays a list of chat messages
 35// with support for real-time updates and session management.
 36type MessageListCmp interface {
 37	util.Model
 38	layout.Sizeable
 39	layout.Focusable
 40}
 41
 42// messageListCmp implements MessageListCmp, providing a virtualized list
 43// of chat messages with support for tool calls, real-time updates, and
 44// session switching.
 45type messageListCmp struct {
 46	app              *app.App
 47	width, height    int
 48	session          session.Session
 49	listCmp          list.ListModel
 50	previousSelected int // Last selected item index for restoring focus
 51
 52	lastUserMessageTime int64
 53	defaultListKeyMap   list.KeyMap
 54}
 55
 56// NewMessagesListCmp creates a new message list component with custom keybindings
 57// and reverse ordering (newest messages at bottom).
 58func NewMessagesListCmp(app *app.App) MessageListCmp {
 59	defaultListKeyMap := list.DefaultKeyMap()
 60	listCmp := list.New(
 61		list.WithGapSize(1),
 62		list.WithReverse(true),
 63		list.WithKeyMap(defaultListKeyMap),
 64	)
 65	return &messageListCmp{
 66		app:               app,
 67		listCmp:           listCmp,
 68		previousSelected:  list.NoSelection,
 69		defaultListKeyMap: defaultListKeyMap,
 70	}
 71}
 72
 73// Init initializes the component (no initialization needed).
 74func (m *messageListCmp) Init() tea.Cmd {
 75	return tea.Sequence(m.listCmp.Init(), m.listCmp.Blur())
 76}
 77
 78// Update handles incoming messages and updates the component state.
 79func (m *messageListCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 80	switch msg := msg.(type) {
 81	case SessionSelectedMsg:
 82		if msg.ID != m.session.ID {
 83			cmd := m.SetSession(msg)
 84			return m, cmd
 85		}
 86		return m, nil
 87	case SessionClearedMsg:
 88		m.session = session.Session{}
 89		return m, m.listCmp.SetItems([]util.Model{})
 90
 91	case pubsub.Event[message.Message]:
 92		cmd := m.handleMessageEvent(msg)
 93		return m, cmd
 94	default:
 95		var cmds []tea.Cmd
 96		u, cmd := m.listCmp.Update(msg)
 97		m.listCmp = u.(list.ListModel)
 98		cmds = append(cmds, cmd)
 99		return m, tea.Batch(cmds...)
100	}
101}
102
103// View renders the message list or an initial screen if empty.
104func (m *messageListCmp) View() tea.View {
105	return tea.NewView(
106		lipgloss.JoinVertical(
107			lipgloss.Left,
108			m.listCmp.View().String(),
109		),
110	)
111}
112
113// handleChildSession handles messages from child sessions (agent tools).
114func (m *messageListCmp) handleChildSession(event pubsub.Event[message.Message]) tea.Cmd {
115	var cmds []tea.Cmd
116	if len(event.Payload.ToolCalls()) == 0 {
117		return nil
118	}
119	items := m.listCmp.Items()
120	toolCallInx := NotFound
121	var toolCall messages.ToolCallCmp
122	for i := len(items) - 1; i >= 0; i-- {
123		if msg, ok := items[i].(messages.ToolCallCmp); ok {
124			if msg.GetToolCall().ID == event.Payload.SessionID {
125				toolCallInx = i
126				toolCall = msg
127			}
128		}
129	}
130	if toolCallInx == NotFound {
131		return nil
132	}
133	nestedToolCalls := toolCall.GetNestedToolCalls()
134	for _, tc := range event.Payload.ToolCalls() {
135		found := false
136		for existingInx, existingTC := range nestedToolCalls {
137			if existingTC.GetToolCall().ID == tc.ID {
138				nestedToolCalls[existingInx].SetToolCall(tc)
139				found = true
140				break
141			}
142		}
143		if !found {
144			nestedCall := messages.NewToolCallCmp(
145				event.Payload.ID,
146				tc,
147				messages.WithToolCallNested(true),
148			)
149			cmds = append(cmds, nestedCall.Init())
150			nestedToolCalls = append(
151				nestedToolCalls,
152				nestedCall,
153			)
154		}
155	}
156	toolCall.SetNestedToolCalls(nestedToolCalls)
157	m.listCmp.UpdateItem(
158		toolCallInx,
159		toolCall,
160	)
161	return tea.Batch(cmds...)
162}
163
164// handleMessageEvent processes different types of message events (created/updated).
165func (m *messageListCmp) handleMessageEvent(event pubsub.Event[message.Message]) tea.Cmd {
166	switch event.Type {
167	case pubsub.CreatedEvent:
168		if event.Payload.SessionID != m.session.ID {
169			return m.handleChildSession(event)
170		}
171		if m.messageExists(event.Payload.ID) {
172			return nil
173		}
174		return m.handleNewMessage(event.Payload)
175	case pubsub.UpdatedEvent:
176		if event.Payload.SessionID != m.session.ID {
177			return m.handleChildSession(event)
178		}
179		return m.handleUpdateAssistantMessage(event.Payload)
180	}
181	return nil
182}
183
184// messageExists checks if a message with the given ID already exists in the list.
185func (m *messageListCmp) messageExists(messageID string) bool {
186	items := m.listCmp.Items()
187	// Search backwards as new messages are more likely to be at the end
188	for i := len(items) - 1; i >= 0; i-- {
189		if msg, ok := items[i].(messages.MessageCmp); ok && msg.GetMessage().ID == messageID {
190			return true
191		}
192	}
193	return false
194}
195
196// handleNewMessage routes new messages to appropriate handlers based on role.
197func (m *messageListCmp) handleNewMessage(msg message.Message) tea.Cmd {
198	switch msg.Role {
199	case message.User:
200		return m.handleNewUserMessage(msg)
201	case message.Assistant:
202		return m.handleNewAssistantMessage(msg)
203	case message.Tool:
204		return m.handleToolMessage(msg)
205	}
206	return nil
207}
208
209// handleNewUserMessage adds a new user message to the list and updates the timestamp.
210func (m *messageListCmp) handleNewUserMessage(msg message.Message) tea.Cmd {
211	m.lastUserMessageTime = msg.CreatedAt
212	return m.listCmp.AppendItem(messages.NewMessageCmp(msg))
213}
214
215// handleToolMessage updates existing tool calls with their results.
216func (m *messageListCmp) handleToolMessage(msg message.Message) tea.Cmd {
217	items := m.listCmp.Items()
218	for _, tr := range msg.ToolResults() {
219		if toolCallIndex := m.findToolCallByID(items, tr.ToolCallID); toolCallIndex != NotFound {
220			toolCall := items[toolCallIndex].(messages.ToolCallCmp)
221			toolCall.SetToolResult(tr)
222			m.listCmp.UpdateItem(toolCallIndex, toolCall)
223		}
224	}
225	return nil
226}
227
228// findToolCallByID searches for a tool call with the specified ID.
229// Returns the index if found, NotFound otherwise.
230func (m *messageListCmp) findToolCallByID(items []util.Model, toolCallID string) int {
231	// Search backwards as tool calls are more likely to be recent
232	for i := len(items) - 1; i >= 0; i-- {
233		if toolCall, ok := items[i].(messages.ToolCallCmp); ok && toolCall.GetToolCall().ID == toolCallID {
234			return i
235		}
236	}
237	return NotFound
238}
239
240// handleUpdateAssistantMessage processes updates to assistant messages,
241// managing both message content and associated tool calls.
242func (m *messageListCmp) handleUpdateAssistantMessage(msg message.Message) tea.Cmd {
243	var cmds []tea.Cmd
244	items := m.listCmp.Items()
245
246	// Find existing assistant message and tool calls for this message
247	assistantIndex, existingToolCalls := m.findAssistantMessageAndToolCalls(items, msg.ID)
248
249	// Handle assistant message content
250	if cmd := m.updateAssistantMessageContent(msg, assistantIndex); cmd != nil {
251		cmds = append(cmds, cmd)
252	}
253
254	// Handle tool calls
255	if cmd := m.updateToolCalls(msg, existingToolCalls); cmd != nil {
256		cmds = append(cmds, cmd)
257	}
258
259	return tea.Batch(cmds...)
260}
261
262// findAssistantMessageAndToolCalls locates the assistant message and its tool calls.
263func (m *messageListCmp) findAssistantMessageAndToolCalls(items []util.Model, messageID string) (int, map[int]messages.ToolCallCmp) {
264	assistantIndex := NotFound
265	toolCalls := make(map[int]messages.ToolCallCmp)
266
267	// Search backwards as messages are more likely to be at the end
268	for i := len(items) - 1; i >= 0; i-- {
269		item := items[i]
270		if asMsg, ok := item.(messages.MessageCmp); ok {
271			if asMsg.GetMessage().ID == messageID {
272				assistantIndex = i
273			}
274		} else if tc, ok := item.(messages.ToolCallCmp); ok {
275			if tc.ParentMessageID() == messageID {
276				toolCalls[i] = tc
277			}
278		}
279	}
280
281	return assistantIndex, toolCalls
282}
283
284// updateAssistantMessageContent updates or removes the assistant message based on content.
285func (m *messageListCmp) updateAssistantMessageContent(msg message.Message, assistantIndex int) tea.Cmd {
286	if assistantIndex == NotFound {
287		return nil
288	}
289
290	shouldShowMessage := m.shouldShowAssistantMessage(msg)
291	hasToolCallsOnly := len(msg.ToolCalls()) > 0 && msg.Content().Text == ""
292
293	if shouldShowMessage {
294		m.listCmp.UpdateItem(
295			assistantIndex,
296			messages.NewMessageCmp(
297				msg,
298			),
299		)
300
301		if msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonEndTurn {
302			m.listCmp.AppendItem(
303				messages.NewAssistantSection(
304					msg,
305					time.Unix(m.lastUserMessageTime, 0),
306				),
307			)
308		}
309	} else if hasToolCallsOnly {
310		m.listCmp.DeleteItem(assistantIndex)
311	}
312
313	return nil
314}
315
316// shouldShowAssistantMessage determines if an assistant message should be displayed.
317func (m *messageListCmp) shouldShowAssistantMessage(msg message.Message) bool {
318	return len(msg.ToolCalls()) == 0 || msg.Content().Text != "" || msg.IsThinking()
319}
320
321// updateToolCalls handles updates to tool calls, updating existing ones and adding new ones.
322func (m *messageListCmp) updateToolCalls(msg message.Message, existingToolCalls map[int]messages.ToolCallCmp) tea.Cmd {
323	var cmds []tea.Cmd
324
325	for _, tc := range msg.ToolCalls() {
326		if cmd := m.updateOrAddToolCall(tc, existingToolCalls, msg.ID); cmd != nil {
327			cmds = append(cmds, cmd)
328		}
329	}
330
331	return tea.Batch(cmds...)
332}
333
334// updateOrAddToolCall updates an existing tool call or adds a new one.
335func (m *messageListCmp) updateOrAddToolCall(tc message.ToolCall, existingToolCalls map[int]messages.ToolCallCmp, messageID string) tea.Cmd {
336	// Try to find existing tool call
337	for index, existingTC := range existingToolCalls {
338		if tc.ID == existingTC.GetToolCall().ID {
339			existingTC.SetToolCall(tc)
340			m.listCmp.UpdateItem(index, existingTC)
341			return nil
342		}
343	}
344
345	// Add new tool call if not found
346	return m.listCmp.AppendItem(messages.NewToolCallCmp(messageID, tc))
347}
348
349// handleNewAssistantMessage processes new assistant messages and their tool calls.
350func (m *messageListCmp) handleNewAssistantMessage(msg message.Message) tea.Cmd {
351	var cmds []tea.Cmd
352
353	// Add assistant message if it should be displayed
354	if m.shouldShowAssistantMessage(msg) {
355		cmd := m.listCmp.AppendItem(
356			messages.NewMessageCmp(
357				msg,
358			),
359		)
360		cmds = append(cmds, cmd)
361	}
362
363	// Add tool calls
364	for _, tc := range msg.ToolCalls() {
365		cmd := m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc))
366		cmds = append(cmds, cmd)
367	}
368
369	return tea.Batch(cmds...)
370}
371
372// SetSession loads and displays messages for a new session.
373func (m *messageListCmp) SetSession(session session.Session) tea.Cmd {
374	if m.session.ID == session.ID {
375		return nil
376	}
377
378	m.session = session
379	sessionMessages, err := m.app.Messages.List(context.Background(), session.ID)
380	if err != nil {
381		return util.ReportError(err)
382	}
383
384	if len(sessionMessages) == 0 {
385		return m.listCmp.SetItems([]util.Model{})
386	}
387
388	// Initialize with first message timestamp
389	m.lastUserMessageTime = sessionMessages[0].CreatedAt
390
391	// Build tool result map for efficient lookup
392	toolResultMap := m.buildToolResultMap(sessionMessages)
393
394	// Convert messages to UI components
395	uiMessages := m.convertMessagesToUI(sessionMessages, toolResultMap)
396
397	return m.listCmp.SetItems(uiMessages)
398}
399
400// buildToolResultMap creates a map of tool call ID to tool result for efficient lookup.
401func (m *messageListCmp) buildToolResultMap(messages []message.Message) map[string]message.ToolResult {
402	toolResultMap := make(map[string]message.ToolResult)
403	for _, msg := range messages {
404		for _, tr := range msg.ToolResults() {
405			toolResultMap[tr.ToolCallID] = tr
406		}
407	}
408	return toolResultMap
409}
410
411// convertMessagesToUI converts database messages to UI components.
412func (m *messageListCmp) convertMessagesToUI(sessionMessages []message.Message, toolResultMap map[string]message.ToolResult) []util.Model {
413	uiMessages := make([]util.Model, 0)
414
415	for _, msg := range sessionMessages {
416		switch msg.Role {
417		case message.User:
418			m.lastUserMessageTime = msg.CreatedAt
419			uiMessages = append(uiMessages, messages.NewMessageCmp(msg))
420		case message.Assistant:
421			uiMessages = append(uiMessages, m.convertAssistantMessage(msg, toolResultMap)...)
422			if msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonEndTurn {
423				uiMessages = append(uiMessages, messages.NewAssistantSection(msg, time.Unix(m.lastUserMessageTime, 0)))
424			}
425		}
426	}
427
428	return uiMessages
429}
430
431// convertAssistantMessage converts an assistant message and its tool calls to UI components.
432func (m *messageListCmp) convertAssistantMessage(msg message.Message, toolResultMap map[string]message.ToolResult) []util.Model {
433	var uiMessages []util.Model
434
435	// Add assistant message if it should be displayed
436	if m.shouldShowAssistantMessage(msg) {
437		uiMessages = append(
438			uiMessages,
439			messages.NewMessageCmp(
440				msg,
441			),
442		)
443	}
444
445	// Add tool calls with their results and status
446	for _, tc := range msg.ToolCalls() {
447		options := m.buildToolCallOptions(tc, msg, toolResultMap)
448		uiMessages = append(uiMessages, messages.NewToolCallCmp(msg.ID, tc, options...))
449		// If this tool call is the agent tool, fetch nested tool calls
450		if tc.Name == agent.AgentToolName {
451			nestedMessages, _ := m.app.Messages.List(context.Background(), tc.ID)
452			nestedUIMessages := m.convertMessagesToUI(nestedMessages, make(map[string]message.ToolResult))
453			nestedToolCalls := make([]messages.ToolCallCmp, 0, len(nestedUIMessages))
454			for _, nestedMsg := range nestedUIMessages {
455				if toolCall, ok := nestedMsg.(messages.ToolCallCmp); ok {
456					toolCall.SetIsNested(true)
457					nestedToolCalls = append(nestedToolCalls, toolCall)
458				}
459			}
460			uiMessages[len(uiMessages)-1].(messages.ToolCallCmp).SetNestedToolCalls(nestedToolCalls)
461		}
462	}
463
464	return uiMessages
465}
466
467// buildToolCallOptions creates options for tool call components based on results and status.
468func (m *messageListCmp) buildToolCallOptions(tc message.ToolCall, msg message.Message, toolResultMap map[string]message.ToolResult) []messages.ToolCallOption {
469	var options []messages.ToolCallOption
470
471	// Add tool result if available
472	if tr, ok := toolResultMap[tc.ID]; ok {
473		options = append(options, messages.WithToolCallResult(tr))
474	}
475
476	// Add cancelled status if applicable
477	if msg.FinishPart() != nil && msg.FinishPart().Reason == message.FinishReasonCanceled {
478		options = append(options, messages.WithToolCallCancelled())
479	}
480
481	return options
482}
483
484// GetSize returns the current width and height of the component.
485func (m *messageListCmp) GetSize() (int, int) {
486	return m.width, m.height
487}
488
489// SetSize updates the component dimensions and propagates to the list component.
490func (m *messageListCmp) SetSize(width int, height int) tea.Cmd {
491	m.width = width
492	m.height = height - 1
493	return m.listCmp.SetSize(width, height-1)
494}
495
496// Blur implements MessageListCmp.
497func (m *messageListCmp) Blur() tea.Cmd {
498	return m.listCmp.Blur()
499}
500
501// Focus implements MessageListCmp.
502func (m *messageListCmp) Focus() tea.Cmd {
503	return m.listCmp.Focus()
504}
505
506// IsFocused implements MessageListCmp.
507func (m *messageListCmp) IsFocused() bool {
508	return m.listCmp.IsFocused()
509}
510
511func (m *messageListCmp) Bindings() []key.Binding {
512	bindings := m.defaultListKeyMap.KeyBindings()
513	return bindings
514}