messages.go

  1package chat
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"math"
  8	"strings"
  9
 10	"github.com/charmbracelet/bubbles/spinner"
 11	"github.com/charmbracelet/bubbles/viewport"
 12	tea "github.com/charmbracelet/bubbletea"
 13	"github.com/charmbracelet/glamour"
 14	"github.com/charmbracelet/lipgloss"
 15	"github.com/charmbracelet/x/ansi"
 16	"github.com/kujtimiihoxha/termai/internal/app"
 17	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 18	"github.com/kujtimiihoxha/termai/internal/llm/models"
 19	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 20	"github.com/kujtimiihoxha/termai/internal/message"
 21	"github.com/kujtimiihoxha/termai/internal/pubsub"
 22	"github.com/kujtimiihoxha/termai/internal/session"
 23	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 24	"github.com/kujtimiihoxha/termai/internal/tui/util"
 25)
 26
 27type uiMessageType int
 28
 29const (
 30	userMessageType uiMessageType = iota
 31	assistantMessageType
 32	toolMessageType
 33)
 34
 35type uiMessage struct {
 36	ID          string
 37	messageType uiMessageType
 38	position    int
 39	height      int
 40	content     string
 41}
 42
 43type messagesCmp struct {
 44	app           *app.App
 45	width, height int
 46	writingMode   bool
 47	viewport      viewport.Model
 48	session       session.Session
 49	messages      []message.Message
 50	uiMessages    []uiMessage
 51	currentMsgID  string
 52	renderer      *glamour.TermRenderer
 53	focusRenderer *glamour.TermRenderer
 54	cachedContent map[string]string
 55	agentWorking  bool
 56	spinner       spinner.Model
 57	needsRerender bool
 58	lastViewport  string
 59}
 60
 61func (m *messagesCmp) Init() tea.Cmd {
 62	return tea.Batch(m.viewport.Init())
 63}
 64
 65func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 66	var cmds []tea.Cmd
 67	switch msg := msg.(type) {
 68	case AgentWorkingMsg:
 69		m.agentWorking = bool(msg)
 70		if m.agentWorking {
 71			cmds = append(cmds, m.spinner.Tick)
 72		}
 73	case EditorFocusMsg:
 74		m.writingMode = bool(msg)
 75	case SessionSelectedMsg:
 76		if msg.ID != m.session.ID {
 77			cmd := m.SetSession(msg)
 78			m.needsRerender = true
 79			return m, cmd
 80		}
 81		return m, nil
 82	case SessionClearedMsg:
 83		m.session = session.Session{}
 84		m.messages = make([]message.Message, 0)
 85		m.currentMsgID = ""
 86		m.needsRerender = true
 87		return m, nil
 88
 89	case tea.KeyMsg:
 90		if m.writingMode {
 91			return m, nil
 92		}
 93	case pubsub.Event[message.Message]:
 94		if msg.Type == pubsub.CreatedEvent {
 95			if msg.Payload.SessionID == m.session.ID {
 96				// check if message exists
 97
 98				messageExists := false
 99				for _, v := range m.messages {
100					if v.ID == msg.Payload.ID {
101						messageExists = true
102						break
103					}
104				}
105
106				if !messageExists {
107					m.messages = append(m.messages, msg.Payload)
108					delete(m.cachedContent, m.currentMsgID)
109					m.currentMsgID = msg.Payload.ID
110					m.needsRerender = true
111				}
112			}
113			for _, v := range m.messages {
114				for _, c := range v.ToolCalls() {
115					// the message is being added to the session of a tool called
116					if c.ID == msg.Payload.SessionID {
117						m.needsRerender = true
118					}
119				}
120			}
121		} else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
122			for i, v := range m.messages {
123				if v.ID == msg.Payload.ID {
124					if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" {
125						cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false)))
126					}
127					m.messages[i] = msg.Payload
128					delete(m.cachedContent, msg.Payload.ID)
129					m.needsRerender = true
130					break
131				}
132			}
133		}
134	}
135	if m.agentWorking {
136		u, cmd := m.spinner.Update(msg)
137		m.spinner = u
138		cmds = append(cmds, cmd)
139	}
140	oldPos := m.viewport.YPosition
141	u, cmd := m.viewport.Update(msg)
142	m.viewport = u
143	m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos
144	cmds = append(cmds, cmd)
145	if m.needsRerender {
146		m.renderView()
147		if len(m.messages) > 0 {
148			if msg, ok := msg.(pubsub.Event[message.Message]); ok {
149				if (msg.Type == pubsub.CreatedEvent) ||
150					(msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
151					m.viewport.GotoBottom()
152				}
153			}
154		}
155		m.needsRerender = false
156	}
157	return m, tea.Batch(cmds...)
158}
159
160func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string {
161	if v, ok := m.cachedContent[msg.ID]; ok {
162		return v
163	}
164	style := styles.BaseStyle.
165		Width(m.width).
166		BorderLeft(true).
167		Foreground(styles.ForgroundDim).
168		BorderForeground(styles.ForgroundDim).
169		BorderStyle(lipgloss.ThickBorder())
170
171	renderer := m.renderer
172	if msg.ID == m.currentMsgID {
173		style = style.
174			Foreground(styles.Forground).
175			BorderForeground(styles.Blue).
176			BorderStyle(lipgloss.ThickBorder())
177		renderer = m.focusRenderer
178	}
179	c, _ := renderer.Render(msg.Content().String())
180	parts := []string{
181		styles.ForceReplaceBackgroundWithLipgloss(c, styles.Background),
182	}
183	// remove newline at the end
184	parts[0] = strings.TrimSuffix(parts[0], "\n")
185	if len(info) > 0 {
186		parts = append(parts, info...)
187	}
188	rendered := style.Render(
189		lipgloss.JoinVertical(
190			lipgloss.Left,
191			parts...,
192		),
193	)
194	m.cachedContent[msg.ID] = rendered
195	return rendered
196}
197
198func formatTimeDifference(unixTime1, unixTime2 int64) string {
199	diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1)))
200
201	if diffSeconds < 60 {
202		return fmt.Sprintf("%.1fs", diffSeconds)
203	}
204
205	minutes := int(diffSeconds / 60)
206	seconds := int(diffSeconds) % 60
207	return fmt.Sprintf("%dm%ds", minutes, seconds)
208}
209
210func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string {
211	key := ""
212	value := ""
213	switch toolCall.Name {
214	// TODO: add result data to the tools
215	case agent.AgentToolName:
216		key = "Task"
217		var params agent.AgentParams
218		json.Unmarshal([]byte(toolCall.Input), &params)
219		value = params.Prompt
220	// TODO: handle nested calls
221	case tools.BashToolName:
222		key = "Bash"
223		var params tools.BashParams
224		json.Unmarshal([]byte(toolCall.Input), &params)
225		value = params.Command
226	case tools.EditToolName:
227		key = "Edit"
228		var params tools.EditParams
229		json.Unmarshal([]byte(toolCall.Input), &params)
230		value = params.FilePath
231	case tools.FetchToolName:
232		key = "Fetch"
233		var params tools.FetchParams
234		json.Unmarshal([]byte(toolCall.Input), &params)
235		value = params.URL
236	case tools.GlobToolName:
237		key = "Glob"
238		var params tools.GlobParams
239		json.Unmarshal([]byte(toolCall.Input), &params)
240		if params.Path == "" {
241			params.Path = "."
242		}
243		value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
244	case tools.GrepToolName:
245		key = "Grep"
246		var params tools.GrepParams
247		json.Unmarshal([]byte(toolCall.Input), &params)
248		if params.Path == "" {
249			params.Path = "."
250		}
251		value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
252	case tools.LSToolName:
253		key = "Ls"
254		var params tools.LSParams
255		json.Unmarshal([]byte(toolCall.Input), &params)
256		if params.Path == "" {
257			params.Path = "."
258		}
259		value = params.Path
260	case tools.SourcegraphToolName:
261		key = "Sourcegraph"
262		var params tools.SourcegraphParams
263		json.Unmarshal([]byte(toolCall.Input), &params)
264		value = params.Query
265	case tools.ViewToolName:
266		key = "View"
267		var params tools.ViewParams
268		json.Unmarshal([]byte(toolCall.Input), &params)
269		value = params.FilePath
270	case tools.WriteToolName:
271		key = "Write"
272		var params tools.WriteParams
273		json.Unmarshal([]byte(toolCall.Input), &params)
274		value = params.FilePath
275	default:
276		key = toolCall.Name
277		var params map[string]any
278		json.Unmarshal([]byte(toolCall.Input), &params)
279		jsonData, _ := json.Marshal(params)
280		value = string(jsonData)
281	}
282
283	style := styles.BaseStyle.
284		Width(m.width).
285		BorderLeft(true).
286		BorderStyle(lipgloss.ThickBorder()).
287		PaddingLeft(1).
288		BorderForeground(styles.Yellow)
289
290	keyStyle := styles.BaseStyle.
291		Foreground(styles.ForgroundDim)
292	valyeStyle := styles.BaseStyle.
293		Foreground(styles.Forground)
294
295	if isNested {
296		valyeStyle = valyeStyle.Foreground(styles.ForgroundMid)
297	}
298	keyValye := keyStyle.Render(
299		fmt.Sprintf("%s: ", key),
300	)
301	if !isNested {
302		value = valyeStyle.
303			Width(m.width - lipgloss.Width(keyValye) - 2).
304			Render(
305				ansi.Truncate(
306					value,
307					m.width-lipgloss.Width(keyValye)-2,
308					"...",
309				),
310			)
311	} else {
312		keyValye = keyStyle.Render(
313			fmt.Sprintf(" └ %s: ", key),
314		)
315		value = valyeStyle.
316			Width(m.width - lipgloss.Width(keyValye) - 2).
317			Render(
318				ansi.Truncate(
319					value,
320					m.width-lipgloss.Width(keyValye)-2,
321					"...",
322				),
323			)
324	}
325
326	innerToolCalls := make([]string, 0)
327	if toolCall.Name == agent.AgentToolName {
328		messages, _ := m.app.Messages.List(context.Background(), toolCall.ID)
329		toolCalls := make([]message.ToolCall, 0)
330		for _, v := range messages {
331			toolCalls = append(toolCalls, v.ToolCalls()...)
332		}
333		for _, v := range toolCalls {
334			call := m.renderToolCall(v, true)
335			innerToolCalls = append(innerToolCalls, call)
336		}
337	}
338
339	if isNested {
340		return lipgloss.JoinHorizontal(
341			lipgloss.Left,
342			keyValye,
343			value,
344		)
345	}
346	callContent := lipgloss.JoinHorizontal(
347		lipgloss.Left,
348		keyValye,
349		value,
350	)
351	callContent = strings.ReplaceAll(callContent, "\n", "")
352	if len(innerToolCalls) > 0 {
353		callContent = lipgloss.JoinVertical(
354			lipgloss.Left,
355			callContent,
356			lipgloss.JoinVertical(
357				lipgloss.Left,
358				innerToolCalls...,
359			),
360		)
361	}
362	return style.Render(callContent)
363}
364
365func (m *messagesCmp) renderAssistantMessage(msg message.Message) []uiMessage {
366	// find the user message that is before this assistant message
367	var userMsg message.Message
368	for i := len(m.messages) - 1; i >= 0; i-- {
369		if m.messages[i].Role == message.User {
370			userMsg = m.messages[i]
371			break
372		}
373	}
374	messages := make([]uiMessage, 0)
375	if msg.Content().String() != "" {
376		info := make([]string, 0)
377		if msg.IsFinished() && msg.FinishReason() == "end_turn" {
378			finish := msg.FinishPart()
379			took := formatTimeDifference(userMsg.CreatedAt, finish.Time)
380
381			info = append(info, styles.BaseStyle.Width(m.width-1).Foreground(styles.ForgroundDim).Render(
382				fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took),
383			))
384		}
385		content := m.renderSimpleMessage(msg, info...)
386		messages = append(messages, uiMessage{
387			messageType: assistantMessageType,
388			position:    0, // gets updated in renderView
389			height:      lipgloss.Height(content),
390			content:     content,
391		})
392	}
393	for _, v := range msg.ToolCalls() {
394		content := m.renderToolCall(v, false)
395		messages = append(messages,
396			uiMessage{
397				messageType: toolMessageType,
398				position:    0, // gets updated in renderView
399				height:      lipgloss.Height(content),
400				content:     content,
401			},
402		)
403	}
404
405	return messages
406}
407
408func (m *messagesCmp) renderView() {
409	m.uiMessages = make([]uiMessage, 0)
410	pos := 0
411
412	for _, v := range m.messages {
413		switch v.Role {
414		case message.User:
415			content := m.renderSimpleMessage(v)
416			m.uiMessages = append(m.uiMessages, uiMessage{
417				messageType: userMessageType,
418				position:    pos,
419				height:      lipgloss.Height(content),
420				content:     content,
421			})
422			pos += lipgloss.Height(content) + 1 // + 1 for spacing
423		case message.Assistant:
424			assistantMessages := m.renderAssistantMessage(v)
425			for _, msg := range assistantMessages {
426				msg.position = pos
427				m.uiMessages = append(m.uiMessages, msg)
428				pos += msg.height + 1 // + 1 for spacing
429			}
430
431		}
432	}
433
434	messages := make([]string, 0)
435	for _, v := range m.uiMessages {
436		messages = append(messages, v.content,
437			styles.BaseStyle.
438				Width(m.width).
439				Render(
440					"",
441				),
442		)
443	}
444	m.viewport.SetContent(
445		styles.BaseStyle.
446			Width(m.width).
447			Render(
448				lipgloss.JoinVertical(
449					lipgloss.Top,
450					messages...,
451				),
452			),
453	)
454}
455
456func (m *messagesCmp) View() string {
457	if len(m.messages) == 0 {
458		content := styles.BaseStyle.
459			Width(m.width).
460			Height(m.height - 1).
461			Render(
462				m.initialScreen(),
463			)
464
465		return styles.BaseStyle.
466			Width(m.width).
467			Render(
468				lipgloss.JoinVertical(
469					lipgloss.Top,
470					content,
471					m.help(),
472				),
473			)
474	}
475
476	return styles.BaseStyle.
477		Width(m.width).
478		Render(
479			lipgloss.JoinVertical(
480				lipgloss.Top,
481				m.viewport.View(),
482				m.help(),
483			),
484		)
485}
486
487func (m *messagesCmp) help() string {
488	text := ""
489
490	if m.agentWorking {
491		text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render(
492			fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."),
493		)
494	}
495	if m.writingMode {
496		text += lipgloss.JoinHorizontal(
497			lipgloss.Left,
498			styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
499			styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"),
500			styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"),
501		)
502	} else {
503		text += lipgloss.JoinHorizontal(
504			lipgloss.Left,
505			styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
506			styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"),
507			styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"),
508		)
509	}
510
511	return styles.BaseStyle.
512		Width(m.width).
513		Render(text)
514}
515
516func (m *messagesCmp) initialScreen() string {
517	return styles.BaseStyle.Width(m.width).Render(
518		lipgloss.JoinVertical(
519			lipgloss.Top,
520			header(m.width),
521			"",
522			lspsConfigured(m.width),
523		),
524	)
525}
526
527func (m *messagesCmp) SetSize(width, height int) {
528	m.width = width
529	m.height = height
530	m.viewport.Width = width
531	m.viewport.Height = height - 1
532	focusRenderer, _ := glamour.NewTermRenderer(
533		glamour.WithStyles(styles.MarkdownTheme(true)),
534		glamour.WithWordWrap(width-1),
535	)
536	renderer, _ := glamour.NewTermRenderer(
537		glamour.WithStyles(styles.MarkdownTheme(false)),
538		glamour.WithWordWrap(width-1),
539	)
540	m.focusRenderer = focusRenderer
541	// clear the cached content
542	for k := range m.cachedContent {
543		delete(m.cachedContent, k)
544	}
545	m.renderer = renderer
546	if len(m.messages) > 0 {
547		m.renderView()
548		m.viewport.GotoBottom()
549	}
550}
551
552func (m *messagesCmp) GetSize() (int, int) {
553	return m.width, m.height
554}
555
556func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
557	m.session = session
558	messages, err := m.app.Messages.List(context.Background(), session.ID)
559	if err != nil {
560		return util.ReportError(err)
561	}
562	m.messages = messages
563	m.currentMsgID = m.messages[len(m.messages)-1].ID
564	m.needsRerender = true
565	return nil
566}
567
568func NewMessagesCmp(app *app.App) tea.Model {
569	focusRenderer, _ := glamour.NewTermRenderer(
570		glamour.WithStyles(styles.MarkdownTheme(true)),
571		glamour.WithWordWrap(80),
572	)
573	renderer, _ := glamour.NewTermRenderer(
574		glamour.WithStyles(styles.MarkdownTheme(false)),
575		glamour.WithWordWrap(80),
576	)
577
578	s := spinner.New()
579	s.Spinner = spinner.Pulse
580	return &messagesCmp{
581		app:           app,
582		writingMode:   true,
583		cachedContent: make(map[string]string),
584		viewport:      viewport.New(0, 0),
585		focusRenderer: focusRenderer,
586		renderer:      renderer,
587		spinner:       s,
588	}
589}