messages.go

  1package repl
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"sort"
  8	"strings"
  9
 10	"github.com/charmbracelet/bubbles/key"
 11	"github.com/charmbracelet/bubbles/viewport"
 12	tea "github.com/charmbracelet/bubbletea"
 13	"github.com/charmbracelet/glamour"
 14	"github.com/charmbracelet/lipgloss"
 15	"github.com/kujtimiihoxha/termai/internal/app"
 16	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 17	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 18	"github.com/kujtimiihoxha/termai/internal/message"
 19	"github.com/kujtimiihoxha/termai/internal/pubsub"
 20	"github.com/kujtimiihoxha/termai/internal/session"
 21	"github.com/kujtimiihoxha/termai/internal/tui/layout"
 22	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 23)
 24
 25type MessagesCmp interface {
 26	tea.Model
 27	layout.Focusable
 28	layout.Bordered
 29	layout.Sizeable
 30	layout.Bindings
 31}
 32
 33type messagesCmp struct {
 34	app            *app.App
 35	messages       []message.Message
 36	selectedMsgIdx int // Index of the selected message
 37	session        session.Session
 38	viewport       viewport.Model
 39	mdRenderer     *glamour.TermRenderer
 40	width          int
 41	height         int
 42	focused        bool
 43	cachedView     string
 44}
 45
 46func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 47	switch msg := msg.(type) {
 48	case pubsub.Event[message.Message]:
 49		if msg.Type == pubsub.CreatedEvent {
 50			if msg.Payload.SessionID == m.session.ID {
 51				m.messages = append(m.messages, msg.Payload)
 52				m.renderView()
 53				m.viewport.GotoBottom()
 54			}
 55			for _, v := range m.messages {
 56				for _, c := range v.ToolCalls() {
 57					// the message is being added to the session of a tool called
 58					if c.ID == msg.Payload.SessionID {
 59						m.renderView()
 60						m.viewport.GotoBottom()
 61					}
 62				}
 63			}
 64		} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
 65			for i, v := range m.messages {
 66				if v.ID == msg.Payload.ID {
 67					m.messages[i] = msg.Payload
 68					m.renderView()
 69					if i == len(m.messages)-1 {
 70						m.viewport.GotoBottom()
 71					}
 72					break
 73				}
 74			}
 75		}
 76	case pubsub.Event[session.Session]:
 77		if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID {
 78			m.session = msg.Payload
 79		}
 80	case SelectedSessionMsg:
 81		m.session, _ = m.app.Sessions.Get(context.Background(), msg.SessionID)
 82		m.messages, _ = m.app.Messages.List(context.Background(), m.session.ID)
 83		m.renderView()
 84		m.viewport.GotoBottom()
 85	}
 86	if m.focused {
 87		u, cmd := m.viewport.Update(msg)
 88		m.viewport = u
 89		return m, cmd
 90	}
 91	return m, nil
 92}
 93
 94func borderColor(role message.MessageRole) lipgloss.TerminalColor {
 95	switch role {
 96	case message.Assistant:
 97		return styles.Mauve
 98	case message.User:
 99		return styles.Rosewater
100	}
101	return styles.Blue
102}
103
104func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string {
105	role := ""
106	icon := ""
107	switch msgRole {
108	case message.Assistant:
109		role = "Assistant"
110		icon = styles.BotIcon
111	case message.User:
112		role = "User"
113		icon = styles.UserIcon
114	}
115	return map[layout.BorderPosition]string{
116		layout.TopLeftBorder: lipgloss.NewStyle().
117			Padding(0, 1).
118			Bold(true).
119			Foreground(styles.Crust).
120			Background(borderColor(msgRole)).
121			Render(fmt.Sprintf("%s %s ", role, icon)),
122		layout.TopRightBorder: lipgloss.NewStyle().
123			Padding(0, 1).
124			Bold(true).
125			Foreground(styles.Crust).
126			Background(borderColor(msgRole)).
127			Render(fmt.Sprintf("#%d ", currentMessage)),
128	}
129}
130
131func hasUnfinishedMessages(messages []message.Message) bool {
132	if len(messages) == 0 {
133		return false
134	}
135	for _, msg := range messages {
136		if !msg.IsFinished() {
137			return true
138		}
139	}
140	return false
141}
142
143func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string {
144	allParts := []string{content}
145
146	leftPaddingValue := 4
147	connectorStyle := lipgloss.NewStyle().
148		Foreground(styles.Peach).
149		Bold(true)
150
151	toolCallStyle := lipgloss.NewStyle().
152		Border(lipgloss.RoundedBorder()).
153		BorderForeground(styles.Peach).
154		Width(m.width-leftPaddingValue-5).
155		Padding(0, 1)
156
157	toolResultStyle := lipgloss.NewStyle().
158		Border(lipgloss.RoundedBorder()).
159		BorderForeground(styles.Green).
160		Width(m.width-leftPaddingValue-5).
161		Padding(0, 1)
162
163	leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue)
164
165	runningStyle := lipgloss.NewStyle().
166		Foreground(styles.Peach).
167		Bold(true)
168
169	renderTool := func(toolCall message.ToolCall) string {
170		toolHeader := lipgloss.NewStyle().
171			Bold(true).
172			Foreground(styles.Blue).
173			Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
174
175		var paramLines []string
176		var args map[string]interface{}
177		var paramOrder []string
178
179		json.Unmarshal([]byte(toolCall.Input), &args)
180
181		for key := range args {
182			paramOrder = append(paramOrder, key)
183		}
184		sort.Strings(paramOrder)
185
186		for _, name := range paramOrder {
187			value := args[name]
188			paramName := lipgloss.NewStyle().
189				Foreground(styles.Peach).
190				Bold(true).
191				Render(name)
192
193			truncate := m.width - leftPaddingValue*2 - 10
194			if len(fmt.Sprintf("%v", value)) > truncate {
195				value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
196			}
197			paramValue := fmt.Sprintf("%v", value)
198			paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
199		}
200
201		paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
202
203		toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
204		return toolCallStyle.Render(toolContent)
205	}
206
207	findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult {
208		for _, msg := range messages {
209			if msg.Role == message.Tool {
210				for _, result := range msg.ToolResults() {
211					if result.ToolCallID == toolCallID {
212						return &result
213					}
214				}
215			}
216		}
217		return nil
218	}
219
220	renderToolResult := func(result message.ToolResult) string {
221		resultHeader := lipgloss.NewStyle().
222			Bold(true).
223			Foreground(styles.Green).
224			Render(fmt.Sprintf("%s Result", styles.CheckIcon))
225
226		// Use the same style for both header and border if it's an error
227		borderColor := styles.Green
228		if result.IsError {
229			resultHeader = lipgloss.NewStyle().
230				Bold(true).
231				Foreground(styles.Red).
232				Render(fmt.Sprintf("%s Error", styles.ErrorIcon))
233			borderColor = styles.Red
234		}
235
236		truncate := 200
237		content := result.Content
238		if len(content) > truncate {
239			content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
240		}
241
242		resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content)
243		return toolResultStyle.BorderForeground(borderColor).Render(resultContent)
244	}
245
246	connector := connectorStyle.Render("└─> Tool Calls:")
247	allParts = append(allParts, connector)
248
249	for _, toolCall := range tools {
250		toolOutput := renderTool(toolCall)
251		allParts = append(allParts, leftPadding.Render(toolOutput))
252
253		result := findToolResult(toolCall.ID, futureMessages)
254		if result != nil {
255
256			resultOutput := renderToolResult(*result)
257			allParts = append(allParts, leftPadding.Render(resultOutput))
258
259		} else if toolCall.Name == agent.AgentToolName {
260
261			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
262			allParts = append(allParts, leftPadding.Render(runningIndicator))
263			taskSessionMessages, _ := m.app.Messages.List(context.Background(), toolCall.ID)
264			for _, msg := range taskSessionMessages {
265				if msg.Role == message.Assistant {
266					for _, toolCall := range msg.ToolCalls() {
267						toolHeader := lipgloss.NewStyle().
268							Bold(true).
269							Foreground(styles.Blue).
270							Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
271
272						var paramLines []string
273						var args map[string]interface{}
274						var paramOrder []string
275
276						json.Unmarshal([]byte(toolCall.Input), &args)
277
278						for key := range args {
279							paramOrder = append(paramOrder, key)
280						}
281						sort.Strings(paramOrder)
282
283						for _, name := range paramOrder {
284							value := args[name]
285							paramName := lipgloss.NewStyle().
286								Foreground(styles.Peach).
287								Bold(true).
288								Render(name)
289
290							truncate := 50
291							if len(fmt.Sprintf("%v", value)) > truncate {
292								value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
293							}
294							paramValue := fmt.Sprintf("%v", value)
295							paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
296						}
297
298						paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
299						toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
300						toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent)
301						allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput))
302					}
303				}
304			}
305
306		} else {
307			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
308			allParts = append(allParts, "    "+runningIndicator)
309		}
310	}
311
312	for _, msg := range futureMessages {
313		if msg.Content().String() != "" || msg.FinishReason() == "canceled" {
314			break
315		}
316
317		for _, toolCall := range msg.ToolCalls() {
318			toolOutput := renderTool(toolCall)
319			allParts = append(allParts, "    "+strings.ReplaceAll(toolOutput, "\n", "\n    "))
320
321			result := findToolResult(toolCall.ID, futureMessages)
322			if result != nil {
323				resultOutput := renderToolResult(*result)
324				allParts = append(allParts, "    "+strings.ReplaceAll(resultOutput, "\n", "\n    "))
325			} else {
326				runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
327				allParts = append(allParts, "    "+runningIndicator)
328			}
329		}
330	}
331
332	return lipgloss.JoinVertical(lipgloss.Left, allParts...)
333}
334
335func (m *messagesCmp) renderView() {
336	stringMessages := make([]string, 0)
337	r, _ := glamour.NewTermRenderer(
338		glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
339		glamour.WithWordWrap(m.width-20),
340		glamour.WithEmoji(),
341	)
342	textStyle := lipgloss.NewStyle().Width(m.width - 4)
343	currentMessage := 1
344	displayedMsgCount := 0 // Track the actual displayed messages count
345
346	prevMessageWasUser := false
347	for inx, msg := range m.messages {
348		content := msg.Content().String()
349		if content != "" || prevMessageWasUser || msg.FinishReason() == "canceled" {
350			if msg.ReasoningContent().String() != "" && content == "" {
351				content = msg.ReasoningContent().String()
352			} else if content == "" {
353				content = "..."
354			}
355			if msg.FinishReason() == "canceled" {
356				content, _ = r.Render(content)
357				content += lipgloss.NewStyle().Padding(1, 0, 0, 1).Foreground(styles.Error).Render(styles.ErrorIcon + " Canceled")
358			} else {
359				content, _ = r.Render(content)
360			}
361
362			isSelected := inx == m.selectedMsgIdx
363
364			border := lipgloss.DoubleBorder()
365			activeColor := borderColor(msg.Role)
366
367			if isSelected {
368				activeColor = styles.Primary // Use primary color for selected message
369			}
370
371			content = layout.Borderize(
372				textStyle.Render(content),
373				layout.BorderOptions{
374					InactiveBorder: border,
375					ActiveBorder:   border,
376					ActiveColor:    activeColor,
377					InactiveColor:  borderColor(msg.Role),
378					EmbeddedText:   borderText(msg.Role, currentMessage),
379				},
380			)
381			if len(msg.ToolCalls()) > 0 {
382				content = m.renderMessageWithToolCall(content, msg.ToolCalls(), m.messages[inx+1:])
383			}
384			stringMessages = append(stringMessages, content)
385			currentMessage++
386			displayedMsgCount++
387		}
388		if msg.Role == message.User && msg.Content().String() != "" {
389			prevMessageWasUser = true
390		} else {
391			prevMessageWasUser = false
392		}
393	}
394	m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...))
395}
396
397func (m *messagesCmp) View() string {
398	return lipgloss.NewStyle().Padding(1).Render(m.viewport.View())
399}
400
401func (m *messagesCmp) BindingKeys() []key.Binding {
402	keys := layout.KeyMapToSlice(m.viewport.KeyMap)
403
404	return keys
405}
406
407func (m *messagesCmp) Blur() tea.Cmd {
408	m.focused = false
409	return nil
410}
411
412func (m *messagesCmp) projectDiagnostics() string {
413	errorDiagnostics := []protocol.Diagnostic{}
414	warnDiagnostics := []protocol.Diagnostic{}
415	hintDiagnostics := []protocol.Diagnostic{}
416	infoDiagnostics := []protocol.Diagnostic{}
417	for _, client := range m.app.LSPClients {
418		for _, d := range client.GetDiagnostics() {
419			for _, diag := range d {
420				switch diag.Severity {
421				case protocol.SeverityError:
422					errorDiagnostics = append(errorDiagnostics, diag)
423				case protocol.SeverityWarning:
424					warnDiagnostics = append(warnDiagnostics, diag)
425				case protocol.SeverityHint:
426					hintDiagnostics = append(hintDiagnostics, diag)
427				case protocol.SeverityInformation:
428					infoDiagnostics = append(infoDiagnostics, diag)
429				}
430			}
431		}
432	}
433
434	if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 {
435		return "No diagnostics"
436	}
437
438	diagnostics := []string{}
439
440	if len(errorDiagnostics) > 0 {
441		errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics)))
442		diagnostics = append(diagnostics, errStr)
443	}
444	if len(warnDiagnostics) > 0 {
445		warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics)))
446		diagnostics = append(diagnostics, warnStr)
447	}
448	if len(hintDiagnostics) > 0 {
449		hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics)))
450		diagnostics = append(diagnostics, hintStr)
451	}
452	if len(infoDiagnostics) > 0 {
453		infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics)))
454		diagnostics = append(diagnostics, infoStr)
455	}
456
457	return strings.Join(diagnostics, " ")
458}
459
460func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
461	title := m.session.Title
462	titleWidth := m.width / 2
463	if len(title) > titleWidth {
464		title = title[:titleWidth] + "..."
465	}
466	if m.focused {
467		title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
468	}
469	borderTest := map[layout.BorderPosition]string{
470		layout.TopLeftBorder:     title,
471		layout.BottomRightBorder: m.projectDiagnostics(),
472	}
473	if hasUnfinishedMessages(m.messages) {
474		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...")
475	} else {
476		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ")
477	}
478
479	return borderTest
480}
481
482func (m *messagesCmp) Focus() tea.Cmd {
483	m.focused = true
484	return nil
485}
486
487func (m *messagesCmp) GetSize() (int, int) {
488	return m.width, m.height
489}
490
491func (m *messagesCmp) IsFocused() bool {
492	return m.focused
493}
494
495func (m *messagesCmp) SetSize(width int, height int) {
496	m.width = width
497	m.height = height
498	m.viewport.Width = width - 2   // padding
499	m.viewport.Height = height - 2 // padding
500	m.renderView()
501}
502
503func (m *messagesCmp) Init() tea.Cmd {
504	return nil
505}
506
507func NewMessagesCmp(app *app.App) MessagesCmp {
508	return &messagesCmp{
509		app:      app,
510		messages: []message.Message{},
511		viewport: viewport.New(0, 0),
512	}
513}