messages.go

  1package chat
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"math"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/spinner"
 10	"github.com/charmbracelet/bubbles/viewport"
 11	tea "github.com/charmbracelet/bubbletea"
 12	"github.com/charmbracelet/glamour"
 13	"github.com/charmbracelet/lipgloss"
 14	"github.com/charmbracelet/x/ansi"
 15	"github.com/kujtimiihoxha/termai/internal/app"
 16	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 17	"github.com/kujtimiihoxha/termai/internal/llm/models"
 18	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 19	"github.com/kujtimiihoxha/termai/internal/message"
 20	"github.com/kujtimiihoxha/termai/internal/pubsub"
 21	"github.com/kujtimiihoxha/termai/internal/session"
 22	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 23	"github.com/kujtimiihoxha/termai/internal/tui/util"
 24)
 25
 26type uiMessageType int
 27
 28const (
 29	userMessageType uiMessageType = iota
 30	assistantMessageType
 31	toolMessageType
 32)
 33
 34type uiMessage struct {
 35	ID          string
 36	messageType uiMessageType
 37	position    int
 38	height      int
 39	content     string
 40}
 41
 42type messagesCmp struct {
 43	app           *app.App
 44	width, height int
 45	writingMode   bool
 46	viewport      viewport.Model
 47	session       session.Session
 48	messages      []message.Message
 49	uiMessages    []uiMessage
 50	currentMsgID  string
 51	renderer      *glamour.TermRenderer
 52	focusRenderer *glamour.TermRenderer
 53	cachedContent map[string]string
 54	agentWorking  bool
 55	spinner       spinner.Model
 56	needsRerender bool
 57	lastViewport  string
 58}
 59
 60func (m *messagesCmp) Init() tea.Cmd {
 61	return tea.Batch(m.viewport.Init())
 62}
 63
 64func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 65	var cmds []tea.Cmd
 66	switch msg := msg.(type) {
 67	case AgentWorkingMsg:
 68		m.agentWorking = bool(msg)
 69		if m.agentWorking {
 70			cmds = append(cmds, m.spinner.Tick)
 71		}
 72	case EditorFocusMsg:
 73		m.writingMode = bool(msg)
 74	case SessionSelectedMsg:
 75		if msg.ID != m.session.ID {
 76			cmd := m.SetSession(msg)
 77			m.needsRerender = true
 78			return m, cmd
 79		}
 80		return m, nil
 81	case SessionClearedMsg:
 82		m.session = session.Session{}
 83		m.messages = make([]message.Message, 0)
 84		m.currentMsgID = ""
 85		m.needsRerender = true
 86		return m, nil
 87
 88	case tea.KeyMsg:
 89		if m.writingMode {
 90			return m, nil
 91		}
 92	case pubsub.Event[message.Message]:
 93		if msg.Type == pubsub.CreatedEvent {
 94			if msg.Payload.SessionID == m.session.ID {
 95				// check if message exists
 96
 97				messageExists := false
 98				for _, v := range m.messages {
 99					if v.ID == msg.Payload.ID {
100						messageExists = true
101						break
102					}
103				}
104
105				if !messageExists {
106					m.messages = append(m.messages, msg.Payload)
107					delete(m.cachedContent, m.currentMsgID)
108					m.currentMsgID = msg.Payload.ID
109					m.needsRerender = true
110				}
111			}
112			for _, v := range m.messages {
113				for _, c := range v.ToolCalls() {
114					// the message is being added to the session of a tool called
115					if c.ID == msg.Payload.SessionID {
116						m.needsRerender = true
117					}
118				}
119			}
120		} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
121			for i, v := range m.messages {
122				if v.ID == msg.Payload.ID {
123					if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" {
124						cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false)))
125					}
126					m.messages[i] = msg.Payload
127					delete(m.cachedContent, msg.Payload.ID)
128					m.needsRerender = true
129					break
130				}
131			}
132		}
133	}
134	if m.agentWorking {
135		u, cmd := m.spinner.Update(msg)
136		m.spinner = u
137		cmds = append(cmds, cmd)
138	}
139	oldPos := m.viewport.YPosition
140	u, cmd := m.viewport.Update(msg)
141	m.viewport = u
142	m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos
143	cmds = append(cmds, cmd)
144	if m.needsRerender {
145		m.renderView()
146		if len(m.messages) > 0 {
147			if msg, ok := msg.(pubsub.Event[message.Message]); ok {
148				if (msg.Type == pubsub.CreatedEvent) ||
149					(msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
150					m.viewport.GotoBottom()
151				}
152			}
153		}
154		m.needsRerender = false
155	}
156	return m, tea.Batch(cmds...)
157}
158
159func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string {
160	if v, ok := m.cachedContent[msg.ID]; ok {
161		return v
162	}
163	style := styles.BaseStyle.
164		Width(m.width).
165		BorderLeft(true).
166		Foreground(styles.ForgroundDim).
167		BorderForeground(styles.ForgroundDim).
168		BorderStyle(lipgloss.ThickBorder())
169
170	renderer := m.renderer
171	if msg.ID == m.currentMsgID {
172		style = style.
173			Foreground(styles.Forground).
174			BorderForeground(styles.Blue).
175			BorderStyle(lipgloss.ThickBorder())
176		renderer = m.focusRenderer
177	}
178	c, _ := renderer.Render(msg.Content().String())
179	parts := []string{
180		styles.ForceReplaceBackgroundWithLipgloss(c, styles.Background),
181	}
182	// remove newline at the end
183	parts[0] = strings.TrimSuffix(parts[0], "\n")
184	if len(info) > 0 {
185		parts = append(parts, info...)
186	}
187	rendered := style.Render(
188		lipgloss.JoinVertical(
189			lipgloss.Left,
190			parts...,
191		),
192	)
193	m.cachedContent[msg.ID] = rendered
194	return rendered
195}
196
197func formatTimeDifference(unixTime1, unixTime2 int64) string {
198	diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1)))
199
200	if diffSeconds < 60 {
201		return fmt.Sprintf("%.1fs", diffSeconds)
202	}
203
204	minutes := int(diffSeconds / 60)
205	seconds := int(diffSeconds) % 60
206	return fmt.Sprintf("%dm%ds", minutes, seconds)
207}
208
209func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string {
210	key := ""
211	value := ""
212	switch toolCall.Name {
213	// TODO: add result data to the tools
214	case agent.AgentToolName:
215		key = "Task"
216		var params agent.AgentParams
217		json.Unmarshal([]byte(toolCall.Input), &params)
218		value = params.Prompt
219	// TODO: handle nested calls
220	case tools.BashToolName:
221		key = "Bash"
222		var params tools.BashParams
223		json.Unmarshal([]byte(toolCall.Input), &params)
224		value = params.Command
225	case tools.EditToolName:
226		key = "Edit"
227		var params tools.EditParams
228		json.Unmarshal([]byte(toolCall.Input), &params)
229		value = params.FilePath
230	case tools.FetchToolName:
231		key = "Fetch"
232		var params tools.FetchParams
233		json.Unmarshal([]byte(toolCall.Input), &params)
234		value = params.URL
235	case tools.GlobToolName:
236		key = "Glob"
237		var params tools.GlobParams
238		json.Unmarshal([]byte(toolCall.Input), &params)
239		if params.Path == "" {
240			params.Path = "."
241		}
242		value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
243	case tools.GrepToolName:
244		key = "Grep"
245		var params tools.GrepParams
246		json.Unmarshal([]byte(toolCall.Input), &params)
247		if params.Path == "" {
248			params.Path = "."
249		}
250		value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
251	case tools.LSToolName:
252		key = "Ls"
253		var params tools.LSParams
254		json.Unmarshal([]byte(toolCall.Input), &params)
255		if params.Path == "" {
256			params.Path = "."
257		}
258		value = params.Path
259	case tools.SourcegraphToolName:
260		key = "Sourcegraph"
261		var params tools.SourcegraphParams
262		json.Unmarshal([]byte(toolCall.Input), &params)
263		value = params.Query
264	case tools.ViewToolName:
265		key = "View"
266		var params tools.ViewParams
267		json.Unmarshal([]byte(toolCall.Input), &params)
268		value = params.FilePath
269	case tools.WriteToolName:
270		key = "Write"
271		var params tools.WriteParams
272		json.Unmarshal([]byte(toolCall.Input), &params)
273		value = params.FilePath
274	default:
275		key = toolCall.Name
276		var params map[string]any
277		json.Unmarshal([]byte(toolCall.Input), &params)
278		jsonData, _ := json.Marshal(params)
279		value = string(jsonData)
280	}
281
282	style := styles.BaseStyle.
283		Width(m.width).
284		BorderLeft(true).
285		BorderStyle(lipgloss.ThickBorder()).
286		PaddingLeft(1).
287		BorderForeground(styles.Yellow)
288
289	keyStyle := styles.BaseStyle.
290		Foreground(styles.ForgroundDim)
291	valyeStyle := styles.BaseStyle.
292		Foreground(styles.Forground)
293
294	if isNested {
295		valyeStyle = valyeStyle.Foreground(styles.ForgroundMid)
296	}
297	keyValye := keyStyle.Render(
298		fmt.Sprintf("%s: ", key),
299	)
300	if !isNested {
301		value = valyeStyle.
302			Width(m.width - lipgloss.Width(keyValye) - 2).
303			Render(
304				ansi.Truncate(
305					value,
306					m.width-lipgloss.Width(keyValye)-2,
307					"...",
308				),
309			)
310	} else {
311		keyValye = keyStyle.Render(
312			fmt.Sprintf(" └ %s: ", key),
313		)
314		value = valyeStyle.
315			Width(m.width - lipgloss.Width(keyValye) - 2).
316			Render(
317				ansi.Truncate(
318					value,
319					m.width-lipgloss.Width(keyValye)-2,
320					"...",
321				),
322			)
323	}
324
325	innerToolCalls := make([]string, 0)
326	if toolCall.Name == agent.AgentToolName {
327		messages, _ := m.app.Messages.List(toolCall.ID)
328		toolCalls := make([]message.ToolCall, 0)
329		for _, v := range messages {
330			toolCalls = append(toolCalls, v.ToolCalls()...)
331		}
332		for _, v := range toolCalls {
333			call := m.renderToolCall(v, true)
334			innerToolCalls = append(innerToolCalls, call)
335		}
336	}
337
338	if isNested {
339		return lipgloss.JoinHorizontal(
340			lipgloss.Left,
341			keyValye,
342			value,
343		)
344	}
345	callContent := lipgloss.JoinHorizontal(
346		lipgloss.Left,
347		keyValye,
348		value,
349	)
350	callContent = strings.ReplaceAll(callContent, "\n", "")
351	if len(innerToolCalls) > 0 {
352		callContent = lipgloss.JoinVertical(
353			lipgloss.Left,
354			callContent,
355			lipgloss.JoinVertical(
356				lipgloss.Left,
357				innerToolCalls...,
358			),
359		)
360	}
361	return style.Render(callContent)
362}
363
364func (m *messagesCmp) renderAssistantMessage(msg message.Message) []uiMessage {
365	// find the user message that is before this assistant message
366	var userMsg message.Message
367	for i := len(m.messages) - 1; i >= 0; i-- {
368		if m.messages[i].Role == message.User {
369			userMsg = m.messages[i]
370			break
371		}
372	}
373	messages := make([]uiMessage, 0)
374	if msg.Content().String() != "" {
375		info := make([]string, 0)
376		if msg.IsFinished() && msg.FinishReason() == "end_turn" {
377			finish := msg.FinishPart()
378			took := formatTimeDifference(userMsg.CreatedAt, finish.Time)
379
380			info = append(info, styles.BaseStyle.Width(m.width-1).Foreground(styles.ForgroundDim).Render(
381				fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took),
382			))
383		}
384		content := m.renderSimpleMessage(msg, info...)
385		messages = append(messages, uiMessage{
386			messageType: assistantMessageType,
387			position:    0, // gets updated in renderView
388			height:      lipgloss.Height(content),
389			content:     content,
390		})
391	}
392	for _, v := range msg.ToolCalls() {
393		content := m.renderToolCall(v, false)
394		messages = append(messages,
395			uiMessage{
396				messageType: toolMessageType,
397				position:    0, // gets updated in renderView
398				height:      lipgloss.Height(content),
399				content:     content,
400			},
401		)
402	}
403
404	return messages
405}
406
407func (m *messagesCmp) renderView() {
408	m.uiMessages = make([]uiMessage, 0)
409	pos := 0
410
411	for _, v := range m.messages {
412		switch v.Role {
413		case message.User:
414			content := m.renderSimpleMessage(v)
415			m.uiMessages = append(m.uiMessages, uiMessage{
416				messageType: userMessageType,
417				position:    pos,
418				height:      lipgloss.Height(content),
419				content:     content,
420			})
421			pos += lipgloss.Height(content) + 1 // + 1 for spacing
422		case message.Assistant:
423			assistantMessages := m.renderAssistantMessage(v)
424			for _, msg := range assistantMessages {
425				msg.position = pos
426				m.uiMessages = append(m.uiMessages, msg)
427				pos += msg.height + 1 // + 1 for spacing
428			}
429
430		}
431	}
432
433	messages := make([]string, 0)
434	for _, v := range m.uiMessages {
435		messages = append(messages, v.content,
436			styles.BaseStyle.
437				Width(m.width).
438				Render(
439					"",
440				),
441		)
442	}
443	m.viewport.SetContent(
444		styles.BaseStyle.
445			Width(m.width).
446			Render(
447				lipgloss.JoinVertical(
448					lipgloss.Top,
449					messages...,
450				),
451			),
452	)
453}
454
455func (m *messagesCmp) View() string {
456	if len(m.messages) == 0 {
457		content := styles.BaseStyle.
458			Width(m.width).
459			Height(m.height - 1).
460			Render(
461				m.initialScreen(),
462			)
463
464		return styles.BaseStyle.
465			Width(m.width).
466			Render(
467				lipgloss.JoinVertical(
468					lipgloss.Top,
469					content,
470					m.help(),
471				),
472			)
473	}
474
475	return styles.BaseStyle.
476		Width(m.width).
477		Render(
478			lipgloss.JoinVertical(
479				lipgloss.Top,
480				m.viewport.View(),
481				m.help(),
482			),
483		)
484}
485
486func (m *messagesCmp) help() string {
487	text := ""
488
489	if m.agentWorking {
490		text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render(
491			fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."),
492		)
493	}
494	if m.writingMode {
495		text += lipgloss.JoinHorizontal(
496			lipgloss.Left,
497			styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
498			styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"),
499			styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"),
500		)
501	} else {
502		text += lipgloss.JoinHorizontal(
503			lipgloss.Left,
504			styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
505			styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"),
506			styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"),
507		)
508	}
509
510	return styles.BaseStyle.
511		Width(m.width).
512		Render(text)
513}
514
515func (m *messagesCmp) initialScreen() string {
516	return styles.BaseStyle.Width(m.width).Render(
517		lipgloss.JoinVertical(
518			lipgloss.Top,
519			header(m.width),
520			"",
521			lspsConfigured(m.width),
522		),
523	)
524}
525
526func (m *messagesCmp) SetSize(width, height int) {
527	m.width = width
528	m.height = height
529	m.viewport.Width = width
530	m.viewport.Height = height - 1
531	focusRenderer, _ := glamour.NewTermRenderer(
532		glamour.WithStyles(styles.MarkdownTheme(true)),
533		glamour.WithWordWrap(width-1),
534	)
535	renderer, _ := glamour.NewTermRenderer(
536		glamour.WithStyles(styles.MarkdownTheme(false)),
537		glamour.WithWordWrap(width-1),
538	)
539	m.focusRenderer = focusRenderer
540	// clear the cached content
541	for k := range m.cachedContent {
542		delete(m.cachedContent, k)
543	}
544	m.renderer = renderer
545	if len(m.messages) > 0 {
546		m.renderView()
547		m.viewport.GotoBottom()
548	}
549}
550
551func (m *messagesCmp) GetSize() (int, int) {
552	return m.width, m.height
553}
554
555func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
556	m.session = session
557	messages, err := m.app.Messages.List(session.ID)
558	if err != nil {
559		return util.ReportError(err)
560	}
561	m.messages = messages
562	m.currentMsgID = m.messages[len(m.messages)-1].ID
563	m.needsRerender = true
564	return nil
565}
566
567func NewMessagesCmp(app *app.App) tea.Model {
568	focusRenderer, _ := glamour.NewTermRenderer(
569		glamour.WithStyles(styles.MarkdownTheme(true)),
570		glamour.WithWordWrap(80),
571	)
572	renderer, _ := glamour.NewTermRenderer(
573		glamour.WithStyles(styles.MarkdownTheme(false)),
574		glamour.WithWordWrap(80),
575	)
576
577	s := spinner.New()
578	s.Spinner = spinner.Pulse
579	return &messagesCmp{
580		app:           app,
581		writingMode:   true,
582		cachedContent: make(map[string]string),
583		viewport:      viewport.New(0, 0),
584		focusRenderer: focusRenderer,
585		renderer:      renderer,
586		spinner:       s,
587	}
588}