messages.go

  1package repl
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"sort"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/key"
 10	"github.com/charmbracelet/bubbles/viewport"
 11	tea "github.com/charmbracelet/bubbletea"
 12	"github.com/charmbracelet/glamour"
 13	"github.com/charmbracelet/lipgloss"
 14	"github.com/kujtimiihoxha/termai/internal/app"
 15	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17	"github.com/kujtimiihoxha/termai/internal/pubsub"
 18	"github.com/kujtimiihoxha/termai/internal/session"
 19	"github.com/kujtimiihoxha/termai/internal/tui/layout"
 20	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 21)
 22
 23type MessagesCmp interface {
 24	tea.Model
 25	layout.Focusable
 26	layout.Bordered
 27	layout.Sizeable
 28	layout.Bindings
 29}
 30
 31type messagesCmp struct {
 32	app            *app.App
 33	messages       []message.Message
 34	selectedMsgIdx int // Index of the selected message
 35	session        session.Session
 36	viewport       viewport.Model
 37	mdRenderer     *glamour.TermRenderer
 38	width          int
 39	height         int
 40	focused        bool
 41	cachedView     string
 42}
 43
 44func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 45	switch msg := msg.(type) {
 46	case pubsub.Event[message.Message]:
 47		if msg.Type == pubsub.CreatedEvent {
 48			if msg.Payload.SessionID == m.session.ID {
 49				m.messages = append(m.messages, msg.Payload)
 50				m.renderView()
 51				m.viewport.GotoBottom()
 52			}
 53			for _, v := range m.messages {
 54				for _, c := range v.ToolCalls {
 55					if c.ID == msg.Payload.SessionID {
 56						m.renderView()
 57						m.viewport.GotoBottom()
 58					}
 59				}
 60			}
 61		} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
 62			for i, v := range m.messages {
 63				if v.ID == msg.Payload.ID {
 64					m.messages[i] = msg.Payload
 65					m.renderView()
 66					if i == len(m.messages)-1 {
 67						m.viewport.GotoBottom()
 68					}
 69					break
 70				}
 71			}
 72		}
 73	case pubsub.Event[session.Session]:
 74		if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID {
 75			m.session = msg.Payload
 76		}
 77	case SelectedSessionMsg:
 78		m.session, _ = m.app.Sessions.Get(msg.SessionID)
 79		m.messages, _ = m.app.Messages.List(m.session.ID)
 80		m.renderView()
 81		m.viewport.GotoBottom()
 82	}
 83	if m.focused {
 84		u, cmd := m.viewport.Update(msg)
 85		m.viewport = u
 86		return m, cmd
 87	}
 88	return m, nil
 89}
 90
 91func borderColor(role message.MessageRole) lipgloss.TerminalColor {
 92	switch role {
 93	case message.Assistant:
 94		return styles.Mauve
 95	case message.User:
 96		return styles.Rosewater
 97	}
 98	return styles.Blue
 99}
