messages.go

  1package repl
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"sort"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/key"
 10	"github.com/charmbracelet/bubbles/viewport"
 11	tea "github.com/charmbracelet/bubbletea"
 12	"github.com/charmbracelet/glamour"
 13	"github.com/charmbracelet/lipgloss"
 14	"github.com/kujtimiihoxha/termai/internal/app"
 15	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 16	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18	"github.com/kujtimiihoxha/termai/internal/pubsub"
 19	"github.com/kujtimiihoxha/termai/internal/session"
 20	"github.com/kujtimiihoxha/termai/internal/tui/layout"
 21	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 22)
 23
 24type MessagesCmp interface {
 25	tea.Model
 26	layout.Focusable
 27	layout.Bordered
 28	layout.Sizeable
 29	layout.Bindings
 30}
 31
 32type messagesCmp struct {
 33	app            *app.App
 34	messages       []message.Message
 35	selectedMsgIdx int // Index of the selected message
 36	session        session.Session
 37	viewport       viewport.Model
 38	mdRenderer     *glamour.TermRenderer
 39	width          int
 40	height         int
 41	focused        bool
 42	cachedView     string
 43}
 44
 45func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 46	switch msg := msg.(type) {
 47	case pubsub.Event[message.Message]:
 48		if msg.Type == pubsub.CreatedEvent {
 49			if msg.Payload.SessionID == m.session.ID {
 50				m.messages = append(m.messages, msg.Payload)
 51				m.renderView()
 52				m.viewport.GotoBottom()
 53			}
 54			for _, v := range m.messages {
 55				for _, c := range v.ToolCalls() {
 56					// the message is being added to the session of a tool called
 57					if c.ID == msg.Payload.SessionID {
 58						m.renderView()
 59						m.viewport.GotoBottom()
 60					}
 61				}
 62			}
 63		} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
 64			for i, v := range m.messages {
 65				if v.ID == msg.Payload.ID {
 66					m.messages[i] = msg.Payload
 67					m.renderView()
 68					if i == len(m.messages)-1 {
 69						m.viewport.GotoBottom()
 70					}
 71					break
 72				}
 73			}
 74		}
 75	case pubsub.Event[session.Session]:
 76		if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID {
 77			m.session = msg.Payload
 78		}
 79	case SelectedSessionMsg:
 80		m.session, _ = m.app.Sessions.Get(msg.SessionID)
 81		m.messages, _ = m.app.Messages.List(m.session.ID)
 82		m.renderView()
 83		m.viewport.GotoBottom()
 84	}
 85	if m.focused {
 86		u, cmd := m.viewport.Update(msg)
 87		m.viewport = u
 88		return m, cmd
 89	}
 90	return m, nil
 91}
 92
 93func borderColor(role message.MessageRole) lipgloss.TerminalColor {
 94	switch role {
 95	case message.Assistant:
 96		return styles.Mauve
 97	case message.User:
 98		return styles.Rosewater
 99	}
100	return styles.Blue
101}
102
103func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string {
104	role := ""
105	icon := ""
106	switch msgRole {
107	case message.Assistant:
108		role = "Assistant"
109		icon = styles.BotIcon
110	case message.User:
111		role = "User"
112		icon = styles.UserIcon
113	}
114	return map[layout.BorderPosition]string{
115		layout.TopLeftBorder: lipgloss.NewStyle().
116			Padding(0, 1).
117			Bold(true).
118			Foreground(styles.Crust).
119			Background(borderColor(msgRole)).
120			Render(fmt.Sprintf("%s %s ", role, icon)),
121		layout.TopRightBorder: lipgloss.NewStyle().
122			Padding(0, 1).
123			Bold(true).
124			Foreground(styles.Crust).
125			Background(borderColor(msgRole)).
126			Render(fmt.Sprintf("#%d ", currentMessage)),
127	}
128}
129
130func hasUnfinishedMessages(messages []message.Message) bool {
131	if len(messages) == 0 {
132		return false
133	}
134	for _, msg := range messages {
135		if !msg.IsFinished() {
136			return true
137		}
138	}
139	return false
140}
141
142func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string {
143	allParts := []string{content}
144
145	leftPaddingValue := 4
146	connectorStyle := lipgloss.NewStyle().
147		Foreground(styles.Peach).
148		Bold(true)
149
150	toolCallStyle := lipgloss.NewStyle().
151		Border(lipgloss.RoundedBorder()).
152		BorderForeground(styles.Peach).
153		Width(m.width-leftPaddingValue-5).
154		Padding(0, 1)
155
156	toolResultStyle := lipgloss.NewStyle().
157		Border(lipgloss.RoundedBorder()).
158		BorderForeground(styles.Green).
159		Width(m.width-leftPaddingValue-5).
160		Padding(0, 1)
161
162	leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue)
163
164	runningStyle := lipgloss.NewStyle().
165		Foreground(styles.Peach).
166		Bold(true)
167
168	renderTool := func(toolCall message.ToolCall) string {
169		toolHeader := lipgloss.NewStyle().
170			Bold(true).
171			Foreground(styles.Blue).
172			Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
173
174		var paramLines []string
175		var args map[string]interface{}
176		var paramOrder []string
177
178		json.Unmarshal([]byte(toolCall.Input), &args)
179
180		for key := range args {
181			paramOrder = append(paramOrder, key)
182		}
183		sort.Strings(paramOrder)
184
185		for _, name := range paramOrder {
186			value := args[name]
187			paramName := lipgloss.NewStyle().
188				Foreground(styles.Peach).
189				Bold(true).
190				Render(name)
191
192			truncate := m.width - leftPaddingValue*2 - 10
193			if len(fmt.Sprintf("%v", value)) > truncate {
194				value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
195			}
196			paramValue := fmt.Sprintf("%v", value)
197			paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
198		}
199
200		paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
201
202		toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
203		return toolCallStyle.Render(toolContent)
204	}
205
206	findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult {
207		for _, msg := range messages {
208			if msg.Role == message.Tool {
209				for _, result := range msg.ToolResults() {
210					if result.ToolCallID == toolCallID {
211						return &result
212					}
213				}
214			}
215		}
216		return nil
217	}
218
219	renderToolResult := func(result message.ToolResult) string {
220		resultHeader := lipgloss.NewStyle().
221			Bold(true).
222			Foreground(styles.Green).
223			Render(fmt.Sprintf("%s Result", styles.CheckIcon))
224		if result.IsError {
225			resultHeader = lipgloss.NewStyle().
226				Bold(true).
227				Foreground(styles.Red).
228				Render(fmt.Sprintf("%s Error", styles.ErrorIcon))
229		}
230
231		truncate := 200
232		content := result.Content
233		if len(content) > truncate {
234			content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
235		}
236
237		resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content)
238		return toolResultStyle.Render(resultContent)
239	}
240
241	connector := connectorStyle.Render("└─> Tool Calls:")
242	allParts = append(allParts, connector)
243
244	for _, toolCall := range tools {
245		toolOutput := renderTool(toolCall)
246		allParts = append(allParts, leftPadding.Render(toolOutput))
247
248		result := findToolResult(toolCall.ID, futureMessages)
249		if result != nil {
250
251			resultOutput := renderToolResult(*result)
252			allParts = append(allParts, leftPadding.Render(resultOutput))
253
254		} else if toolCall.Name == agent.AgentToolName {
255
256			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
257			allParts = append(allParts, leftPadding.Render(runningIndicator))
258			taskSessionMessages, _ := m.app.Messages.List(toolCall.ID)
259			for _, msg := range taskSessionMessages {
260				if msg.Role == message.Assistant {
261					for _, toolCall := range msg.ToolCalls() {
262						toolHeader := lipgloss.NewStyle().
263							Bold(true).
264							Foreground(styles.Blue).
265							Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
266
267						var paramLines []string
268						var args map[string]interface{}
269						var paramOrder []string
270
271						json.Unmarshal([]byte(toolCall.Input), &args)
272
273						for key := range args {
274							paramOrder = append(paramOrder, key)
275						}
276						sort.Strings(paramOrder)
277
278						for _, name := range paramOrder {
279							value := args[name]
280							paramName := lipgloss.NewStyle().
281								Foreground(styles.Peach).
282								Bold(true).
283								Render(name)
284
285							truncate := 50
286							if len(fmt.Sprintf("%v", value)) > truncate {
287								value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
288							}
289							paramValue := fmt.Sprintf("%v", value)
290							paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
291						}
292
293						paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
294						toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
295						toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent)
296						allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput))
297					}
298				}
299			}
300
301		} else {
302			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
303			allParts = append(allParts, "    "+runningIndicator)
304		}
305	}
306
307	for _, msg := range futureMessages {
308		if msg.Content().String() != "" {
309			break
310		}
311
312		for _, toolCall := range msg.ToolCalls() {
313			toolOutput := renderTool(toolCall)
314			allParts = append(allParts, "    "+strings.ReplaceAll(toolOutput, "\n", "\n    "))
315
316			result := findToolResult(toolCall.ID, futureMessages)
317			if result != nil {
318				resultOutput := renderToolResult(*result)
319				allParts = append(allParts, "    "+strings.ReplaceAll(resultOutput, "\n", "\n    "))
320			} else {
321				runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
322				allParts = append(allParts, "    "+runningIndicator)
323			}
324		}
325	}
326
327	return lipgloss.JoinVertical(lipgloss.Left, allParts...)
328}
329
330func (m *messagesCmp) renderView() {
331	stringMessages := make([]string, 0)
332	r, _ := glamour.NewTermRenderer(
333		glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
334		glamour.WithWordWrap(m.width-20),
335		glamour.WithEmoji(),
336	)
337	textStyle := lipgloss.NewStyle().Width(m.width - 4)
338	currentMessage := 1
339	displayedMsgCount := 0 // Track the actual displayed messages count
340
341	prevMessageWasUser := false
342	for inx, msg := range m.messages {
343		content := msg.Content().String()
344		if content != "" || prevMessageWasUser {
345			if msg.ReasoningContent().String() != "" && content == "" {
346				content = msg.ReasoningContent().String()
347			} else if content == "" {
348				content = "..."
349			}
350			content, _ = r.Render(content)
351
352			isSelected := inx == m.selectedMsgIdx
353
354			border := lipgloss.DoubleBorder()
355			activeColor := borderColor(msg.Role)
356
357			if isSelected {
358				activeColor = styles.Primary // Use primary color for selected message
359			}
360
361			content = layout.Borderize(
362				textStyle.Render(content),
363				layout.BorderOptions{
364					InactiveBorder: border,
365					ActiveBorder:   border,
366					ActiveColor:    activeColor,
367					InactiveColor:  borderColor(msg.Role),
368					EmbeddedText:   borderText(msg.Role, currentMessage),
369				},
370			)
371			if len(msg.ToolCalls()) > 0 {
372				content = m.renderMessageWithToolCall(content, msg.ToolCalls(), m.messages[inx+1:])
373			}
374			stringMessages = append(stringMessages, content)
375			currentMessage++
376			displayedMsgCount++
377		}
378		if msg.Role == message.User && msg.Content().String() != "" {
379			prevMessageWasUser = true
380		} else {
381			prevMessageWasUser = false
382		}
383	}
384	m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...))
385}
386
387func (m *messagesCmp) View() string {
388	return lipgloss.NewStyle().Padding(1).Render(m.viewport.View())
389}
390
391func (m *messagesCmp) BindingKeys() []key.Binding {
392	keys := layout.KeyMapToSlice(m.viewport.KeyMap)
393
394	return keys
395}
396
397func (m *messagesCmp) Blur() tea.Cmd {
398	m.focused = false
399	return nil
400}
401
402func (m *messagesCmp) projectDiagnostics() string {
403	errorDiagnostics := []protocol.Diagnostic{}
404	warnDiagnostics := []protocol.Diagnostic{}
405	hintDiagnostics := []protocol.Diagnostic{}
406	infoDiagnostics := []protocol.Diagnostic{}
407	for _, client := range m.app.LSPClients {
408		for _, d := range client.GetDiagnostics() {
409			for _, diag := range d {
410				switch diag.Severity {
411				case protocol.SeverityError:
412					errorDiagnostics = append(errorDiagnostics, diag)
413				case protocol.SeverityWarning:
414					warnDiagnostics = append(warnDiagnostics, diag)
415				case protocol.SeverityHint:
416					hintDiagnostics = append(hintDiagnostics, diag)
417				case protocol.SeverityInformation:
418					infoDiagnostics = append(infoDiagnostics, diag)
419				}
420			}
421		}
422	}
423
424	if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 {
425		return "No diagnostics"
426	}
427
428	diagnostics := []string{}
429
430	if len(errorDiagnostics) > 0 {
431		errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics)))
432		diagnostics = append(diagnostics, errStr)
433	}
434	if len(warnDiagnostics) > 0 {
435		warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics)))
436		diagnostics = append(diagnostics, warnStr)
437	}
438	if len(hintDiagnostics) > 0 {
439		hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics)))
440		diagnostics = append(diagnostics, hintStr)
441	}
442	if len(infoDiagnostics) > 0 {
443		infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics)))
444		diagnostics = append(diagnostics, infoStr)
445	}
446
447	return strings.Join(diagnostics, " ")
448}
449
450func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
451	title := m.session.Title
452	titleWidth := m.width / 2
453	if len(title) > titleWidth {
454		title = title[:titleWidth] + "..."
455	}
456	if m.focused {
457		title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
458	}
459	borderTest := map[layout.BorderPosition]string{
460		layout.TopLeftBorder:     title,
461		layout.BottomRightBorder: m.projectDiagnostics(),
462	}
463	if hasUnfinishedMessages(m.messages) {
464		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...")
465	} else {
466		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ")
467	}
468
469	return borderTest
470}
471
472func (m *messagesCmp) Focus() tea.Cmd {
473	m.focused = true
474	return nil
475}
476
477func (m *messagesCmp) GetSize() (int, int) {
478	return m.width, m.height
479}
480
481func (m *messagesCmp) IsFocused() bool {
482	return m.focused
483}
484
485func (m *messagesCmp) SetSize(width int, height int) {
486	m.width = width
487	m.height = height
488	m.viewport.Width = width - 2   // padding
489	m.viewport.Height = height - 2 // padding
490	m.renderView()
491}
492
493func (m *messagesCmp) Init() tea.Cmd {
494	return nil
495}
496
497func NewMessagesCmp(app *app.App) MessagesCmp {
498	return &messagesCmp{
499		app:      app,
500		messages: []message.Message{},
501		viewport: viewport.New(0, 0),
502	}
503}