messages.go

  1package repl
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"sort"
  7	"strings"
  8	"time"
  9
 10	"github.com/charmbracelet/bubbles/key"
 11	"github.com/charmbracelet/bubbles/viewport"
 12	tea "github.com/charmbracelet/bubbletea"
 13	"github.com/charmbracelet/glamour"
 14	"github.com/charmbracelet/lipgloss"
 15	"github.com/kujtimiihoxha/termai/internal/app"
 16	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 17	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 18	"github.com/kujtimiihoxha/termai/internal/message"
 19	"github.com/kujtimiihoxha/termai/internal/pubsub"
 20	"github.com/kujtimiihoxha/termai/internal/session"
 21	"github.com/kujtimiihoxha/termai/internal/tui/layout"
 22	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 23)
 24
 25type MessagesCmp interface {
 26	tea.Model
 27	layout.Focusable
 28	layout.Bordered
 29	layout.Sizeable
 30	layout.Bindings
 31}
 32
 33type messagesCmp struct {
 34	app            *app.App
 35	messages       []message.Message
 36	selectedMsgIdx int // Index of the selected message
 37	session        session.Session
 38	viewport       viewport.Model
 39	mdRenderer     *glamour.TermRenderer
 40	width          int
 41	height         int
 42	focused        bool
 43	cachedView     string
 44	timeLoaded     time.Time
 45}
 46
 47func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 48	switch msg := msg.(type) {
 49	case pubsub.Event[message.Message]:
 50		if msg.Type == pubsub.CreatedEvent {
 51			if msg.Payload.SessionID == m.session.ID {
 52				m.messages = append(m.messages, msg.Payload)
 53				m.renderView()
 54				m.viewport.GotoBottom()
 55			}
 56			for _, v := range m.messages {
 57				for _, c := range v.ToolCalls() {
 58					// the message is being added to the session of a tool called
 59					if c.ID == msg.Payload.SessionID {
 60						m.renderView()
 61						m.viewport.GotoBottom()
 62					}
 63				}
 64			}
 65		} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
 66			for i, v := range m.messages {
 67				if v.ID == msg.Payload.ID {
 68					m.messages[i] = msg.Payload
 69					m.renderView()
 70					if i == len(m.messages)-1 {
 71						m.viewport.GotoBottom()
 72					}
 73					break
 74				}
 75			}
 76		}
 77	case pubsub.Event[session.Session]:
 78		if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID {
 79			m.session = msg.Payload
 80		}
 81	case SelectedSessionMsg:
 82		m.session, _ = m.app.Sessions.Get(msg.SessionID)
 83		m.messages, _ = m.app.Messages.List(m.session.ID)
 84		m.renderView()
 85		m.viewport.GotoBottom()
 86	}
 87	if m.focused {
 88		u, cmd := m.viewport.Update(msg)
 89		m.viewport = u
 90		return m, cmd
 91	}
 92	return m, nil
 93}
 94
 95func borderColor(role message.MessageRole) lipgloss.TerminalColor {
 96	switch role {
 97	case message.Assistant:
 98		return styles.Mauve
 99	case message.User:
100		return styles.Rosewater
101	}
102	return styles.Blue
103}
104
105func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string {
106	role := ""
107	icon := ""
108	switch msgRole {
109	case message.Assistant:
110		role = "Assistant"
111		icon = styles.BotIcon
112	case message.User:
113		role = "User"
114		icon = styles.UserIcon
115	}
116	return map[layout.BorderPosition]string{
117		layout.TopLeftBorder: lipgloss.NewStyle().
118			Padding(0, 1).
119			Bold(true).
120			Foreground(styles.Crust).
121			Background(borderColor(msgRole)).
122			Render(fmt.Sprintf("%s %s ", role, icon)),
123		layout.TopRightBorder: lipgloss.NewStyle().
124			Padding(0, 1).
125			Bold(true).
126			Foreground(styles.Crust).
127			Background(borderColor(msgRole)).
128			Render(fmt.Sprintf("#%d ", currentMessage)),
129	}
130}
131
132func hasUnfinishedMessages(messages []message.Message) bool {
133	if len(messages) == 0 {
134		return false
135	}
136	for _, msg := range messages {
137		if !msg.IsFinished() {
138			return true
139		}
140	}
141	return false
142}
143
144func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string {
145	allParts := []string{content}
146
147	leftPaddingValue := 4
148	connectorStyle := lipgloss.NewStyle().
149		Foreground(styles.Peach).
150		Bold(true)
151
152	toolCallStyle := lipgloss.NewStyle().
153		Border(lipgloss.RoundedBorder()).
154		BorderForeground(styles.Peach).
155		Width(m.width-leftPaddingValue-5).
156		Padding(0, 1)
157
158	toolResultStyle := lipgloss.NewStyle().
159		Border(lipgloss.RoundedBorder()).
160		BorderForeground(styles.Green).
161		Width(m.width-leftPaddingValue-5).
162		Padding(0, 1)
163
164	leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue)
165
166	runningStyle := lipgloss.NewStyle().
167		Foreground(styles.Peach).
168		Bold(true)
169
170	renderTool := func(toolCall message.ToolCall) string {
171		toolHeader := lipgloss.NewStyle().
172			Bold(true).
173			Foreground(styles.Blue).
174			Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
175
176		var paramLines []string
177		var args map[string]interface{}
178		var paramOrder []string
179
180		json.Unmarshal([]byte(toolCall.Input), &args)
181
182		for key := range args {
183			paramOrder = append(paramOrder, key)
184		}
185		sort.Strings(paramOrder)
186
187		for _, name := range paramOrder {
188			value := args[name]
189			paramName := lipgloss.NewStyle().
190				Foreground(styles.Peach).
191				Bold(true).
192				Render(name)
193
194			truncate := m.width - leftPaddingValue*2 - 10
195			if len(fmt.Sprintf("%v", value)) > truncate {
196				value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
197			}
198			paramValue := fmt.Sprintf("%v", value)
199			paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
200		}
201
202		paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
203
204		toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
205		return toolCallStyle.Render(toolContent)
206	}
207
208	findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult {
209		for _, msg := range messages {
210			if msg.Role == message.Tool {
211				for _, result := range msg.ToolResults() {
212					if result.ToolCallID == toolCallID {
213						return &result
214					}
215				}
216			}
217		}
218		return nil
219	}
220
221	renderToolResult := func(result message.ToolResult) string {
222		resultHeader := lipgloss.NewStyle().
223			Bold(true).
224			Foreground(styles.Green).
225			Render(fmt.Sprintf("%s Result", styles.CheckIcon))
226		if result.IsError {
227			resultHeader = lipgloss.NewStyle().
228				Bold(true).
229				Foreground(styles.Red).
230				Render(fmt.Sprintf("%s Error", styles.ErrorIcon))
231		}
232
233		truncate := 200
234		content := result.Content
235		if len(content) > truncate {
236			content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
237		}
238
239		resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content)
240		return toolResultStyle.Render(resultContent)
241	}
242
243	connector := connectorStyle.Render("└─> Tool Calls:")
244	allParts = append(allParts, connector)
245
246	for _, toolCall := range tools {
247		toolOutput := renderTool(toolCall)
248		allParts = append(allParts, leftPadding.Render(toolOutput))
249
250		result := findToolResult(toolCall.ID, futureMessages)
251		if result != nil {
252
253			resultOutput := renderToolResult(*result)
254			allParts = append(allParts, leftPadding.Render(resultOutput))
255
256		} else if toolCall.Name == agent.AgentToolName {
257
258			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
259			allParts = append(allParts, leftPadding.Render(runningIndicator))
260			taskSessionMessages, _ := m.app.Messages.List(toolCall.ID)
261			for _, msg := range taskSessionMessages {
262				if msg.Role == message.Assistant {
263					for _, toolCall := range msg.ToolCalls() {
264						toolHeader := lipgloss.NewStyle().
265							Bold(true).
266							Foreground(styles.Blue).
267							Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
268
269						var paramLines []string
270						var args map[string]interface{}
271						var paramOrder []string
272
273						json.Unmarshal([]byte(toolCall.Input), &args)
274
275						for key := range args {
276							paramOrder = append(paramOrder, key)
277						}
278						sort.Strings(paramOrder)
279
280						for _, name := range paramOrder {
281							value := args[name]
282							paramName := lipgloss.NewStyle().
283								Foreground(styles.Peach).
284								Bold(true).
285								Render(name)
286
287							truncate := 50
288							if len(fmt.Sprintf("%v", value)) > truncate {
289								value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
290							}
291							paramValue := fmt.Sprintf("%v", value)
292							paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
293						}
294
295						paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
296						toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
297						toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent)
298						allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput))
299					}
300				}
301			}
302
303		} else {
304			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
305			allParts = append(allParts, "    "+runningIndicator)
306		}
307	}
308
309	for _, msg := range futureMessages {
310		if msg.Content().String() != "" {
311			break
312		}
313
314		for _, toolCall := range msg.ToolCalls() {
315			toolOutput := renderTool(toolCall)
316			allParts = append(allParts, "    "+strings.ReplaceAll(toolOutput, "\n", "\n    "))
317
318			result := findToolResult(toolCall.ID, futureMessages)
319			if result != nil {
320				resultOutput := renderToolResult(*result)
321				allParts = append(allParts, "    "+strings.ReplaceAll(resultOutput, "\n", "\n    "))
322			} else {
323				runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
324				allParts = append(allParts, "    "+runningIndicator)
325			}
326		}
327	}
328
329	return lipgloss.JoinVertical(lipgloss.Left, allParts...)
330}
331
332func (m *messagesCmp) renderView() {
333	stringMessages := make([]string, 0)
334	r, _ := glamour.NewTermRenderer(
335		glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
336		glamour.WithWordWrap(m.width-20),
337		glamour.WithEmoji(),
338	)
339	textStyle := lipgloss.NewStyle().Width(m.width - 4)
340	currentMessage := 1
341	displayedMsgCount := 0 // Track the actual displayed messages count
342
343	prevMessageWasUser := false
344	for inx, msg := range m.messages {
345		content := msg.Content().String()
346		if content != "" || prevMessageWasUser {
347			if msg.ReasoningContent().String() != "" && content == "" {
348				content = msg.ReasoningContent().String()
349			} else if content == "" {
350				content = "..."
351			}
352			content, _ = r.Render(content)
353
354			isSelected := inx == m.selectedMsgIdx
355
356			border := lipgloss.DoubleBorder()
357			activeColor := borderColor(msg.Role)
358
359			if isSelected {
360				activeColor = styles.Primary // Use primary color for selected message
361			}
362
363			content = layout.Borderize(
364				textStyle.Render(content),
365				layout.BorderOptions{
366					InactiveBorder: border,
367					ActiveBorder:   border,
368					ActiveColor:    activeColor,
369					InactiveColor:  borderColor(msg.Role),
370					EmbeddedText:   borderText(msg.Role, currentMessage),
371				},
372			)
373			if len(msg.ToolCalls()) > 0 {
374				content = m.renderMessageWithToolCall(content, msg.ToolCalls(), m.messages[inx+1:])
375			}
376			stringMessages = append(stringMessages, content)
377			currentMessage++
378			displayedMsgCount++
379		}
380		if msg.Role == message.User && msg.Content().String() != "" {
381			prevMessageWasUser = true
382		} else {
383			prevMessageWasUser = false
384		}
385	}
386	m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...))
387}
388
389func (m *messagesCmp) View() string {
390	return lipgloss.NewStyle().Padding(1).Render(m.viewport.View())
391}
392
393func (m *messagesCmp) BindingKeys() []key.Binding {
394	keys := layout.KeyMapToSlice(m.viewport.KeyMap)
395
396	return keys
397}
398
399func (m *messagesCmp) Blur() tea.Cmd {
400	m.focused = false
401	return nil
402}
403
404func (m *messagesCmp) projectDiagnostics() string {
405	errorDiagnostics := []protocol.Diagnostic{}
406	warnDiagnostics := []protocol.Diagnostic{}
407	hintDiagnostics := []protocol.Diagnostic{}
408	infoDiagnostics := []protocol.Diagnostic{}
409	for _, client := range m.app.LSPClients {
410		for _, d := range client.GetDiagnostics() {
411			for _, diag := range d {
412				switch diag.Severity {
413				case protocol.SeverityError:
414					errorDiagnostics = append(errorDiagnostics, diag)
415				case protocol.SeverityWarning:
416					warnDiagnostics = append(warnDiagnostics, diag)
417				case protocol.SeverityHint:
418					hintDiagnostics = append(hintDiagnostics, diag)
419				case protocol.SeverityInformation:
420					infoDiagnostics = append(infoDiagnostics, diag)
421				}
422			}
423		}
424	}
425
426	if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 {
427		if time.Since(m.timeLoaded) < time.Second*10 {
428			return "Loading diagnostics..."
429		}
430		return "No diagnostics"
431	}
432
433	diagnostics := []string{}
434
435	if len(errorDiagnostics) > 0 {
436		errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics)))
437		diagnostics = append(diagnostics, errStr)
438	}
439	if len(warnDiagnostics) > 0 {
440		warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics)))
441		diagnostics = append(diagnostics, warnStr)
442	}
443	if len(hintDiagnostics) > 0 {
444		hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics)))
445		diagnostics = append(diagnostics, hintStr)
446	}
447	if len(infoDiagnostics) > 0 {
448		infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics)))
449		diagnostics = append(diagnostics, infoStr)
450	}
451
452	return strings.Join(diagnostics, " ")
453}
454
455func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
456	title := m.session.Title
457	titleWidth := m.width / 2
458	if len(title) > titleWidth {
459		title = title[:titleWidth] + "..."
460	}
461	if m.focused {
462		title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
463	}
464	borderTest := map[layout.BorderPosition]string{
465		layout.TopLeftBorder:     title,
466		layout.BottomRightBorder: m.projectDiagnostics(),
467	}
468	if hasUnfinishedMessages(m.messages) {
469		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...")
470	} else {
471		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ")
472	}
473
474	return borderTest
475}
476
477func (m *messagesCmp) Focus() tea.Cmd {
478	m.focused = true
479	return nil
480}
481
482func (m *messagesCmp) GetSize() (int, int) {
483	return m.width, m.height
484}
485
486func (m *messagesCmp) IsFocused() bool {
487	return m.focused
488}
489
490func (m *messagesCmp) SetSize(width int, height int) {
491	m.width = width
492	m.height = height
493	m.viewport.Width = width - 2   // padding
494	m.viewport.Height = height - 2 // padding
495	m.renderView()
496}
497
498func (m *messagesCmp) Init() tea.Cmd {
499	m.timeLoaded = time.Now()
500	return nil
501}
502
503func NewMessagesCmp(app *app.App) MessagesCmp {
504	return &messagesCmp{
505		app:      app,
506		messages: []message.Message{},
507		viewport: viewport.New(0, 0),
508	}
509}