100
101func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string {
102	role := ""
103	icon := ""
104	switch msgRole {
105	case message.Assistant:
106		role = "Assistant"
107		icon = styles.BotIcon
108	case message.User:
109		role = "User"
110		icon = styles.UserIcon
111	}
112	return map[layout.BorderPosition]string{
113		layout.TopLeftBorder: lipgloss.NewStyle().
114			Padding(0, 1).
115			Bold(true).
116			Foreground(styles.Crust).
117			Background(borderColor(msgRole)).
118			Render(fmt.Sprintf("%s %s ", role, icon)),
119		layout.TopRightBorder: lipgloss.NewStyle().
120			Padding(0, 1).
121			Bold(true).
122			Foreground(styles.Crust).
123			Background(borderColor(msgRole)).
124			Render(fmt.Sprintf("#%d ", currentMessage)),
125	}
126}
127
128func hasUnfinishedMessages(messages []message.Message) bool {
129	if len(messages) == 0 {
130		return false
131	}
132	for _, msg := range messages {
133		if !msg.Finished {
134			return true
135		}
136	}
137	lastMessage := messages[len(messages)-1]
138	return lastMessage.Role != message.Assistant
139}
140
141func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string {
142	allParts := []string{content}
143
144	leftPaddingValue := 4
145	connectorStyle := lipgloss.NewStyle().
146		Foreground(styles.Peach).
147		Bold(true)
148
149	toolCallStyle := lipgloss.NewStyle().
150		Border(lipgloss.RoundedBorder()).
151		BorderForeground(styles.Peach).
152		Width(m.width-leftPaddingValue-5).
153		Padding(0, 1)
154
155	toolResultStyle := lipgloss.NewStyle().
156		Border(lipgloss.RoundedBorder()).
157		BorderForeground(styles.Green).
158		Width(m.width-leftPaddingValue-5).
159		Padding(0, 1)
160
161	leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue)
162
163	runningStyle := lipgloss.NewStyle().
164		Foreground(styles.Peach).
165		Bold(true)
166
167	renderTool := func(toolCall message.ToolCall) string {
168		toolHeader := lipgloss.NewStyle().
169			Bold(true).
170			Foreground(styles.Blue).
171			Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
172
173		var paramLines []string
174		var args map[string]interface{}
175		var paramOrder []string
176
177		json.Unmarshal([]byte(toolCall.Input), &args)
178
179		for key := range args {
180			paramOrder = append(paramOrder, key)
181		}
182		sort.Strings(paramOrder)
183
184		for _, name := range paramOrder {
185			value := args[name]
186			paramName := lipgloss.NewStyle().
187				Foreground(styles.Peach).
188				Bold(true).
189				Render(name)
190
191			truncate := m.width - leftPaddingValue*2 - 10
192			if len(fmt.Sprintf("%v", value)) > truncate {
193				value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
194			}
195			paramValue := fmt.Sprintf("%v", value)
196			paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
197		}
198
199		paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
200
201		toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
202		return toolCallStyle.Render(toolContent)
203	}
204
205	findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult {
206		for _, msg := range messages {
207			if msg.Role == message.Tool {
208				for _, result := range msg.ToolResults {
209					if result.ToolCallID == toolCallID {
210						return &result
211					}
212				}
213			}
214		}
215		return nil
216	}
217
218	renderToolResult := func(result message.ToolResult) string {
219		resultHeader := lipgloss.NewStyle().
220			Bold(true).
221			Foreground(styles.Green).
222			Render(fmt.Sprintf("%s Result", styles.CheckIcon))
223		if result.IsError {
224			resultHeader = lipgloss.NewStyle().
225				Bold(true).
226				Foreground(styles.Red).
227				Render(fmt.Sprintf("%s Error", styles.ErrorIcon))
228		}
229
230		truncate := 200
231		content := result.Content
232		if len(content) > truncate {
233			content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
234		}
235
236		resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content)
237		return toolResultStyle.Render(resultContent)
238	}
239
240	connector := connectorStyle.Render("└─> Tool Calls:")
241	allParts = append(allParts, connector)
242
243	for _, toolCall := range tools {
244		toolOutput := renderTool(toolCall)
245		allParts = append(allParts, leftPadding.Render(toolOutput))
246
247		result := findToolResult(toolCall.ID, futureMessages)
248		if result != nil {
249
250			resultOutput := renderToolResult(*result)
251			allParts = append(allParts, leftPadding.Render(resultOutput))
252
253		} else if toolCall.Name == agent.AgentToolName {
254
255			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
256			allParts = append(allParts, leftPadding.Render(runningIndicator))
257			taskSessionMessages, _ := m.app.Messages.List(toolCall.ID)
258			for _, msg := range taskSessionMessages {
259				if msg.Role == message.Assistant {
260					for _, toolCall := range msg.ToolCalls {
261						toolHeader := lipgloss.NewStyle().
262							Bold(true).
263							Foreground(styles.Blue).
264							Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name))
265
266						var paramLines []string
267						var args map[string]interface{}
268						var paramOrder []string
269
270						json.Unmarshal([]byte(toolCall.Input), &args)
271
272						for key := range args {
273							paramOrder = append(paramOrder, key)
274						}
275						sort.Strings(paramOrder)
276
277						for _, name := range paramOrder {
278							value := args[name]
279							paramName := lipgloss.NewStyle().
280								Foreground(styles.Peach).
281								Bold(true).
282								Render(name)
283
284							truncate := 50
285							if len(fmt.Sprintf("%v", value)) > truncate {
286								value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)")
287							}
288							paramValue := fmt.Sprintf("%v", value)
289							paramLines = append(paramLines, fmt.Sprintf("  %s: %s", paramName, paramValue))
290						}
291
292						paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...)
293						toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock)
294						toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent)
295						allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput))
296					}
297				}
298			}
299
300		} else {
301			runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
302			allParts = append(allParts, "    "+runningIndicator)
303		}
304	}
305
306	for _, msg := range futureMessages {
307		if msg.Content != "" {
308			break
309		}
310
311		for _, toolCall := range msg.ToolCalls {
312			toolOutput := renderTool(toolCall)
313			allParts = append(allParts, "    "+strings.ReplaceAll(toolOutput, "\n", "\n    "))
314
315			result := findToolResult(toolCall.ID, futureMessages)
316			if result != nil {
317				resultOutput := renderToolResult(*result)
318				allParts = append(allParts, "    "+strings.ReplaceAll(resultOutput, "\n", "\n    "))
319			} else {
320				runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon))
321				allParts = append(allParts, "    "+runningIndicator)
322			}
323		}
324	}
325
326	return lipgloss.JoinVertical(lipgloss.Left, allParts...)
327}
328
329func (m *messagesCmp) renderView() {
330	stringMessages := make([]string, 0)
331	r, _ := glamour.NewTermRenderer(
332		glamour.WithStyles(styles.CatppuccinMarkdownStyle()),
333		glamour.WithWordWrap(m.width-20),
334		glamour.WithEmoji(),
335	)
336	textStyle := lipgloss.NewStyle().Width(m.width - 4)
337	currentMessage := 1
338	displayedMsgCount := 0 // Track the actual displayed messages count
339
340	prevMessageWasUser := false
341	for inx, msg := range m.messages {
342		content := msg.Content
343		if content != "" || prevMessageWasUser {
344			if msg.Thinking != "" && content == "" {
345				content = msg.Thinking
346			} else if content == "" {
347				content = "..."
348			}
349			content, _ = r.Render(content)
350
351			isSelected := inx == m.selectedMsgIdx
352
353			border := lipgloss.DoubleBorder()
354			activeColor := borderColor(msg.Role)
355
356			if isSelected {
357				activeColor = styles.Primary // Use primary color for selected message
358			}
359
360			content = layout.Borderize(
361				textStyle.Render(content),
362				layout.BorderOptions{
363					InactiveBorder: border,
364					ActiveBorder:   border,
365					ActiveColor:    activeColor,
366					InactiveColor:  borderColor(msg.Role),
367					EmbeddedText:   borderText(msg.Role, currentMessage),
368				},
369			)
370			if len(msg.ToolCalls) > 0 {
371				content = m.renderMessageWithToolCall(content, msg.ToolCalls, m.messages[inx+1:])
372			}
373			stringMessages = append(stringMessages, content)
374			currentMessage++
375			displayedMsgCount++
376		}
377		if msg.Role == message.User && msg.Content != "" {
378			prevMessageWasUser = true
379		} else {
380			prevMessageWasUser = false
381		}
382	}
383	m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...))
384}
385
386func (m *messagesCmp) View() string {
387	return lipgloss.NewStyle().Padding(1).Render(m.viewport.View())
388}
389
390func (m *messagesCmp) BindingKeys() []key.Binding {
391	keys := layout.KeyMapToSlice(m.viewport.KeyMap)
392
393	return keys
394}
395
396func (m *messagesCmp) Blur() tea.Cmd {
397	m.focused = false
398	return nil
399}
400
401func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
402	title := m.session.Title
403	titleWidth := m.width / 2
404	if len(title) > titleWidth {
405		title = title[:titleWidth] + "..."
406	}
407	if m.focused {
408		title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title)
409	}
410	borderTest := map[layout.BorderPosition]string{
411		layout.TopLeftBorder:     title,
412		layout.BottomRightBorder: formatTokensAndCost(m.session.CompletionTokens+m.session.PromptTokens, m.session.Cost),
413	}
414	if hasUnfinishedMessages(m.messages) {
415		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...")
416	} else {
417		borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ")
418	}
419
420	return borderTest
421}
422
423func (m *messagesCmp) Focus() tea.Cmd {
424	m.focused = true
425	return nil
426}
427
428func (m *messagesCmp) GetSize() (int, int) {
429	return m.width, m.height
430}
431
432func (m *messagesCmp) IsFocused() bool {
433	return m.focused
434}
435
436func (m *messagesCmp) SetSize(width int, height int) {
437	m.width = width
438	m.height = height
439	m.viewport.Width = width - 2   // padding
440	m.viewport.Height = height - 2 // padding
441	m.renderView()
442}
443
444func (m *messagesCmp) Init() tea.Cmd {
445	return nil
446}
447
448func NewMessagesCmp(app *app.App) MessagesCmp {
449	return &messagesCmp{
450		app:      app,
451		messages: []message.Message{},
452		viewport: viewport.New(0, 0),
453	}
454}