messages.go

  1package repl
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"sort"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/key"
 10	"github.com/charmbracelet/bubbles/viewport"
 11	tea "github.com/charmbracelet/bubbletea"
 12	"github.com/charmbracelet/glamour"
 13	"github.com/charmbracelet/lipgloss"
 14	"github.com/kujtimiihoxha/termai/internal/app"
 15	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 16	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18	"github.com/kujtimiihoxha/termai/internal/pubsub"
 19	"github.com/kujtimiihoxha/termai/internal/session"
 20	"github.com/kujtimiihoxha/termai/internal/tui/layout"
 21	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 22)
 23
 24type MessagesCmp interface {
 25	tea.Model
 26	layout.Focusable
 27	layout.Bordered
 28	layout.Sizeable
 29	layout.Bindings
 30}
 31
 32type messagesCmp struct {
 33	app            *app.App
 34	messages       []message.Message
 35	selectedMsgIdx int // Index of the selected message
 36	session        session.Session
 37	viewport       viewport.Model
 38	mdRenderer     *glamour.TermRenderer
 39	width          int
 40	height         int
 41	focused        bool
 42	cachedView     string
 43}
 44
 45func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 46	switch msg := msg.(type) {
 47	case pubsub.Event[message.Message]:
 48		if msg.Type == pubsub.CreatedEvent {
 49			if msg.Payload.SessionID == m.session.ID {
 50				m.messages = append(m.messages, msg.Payload)
 51				m.renderView()
 52				m.viewport.GotoBottom()
 53			}
 54			for _, v := range m.messages {
 55				for _, c := range v.ToolCalls() {
 56					// the message is being added to the session of a tool called
 57					if c.ID == msg.Payload.SessionID {
 58						m.renderView()
 59						m.viewport.GotoBottom()
 60					}
 61				}
 62			}
 63		} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
 64			for i, v := range m.messages {
 65				if v.ID == msg.Payload.ID {
 66					m.messages[i] = msg.Payload
 67					m.renderView()
 68					if i == len(m.messages)-1 {
 69						m.viewport.GotoBottom()
 70					}
 71					break
 72				}
 73			}
 74		}
 75	case pubsub.Event[session.Session]:
 76		if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID {
 77			m.session = msg.Payload
 78		}
 79	case SelectedSessionMsg:
 80		m.session, _ = m.app.Sessions.Get(msg.SessionID)
 81		m.messages, _ = m.app.Messages.List(m.session.ID)
 82		m.renderView()
 83		m.viewport.GotoBottom()
 84	}
 85	if m.focused {
 86		u, cmd := m.viewport.Update(msg)
 87		m.viewport = u
 88		return m, cmd
 89	}
 90	return m, nil
 91}
 92
 93func borderColor(role message.MessageRole) lipgloss.TerminalColor {
 94	switch role {
 95	case message.Assistant:
 96		return styles.Mauve
 97	case message.User:
 98		return styles.Rosewater
 99	}
100	return styles.Blue
101}
102
103func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string {
104	role := ""
105	icon := ""
106	switch msgRole {
107	case message.Assistant:
108		role = "Assistant"
109		icon = styles.BotIcon
110	case message.User:
111		role = "User"
112		icon = styles.UserIcon
113	}
114	return map[layout.BorderPosition]string{
115		layout.TopLeftBorder: lipgloss.NewStyle().
116			Padding(0, 1).
117			Bold(true).
118			Foreground(styles.Crust).
119			Background(borderColor(msgRole)).
120			Render(fmt.Sprintf("%s %s ", role, icon)),
121		layout.TopRightBorder: lipgloss.NewStyle().
122			Padding(0, 1).
123			Bold(true).
124			Foreground(styles.Crust).
125			Background(borderColor(msgRole)).
126			Render(fmt.Sprintf("#%d ", currentMessage)),
127	}
128}
129
130func hasUnfinishedMessages(messages []message.Message) bool {
131	if len(messages) == 0 {
132		return false
133	}
134	for _, msg := range messages {
135		if !msg.IsFinished() {
136			return true
137		}
138	}
139	return false
140}
141
142func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string {
143	allParts := []string{content}
144
145	leftPaddingValue := 4
146	connectorStyle := lipgloss.NewStyle().
147		Foreground(styles.Peach).
148		Bold(true)
149
150	toolCallStyle := lipgloss.NewStyle().
151		Border(lipgloss.RoundedBorder()).
152		BorderForeground(styles.Peach).
153		Width(m.width-leftPaddingValue-5).
154		Padding(0, 1)
155
156	toolResultStyle := lipgloss.NewStyle().
157		Border(lipgloss.RoundedBorder()).
158		BorderForeground(styles.Green).
159		Width(m.width-leftPaddingValue-5).
160		Padding(0, 1)
161
162	leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue)
163
164	runningStyle := lipgloss.NewStyle().
165		Foreground(styles.Peach).
166		Bold(true)
167
168	renderTool := func(toolCall message.ToolCall) string {
169		toolHeader := lipgloss.NewStyle().
170			Bold(true).
171			Foreground(styles.Blue).
172			Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
173
174		var paramLines []string
175		var args map[string]interface{}
176		var paramOrder []string
177
178		json.Unmarshal([]byte(toolCall.Input), &args)
179
180		for key := range args {
181			paramOrder = append(paramOrder, key)
182		}
183		sort.Strings(paramOrder)
184
185		for _, name := range paramOrder {
186			value := args[name]
187			paramName := lipgloss.NewStyle().
188				Foreground(styles.Peach).
189				Bold(true).
190				Render(name)
191
192			truncate := m.width - leftPaddingValue*2 - 10
193			if len(fmt.Sprintf("%v", value)) > truncate {
194				value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
195			}
196			paramValue := fmt.Sprintf("%v", value)
197			paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
198		}
199
200		paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
201
202		toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
203		return toolCallStyle.Render(toolContent)
204	}
205
206	findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult {
207		for _, msg := range messages {
208			if msg.Role == message.Tool {
209				for _, result := range msg.ToolResults() {
210					if result.ToolCallID == toolCallID {
211						return &result
212					}
213				}
214			}
215		}
216		return nil
217	}
218
219	renderToolResult := func(result message.ToolResult) string {
220		resultHeader := lipgloss.NewStyle().
221			Bold(true).
222			Foreground(styles.Green).
223			Render(fmt.Sprintf("%s Result", styles.CheckIcon))
224
225		// Use the same style for both header and border if it's an error
226		borderColor := styles.Green
227		if result.IsError {
228			resultHeader = lipgloss.NewStyle().
229				Bold(true).
230				Foreground(styles.Red).
231				Render(fmt.Sprintf("%s Error", styles.ErrorIcon))
232			borderColor = styles.Red
233		}
234
235		truncate := 200
236		content := result.Content
237		if len(content) > truncate {
238			content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
239		}
240
241		resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content)
242		return toolResultStyle.BorderForeground(borderColor).Render(resultContent)
243	}
244
245	connector := connectorStyle.Render("└─> Tool Calls:")
246	allParts = append(allParts, connector)
247
248	for _, toolCall := range tools {
249		toolOutput := renderTool(toolCall)
250		allParts = append(allParts, leftPadding.Render(toolOutput))
251
252		result := findToolResult(toolCall.ID, futureMessages)
253		if result != nil {
254
255			resultOutput := renderToolResult(*result)
256			allParts = append(allParts, leftPadding.Render(resultOutput))
257
258		} else if toolCall.Name == agent.AgentToolName {
259
260			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
261			allParts = append(allParts, leftPadding.Render(runningIndicator))
262			taskSessionMessages, _ := m.app.Messages.List(toolCall.ID)
263			for _, msg := range taskSessionMessages {
264				if msg.Role == message.Assistant {
265					for _, toolCall := range msg.ToolCalls() {
266						toolHeader := lipgloss.NewStyle().
267							Bold(true).
268							Foreground(styles.Blue).
269							Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
270
271						var paramLines []string
272						var args map[string]interface{}
273						var paramOrder []string
274
275						json.Unmarshal([]byte(toolCall.Input), &args)
276
277						for key := range args {
278							paramOrder = append(paramOrder, key)
279						}
280						sort.Strings(paramOrder)
281
282						for _, name := range paramOrder {
283							value := args[name]
284							paramName := lipgloss.NewStyle().
285								Foreground(styles.Peach).
286								Bold(true).
287								Render(name)
288
289							truncate := 50
290							if len(fmt.Sprintf("%v", value)) > truncate {
291								value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
292							}
293							paramValue := fmt.Sprintf("%v", value)
294							paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
295						}
296
297						paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
298						toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
299						toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent)
300						allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput))
301					}
302				}
303			}
304
305		} else {
306			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
307			allParts = append(allParts, "    "+runningIndicator)
308		}
309	}
310
311	for _, msg := range futureMessages {
312		if msg.Content().String() != "" {
313			break
314		}
315
316		for _, toolCall := range msg.ToolCalls() {
317			toolOutput := renderTool(toolCall)
318			allParts = append(allParts, "    "+strings.ReplaceAll(toolOutput, "\n", "\n    "))
319
320			result := findToolResult(toolCall.ID, futureMessages)
321			if result != nil {
322				resultOutput := renderToolResult(*result)
323				allParts = append(allParts, "    "+strings.ReplaceAll(resultOutput, "\n", "\n    "))
324			} else {
325				runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
326				allParts = append(allParts, "    "+runningIndicator)
327			}
328		}
329	}
330
331	return lipgloss.JoinVertical(lipgloss.Left, allParts...)
332}
333
334func (m *messagesCmp) renderView() {
335	stringMessages := make([]string, 0)
336	r, _ := glamour.NewTermRenderer(
337		glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
338		glamour.WithWordWrap(m.width-20),
339		glamour.WithEmoji(),
340	)
341	textStyle := lipgloss.NewStyle().Width(m.width - 4)
342	currentMessage := 1
343	displayedMsgCount := 0 // Track the actual displayed messages count
344
345	prevMessageWasUser := false
346	for inx, msg := range m.messages {
347		content := msg.Content().String()
348		if content != "" || prevMessageWasUser {
349			if msg.ReasoningContent().String() != "" && content == "" {
350				content = msg.ReasoningContent().String()
351			} else if content == "" {
352				content = "..."
353			}
354			content, _ = r.Render(content)
355
356			isSelected := inx == m.selectedMsgIdx
357
358			border := lipgloss.DoubleBorder()
359			activeColor := borderColor(msg.Role)
360
361			if isSelected {
362				activeColor = styles.Primary // Use primary color for selected message
363			}
364
365			content = layout.Borderize(
366				textStyle.Render(content),
367				layout.BorderOptions{
368					InactiveBorder: border,
369					ActiveBorder:   border,
370					ActiveColor:    activeColor,
371					InactiveColor:  borderColor(msg.Role),
372					EmbeddedText:   borderText(msg.Role, currentMessage),
373				},
374			)
375			if len(msg.ToolCalls()) > 0 {
376				content = m.renderMessageWithToolCall(content, msg.ToolCalls(), m.messages[inx+1:])
377			}
378			stringMessages = append(stringMessages, content)
379			currentMessage++
380			displayedMsgCount++
381		}
382		if msg.Role == message.User && msg.Content().String() != "" {
383			prevMessageWasUser = true
384		} else {
385			prevMessageWasUser = false
386		}
387	}
388	m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...))
389}
390
391func (m *messagesCmp) View() string {
392	return lipgloss.NewStyle().Padding(1).Render(m.viewport.View())
393}
394
395func (m *messagesCmp) BindingKeys() []key.Binding {
396	keys := layout.KeyMapToSlice(m.viewport.KeyMap)
397
398	return keys
399}
400
401func (m *messagesCmp) Blur() tea.Cmd {
402	m.focused = false
403	return nil
404}
405
406func (m *messagesCmp) projectDiagnostics() string {
407	errorDiagnostics := []protocol.Diagnostic{}
408	warnDiagnostics := []protocol.Diagnostic{}
409	hintDiagnostics := []protocol.Diagnostic{}
410	infoDiagnostics := []protocol.Diagnostic{}
411	for _, client := range m.app.LSPClients {
412		for _, d := range client.GetDiagnostics() {
413			for _, diag := range d {
414				switch diag.Severity {
415				case protocol.SeverityError:
416					errorDiagnostics = append(errorDiagnostics, diag)
417				case protocol.SeverityWarning:
418					warnDiagnostics = append(warnDiagnostics, diag)
419				case protocol.SeverityHint:
420					hintDiagnostics = append(hintDiagnostics, diag)
421				case protocol.SeverityInformation:
422					infoDiagnostics = append(infoDiagnostics, diag)
423				}
424			}
425		}
426	}
427
428	if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 {
429		return "No diagnostics"
430	}
431
432	diagnostics := []string{}
433
434	if len(errorDiagnostics) > 0 {
435		errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics)))
436		diagnostics = append(diagnostics, errStr)
437	}
438	if len(warnDiagnostics) > 0 {
439		warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics)))
440		diagnostics = append(diagnostics, warnStr)
441	}
442	if len(hintDiagnostics) > 0 {
443		hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics)))
444		diagnostics = append(diagnostics, hintStr)
445	}
446	if len(infoDiagnostics) > 0 {
447		infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics)))
448		diagnostics = append(diagnostics, infoStr)
449	}
450
451	return strings.Join(diagnostics, " ")
452}
453
454func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
455	title := m.session.Title
456	titleWidth := m.width / 2
457	if len(title) > titleWidth {
458		title = title[:titleWidth] + "..."
459	}
460	if m.focused {
461		title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
462	}
463	borderTest := map[layout.BorderPosition]string{
464		layout.TopLeftBorder:     title,
465		layout.BottomRightBorder: m.projectDiagnostics(),
466	}
467	if hasUnfinishedMessages(m.messages) {
468		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...")
469	} else {
470		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ")
471	}
472
473	return borderTest
474}
475
476func (m *messagesCmp) Focus() tea.Cmd {
477	m.focused = true
478	return nil
479}
480
481func (m *messagesCmp) GetSize() (int, int) {
482	return m.width, m.height
483}
484
485func (m *messagesCmp) IsFocused() bool {
486	return m.focused
487}
488
489func (m *messagesCmp) SetSize(width int, height int) {
490	m.width = width
491	m.height = height
492	m.viewport.Width = width - 2   // padding
493	m.viewport.Height = height - 2 // padding
494	m.renderView()
495}
496
497func (m *messagesCmp) Init() tea.Cmd {
498	return nil
499}
500
501func NewMessagesCmp(app *app.App) MessagesCmp {
502	return &messagesCmp{
503		app:      app,
504		messages: []message.Message{},
505		viewport: viewport.New(0, 0),
506	}
507}