1package chat
2
3import (
4 "encoding/json"
5 "fmt"
6 "math"
7 "strings"
8
9 "github.com/charmbracelet/bubbles/spinner"
10 "github.com/charmbracelet/bubbles/viewport"
11 tea "github.com/charmbracelet/bubbletea"
12 "github.com/charmbracelet/glamour"
13 "github.com/charmbracelet/lipgloss"
14 "github.com/charmbracelet/x/ansi"
15 "github.com/kujtimiihoxha/termai/internal/app"
16 "github.com/kujtimiihoxha/termai/internal/llm/agent"
17 "github.com/kujtimiihoxha/termai/internal/llm/models"
18 "github.com/kujtimiihoxha/termai/internal/llm/tools"
19 "github.com/kujtimiihoxha/termai/internal/message"
20 "github.com/kujtimiihoxha/termai/internal/pubsub"
21 "github.com/kujtimiihoxha/termai/internal/session"
22 "github.com/kujtimiihoxha/termai/internal/tui/styles"
23 "github.com/kujtimiihoxha/termai/internal/tui/util"
24)
25
26type uiMessageType int
27
28const (
29 userMessageType uiMessageType = iota
30 assistantMessageType
31 toolMessageType
32)
33
34type uiMessage struct {
35 ID string
36 messageType uiMessageType
37 position int
38 height int
39 content string
40}
41
42type messagesCmp struct {
43 app *app.App
44 width, height int
45 writingMode bool
46 viewport viewport.Model
47 session session.Session
48 messages []message.Message
49 uiMessages []uiMessage
50 currentMsgID string
51 renderer *glamour.TermRenderer
52 focusRenderer *glamour.TermRenderer
53 cachedContent map[string]string
54 agentWorking bool
55 spinner spinner.Model
56 needsRerender bool
57 lastViewport string
58}
59
60func (m *messagesCmp) Init() tea.Cmd {
61 return tea.Batch(m.viewport.Init())
62}
63
64func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
65 var cmds []tea.Cmd
66 switch msg := msg.(type) {
67 case AgentWorkingMsg:
68 m.agentWorking = bool(msg)
69 if m.agentWorking {
70 cmds = append(cmds, m.spinner.Tick)
71 }
72 case EditorFocusMsg:
73 m.writingMode = bool(msg)
74 case SessionSelectedMsg:
75 if msg.ID != m.session.ID {
76 cmd := m.SetSession(msg)
77 m.needsRerender = true
78 return m, cmd
79 }
80 return m, nil
81 case SessionClearedMsg:
82 m.session = session.Session{}
83 m.messages = make([]message.Message, 0)
84 m.currentMsgID = ""
85 m.needsRerender = true
86 return m, nil
87
88 case tea.KeyMsg:
89 if m.writingMode {
90 return m, nil
91 }
92 case pubsub.Event[message.Message]:
93 if msg.Type == pubsub.CreatedEvent {
94 if msg.Payload.SessionID == m.session.ID {
95 // check if message exists
96
97 messageExists := false
98 for _, v := range m.messages {
99 if v.ID == msg.Payload.ID {
100 messageExists = true
101 break
102 }
103 }
104
105 if !messageExists {
106 m.messages = append(m.messages, msg.Payload)
107 delete(m.cachedContent, m.currentMsgID)
108 m.currentMsgID = msg.Payload.ID
109 m.needsRerender = true
110 }
111 }
112 for _, v := range m.messages {
113 for _, c := range v.ToolCalls() {
114 // the message is being added to the session of a tool called
115 if c.ID == msg.Payload.SessionID {
116 m.needsRerender = true
117 }
118 }
119 }
120 } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
121 for i, v := range m.messages {
122 if v.ID == msg.Payload.ID {
123 if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" {
124 cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false)))
125 }
126 m.messages[i] = msg.Payload
127 delete(m.cachedContent, msg.Payload.ID)
128 m.needsRerender = true
129 break
130 }
131 }
132 }
133 }
134 if m.agentWorking {
135 u, cmd := m.spinner.Update(msg)
136 m.spinner = u
137 cmds = append(cmds, cmd)
138 }
139 oldPos := m.viewport.YPosition
140 u, cmd := m.viewport.Update(msg)
141 m.viewport = u
142 m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos
143 cmds = append(cmds, cmd)
144 if m.needsRerender {
145 m.renderView()
146 if len(m.messages) > 0 {
147 if msg, ok := msg.(pubsub.Event[message.Message]); ok {
148 if (msg.Type == pubsub.CreatedEvent) ||
149 (msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
150 m.viewport.GotoBottom()
151 }
152 }
153 }
154 m.needsRerender = false
155 }
156 return m, tea.Batch(cmds...)
157}
158
159func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string {
160 if v, ok := m.cachedContent[msg.ID]; ok {
161 return v
162 }
163 style := styles.BaseStyle.
164 Width(m.width).
165 BorderLeft(true).
166 Foreground(styles.ForgroundDim).
167 BorderForeground(styles.ForgroundDim).
168 BorderStyle(lipgloss.ThickBorder())
169
170 renderer := m.renderer
171 if msg.ID == m.currentMsgID {
172 style = style.
173 Foreground(styles.Forground).
174 BorderForeground(styles.Blue).
175 BorderStyle(lipgloss.ThickBorder())
176 renderer = m.focusRenderer
177 }
178 c, _ := renderer.Render(msg.Content().String())
179 parts := []string{
180 styles.ForceReplaceBackgroundWithLipgloss(c, styles.Background),
181 }
182 // remove newline at the end
183 parts[0] = strings.TrimSuffix(parts[0], "\n")
184 if len(info) > 0 {
185 parts = append(parts, info...)
186 }
187 rendered := style.Render(
188 lipgloss.JoinVertical(
189 lipgloss.Left,
190 parts...,
191 ),
192 )
193 m.cachedContent[msg.ID] = rendered
194 return rendered
195}
196
197func formatTimeDifference(unixTime1, unixTime2 int64) string {
198 diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1)))
199
200 if diffSeconds < 60 {
201 return fmt.Sprintf("%.1fs", diffSeconds)
202 }
203
204 minutes := int(diffSeconds / 60)
205 seconds := int(diffSeconds) % 60
206 return fmt.Sprintf("%dm%ds", minutes, seconds)
207}
208
209func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string {
210 key := ""
211 value := ""
212 switch toolCall.Name {
213 // TODO: add result data to the tools
214 case agent.AgentToolName:
215 key = "Task"
216 var params agent.AgentParams
217 json.Unmarshal([]byte(toolCall.Input), ¶ms)
218 value = params.Prompt
219 // TODO: handle nested calls
220 case tools.BashToolName:
221 key = "Bash"
222 var params tools.BashParams
223 json.Unmarshal([]byte(toolCall.Input), ¶ms)
224 value = params.Command
225 case tools.EditToolName:
226 key = "Edit"
227 var params tools.EditParams
228 json.Unmarshal([]byte(toolCall.Input), ¶ms)
229 value = params.FilePath
230 case tools.FetchToolName:
231 key = "Fetch"
232 var params tools.FetchParams
233 json.Unmarshal([]byte(toolCall.Input), ¶ms)
234 value = params.URL
235 case tools.GlobToolName:
236 key = "Glob"
237 var params tools.GlobParams
238 json.Unmarshal([]byte(toolCall.Input), ¶ms)
239 if params.Path == "" {
240 params.Path = "."
241 }
242 value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
243 case tools.GrepToolName:
244 key = "Grep"
245 var params tools.GrepParams
246 json.Unmarshal([]byte(toolCall.Input), ¶ms)
247 if params.Path == "" {
248 params.Path = "."
249 }
250 value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
251 case tools.LSToolName:
252 key = "Ls"
253 var params tools.LSParams
254 json.Unmarshal([]byte(toolCall.Input), ¶ms)
255 if params.Path == "" {
256 params.Path = "."
257 }
258 value = params.Path
259 case tools.SourcegraphToolName:
260 key = "Sourcegraph"
261 var params tools.SourcegraphParams
262 json.Unmarshal([]byte(toolCall.Input), ¶ms)
263 value = params.Query
264 case tools.ViewToolName:
265 key = "View"
266 var params tools.ViewParams
267 json.Unmarshal([]byte(toolCall.Input), ¶ms)
268 value = params.FilePath
269 case tools.WriteToolName:
270 key = "Write"
271 var params tools.WriteParams
272 json.Unmarshal([]byte(toolCall.Input), ¶ms)
273 value = params.FilePath
274 default:
275 key = toolCall.Name
276 var params map[string]any
277 json.Unmarshal([]byte(toolCall.Input), ¶ms)
278 jsonData, _ := json.Marshal(params)
279 value = string(jsonData)
280 }
281
282 style := styles.BaseStyle.
283 Width(m.width).
284 BorderLeft(true).
285 BorderStyle(lipgloss.ThickBorder()).
286 PaddingLeft(1).
287 BorderForeground(styles.Yellow)
288
289 keyStyle := styles.BaseStyle.
290 Foreground(styles.ForgroundDim)
291 valyeStyle := styles.BaseStyle.
292 Foreground(styles.Forground)
293
294 if isNested {
295 valyeStyle = valyeStyle.Foreground(styles.ForgroundMid)
296 }
297 keyValye := keyStyle.Render(
298 fmt.Sprintf("%s: ", key),
299 )
300 if !isNested {
301 value = valyeStyle.
302 Width(m.width - lipgloss.Width(keyValye) - 2).
303 Render(
304 ansi.Truncate(
305 value,
306 m.width-lipgloss.Width(keyValye)-2,
307 "...",
308 ),
309 )
310 } else {
311 keyValye = keyStyle.Render(
312 fmt.Sprintf(" └ %s: ", key),
313 )
314 value = valyeStyle.
315 Width(m.width - lipgloss.Width(keyValye) - 2).
316 Render(
317 ansi.Truncate(
318 value,
319 m.width-lipgloss.Width(keyValye)-2,
320 "...",
321 ),
322 )
323 }
324
325 innerToolCalls := make([]string, 0)
326 if toolCall.Name == agent.AgentToolName {
327 messages, _ := m.app.Messages.List(toolCall.ID)
328 toolCalls := make([]message.ToolCall, 0)
329 for _, v := range messages {
330 toolCalls = append(toolCalls, v.ToolCalls()...)
331 }
332 for _, v := range toolCalls {
333 call := m.renderToolCall(v, true)
334 innerToolCalls = append(innerToolCalls, call)
335 }
336 }
337
338 if isNested {
339 return lipgloss.JoinHorizontal(
340 lipgloss.Left,
341 keyValye,
342 value,
343 )
344 }
345 callContent := lipgloss.JoinHorizontal(
346 lipgloss.Left,
347 keyValye,
348 value,
349 )
350 callContent = strings.ReplaceAll(callContent, "\n", "")
351 if len(innerToolCalls) > 0 {
352 callContent = lipgloss.JoinVertical(
353 lipgloss.Left,
354 callContent,
355 lipgloss.JoinVertical(
356 lipgloss.Left,
357 innerToolCalls...,
358 ),
359 )
360 }
361 return style.Render(callContent)
362}
363
364func (m *messagesCmp) renderAssistantMessage(msg message.Message) []uiMessage {
365 // find the user message that is before this assistant message
366 var userMsg message.Message
367 for i := len(m.messages) - 1; i >= 0; i-- {
368 if m.messages[i].Role == message.User {
369 userMsg = m.messages[i]
370 break
371 }
372 }
373 messages := make([]uiMessage, 0)
374 if msg.Content().String() != "" {
375 info := make([]string, 0)
376 if msg.IsFinished() && msg.FinishReason() == "end_turn" {
377 finish := msg.FinishPart()
378 took := formatTimeDifference(userMsg.CreatedAt, finish.Time)
379
380 info = append(info, styles.BaseStyle.Width(m.width-1).Foreground(styles.ForgroundDim).Render(
381 fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took),
382 ))
383 }
384 content := m.renderSimpleMessage(msg, info...)
385 messages = append(messages, uiMessage{
386 messageType: assistantMessageType,
387 position: 0, // gets updated in renderView
388 height: lipgloss.Height(content),
389 content: content,
390 })
391 }
392 for _, v := range msg.ToolCalls() {
393 content := m.renderToolCall(v, false)
394 messages = append(messages,
395 uiMessage{
396 messageType: toolMessageType,
397 position: 0, // gets updated in renderView
398 height: lipgloss.Height(content),
399 content: content,
400 },
401 )
402 }
403
404 return messages
405}
406
407func (m *messagesCmp) renderView() {
408 m.uiMessages = make([]uiMessage, 0)
409 pos := 0
410
411 for _, v := range m.messages {
412 switch v.Role {
413 case message.User:
414 content := m.renderSimpleMessage(v)
415 m.uiMessages = append(m.uiMessages, uiMessage{
416 messageType: userMessageType,
417 position: pos,
418 height: lipgloss.Height(content),
419 content: content,
420 })
421 pos += lipgloss.Height(content) + 1 // + 1 for spacing
422 case message.Assistant:
423 assistantMessages := m.renderAssistantMessage(v)
424 for _, msg := range assistantMessages {
425 msg.position = pos
426 m.uiMessages = append(m.uiMessages, msg)
427 pos += msg.height + 1 // + 1 for spacing
428 }
429
430 }
431 }
432
433 messages := make([]string, 0)
434 for _, v := range m.uiMessages {
435 messages = append(messages, v.content,
436 styles.BaseStyle.
437 Width(m.width).
438 Render(
439 "",
440 ),
441 )
442 }
443 m.viewport.SetContent(
444 styles.BaseStyle.
445 Width(m.width).
446 Render(
447 lipgloss.JoinVertical(
448 lipgloss.Top,
449 messages...,
450 ),
451 ),
452 )
453}
454
455func (m *messagesCmp) View() string {
456 if len(m.messages) == 0 {
457 content := styles.BaseStyle.
458 Width(m.width).
459 Height(m.height - 1).
460 Render(
461 m.initialScreen(),
462 )
463
464 return styles.BaseStyle.
465 Width(m.width).
466 Render(
467 lipgloss.JoinVertical(
468 lipgloss.Top,
469 content,
470 m.help(),
471 ),
472 )
473 }
474
475 return styles.BaseStyle.
476 Width(m.width).
477 Render(
478 lipgloss.JoinVertical(
479 lipgloss.Top,
480 m.viewport.View(),
481 m.help(),
482 ),
483 )
484}
485
486func (m *messagesCmp) help() string {
487 text := ""
488
489 if m.agentWorking {
490 text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render(
491 fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."),
492 )
493 }
494 if m.writingMode {
495 text += lipgloss.JoinHorizontal(
496 lipgloss.Left,
497 styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
498 styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"),
499 styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"),
500 )
501 } else {
502 text += lipgloss.JoinHorizontal(
503 lipgloss.Left,
504 styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
505 styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"),
506 styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"),
507 )
508 }
509
510 return styles.BaseStyle.
511 Width(m.width).
512 Render(text)
513}
514
515func (m *messagesCmp) initialScreen() string {
516 return styles.BaseStyle.Width(m.width).Render(
517 lipgloss.JoinVertical(
518 lipgloss.Top,
519 header(m.width),
520 "",
521 lspsConfigured(m.width),
522 ),
523 )
524}
525
526func (m *messagesCmp) SetSize(width, height int) {
527 m.width = width
528 m.height = height
529 m.viewport.Width = width
530 m.viewport.Height = height - 1
531 focusRenderer, _ := glamour.NewTermRenderer(
532 glamour.WithStyles(styles.MarkdownTheme(true)),
533 glamour.WithWordWrap(width-1),
534 )
535 renderer, _ := glamour.NewTermRenderer(
536 glamour.WithStyles(styles.MarkdownTheme(false)),
537 glamour.WithWordWrap(width-1),
538 )
539 m.focusRenderer = focusRenderer
540 // clear the cached content
541 for k := range m.cachedContent {
542 delete(m.cachedContent, k)
543 }
544 m.renderer = renderer
545 if len(m.messages) > 0 {
546 m.renderView()
547 m.viewport.GotoBottom()
548 }
549}
550
551func (m *messagesCmp) GetSize() (int, int) {
552 return m.width, m.height
553}
554
555func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
556 m.session = session
557 messages, err := m.app.Messages.List(session.ID)
558 if err != nil {
559 return util.ReportError(err)
560 }
561 m.messages = messages
562 m.currentMsgID = m.messages[len(m.messages)-1].ID
563 m.needsRerender = true
564 return nil
565}
566
567func NewMessagesCmp(app *app.App) tea.Model {
568 focusRenderer, _ := glamour.NewTermRenderer(
569 glamour.WithStyles(styles.MarkdownTheme(true)),
570 glamour.WithWordWrap(80),
571 )
572 renderer, _ := glamour.NewTermRenderer(
573 glamour.WithStyles(styles.MarkdownTheme(false)),
574 glamour.WithWordWrap(80),
575 )
576
577 s := spinner.New()
578 s.Spinner = spinner.Pulse
579 return &messagesCmp{
580 app: app,
581 writingMode: true,
582 cachedContent: make(map[string]string),
583 viewport: viewport.New(0, 0),
584 focusRenderer: focusRenderer,
585 renderer: renderer,
586 spinner: s,
587 }
588}