1package chat
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "math"
8 "strings"
9
10 "github.com/charmbracelet/bubbles/spinner"
11 "github.com/charmbracelet/bubbles/viewport"
12 tea "github.com/charmbracelet/bubbletea"
13 "github.com/charmbracelet/glamour"
14 "github.com/charmbracelet/lipgloss"
15 "github.com/charmbracelet/x/ansi"
16 "github.com/kujtimiihoxha/termai/internal/app"
17 "github.com/kujtimiihoxha/termai/internal/llm/agent"
18 "github.com/kujtimiihoxha/termai/internal/llm/models"
19 "github.com/kujtimiihoxha/termai/internal/llm/tools"
20 "github.com/kujtimiihoxha/termai/internal/message"
21 "github.com/kujtimiihoxha/termai/internal/pubsub"
22 "github.com/kujtimiihoxha/termai/internal/session"
23 "github.com/kujtimiihoxha/termai/internal/tui/styles"
24 "github.com/kujtimiihoxha/termai/internal/tui/util"
25)
26
27type uiMessageType int
28
29const (
30 userMessageType uiMessageType = iota
31 assistantMessageType
32 toolMessageType
33)
34
35type uiMessage struct {
36 ID string
37 messageType uiMessageType
38 position int
39 height int
40 content string
41}
42
43type messagesCmp struct {
44 app *app.App
45 width, height int
46 writingMode bool
47 viewport viewport.Model
48 session session.Session
49 messages []message.Message
50 uiMessages []uiMessage
51 currentMsgID string
52 renderer *glamour.TermRenderer
53 focusRenderer *glamour.TermRenderer
54 cachedContent map[string]string
55 agentWorking bool
56 spinner spinner.Model
57 needsRerender bool
58 lastViewport string
59}
60
61func (m *messagesCmp) Init() tea.Cmd {
62 return tea.Batch(m.viewport.Init())
63}
64
65func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
66 var cmds []tea.Cmd
67 switch msg := msg.(type) {
68 case AgentWorkingMsg:
69 m.agentWorking = bool(msg)
70 if m.agentWorking {
71 cmds = append(cmds, m.spinner.Tick)
72 }
73 case EditorFocusMsg:
74 m.writingMode = bool(msg)
75 case SessionSelectedMsg:
76 if msg.ID != m.session.ID {
77 cmd := m.SetSession(msg)
78 m.needsRerender = true
79 return m, cmd
80 }
81 return m, nil
82 case SessionClearedMsg:
83 m.session = session.Session{}
84 m.messages = make([]message.Message, 0)
85 m.currentMsgID = ""
86 m.needsRerender = true
87 return m, nil
88
89 case tea.KeyMsg:
90 if m.writingMode {
91 return m, nil
92 }
93 case pubsub.Event[message.Message]:
94 if msg.Type == pubsub.CreatedEvent {
95 if msg.Payload.SessionID == m.session.ID {
96 // check if message exists
97
98 messageExists := false
99 for _, v := range m.messages {
100 if v.ID == msg.Payload.ID {
101 messageExists = true
102 break
103 }
104 }
105
106 if !messageExists {
107 m.messages = append(m.messages, msg.Payload)
108 delete(m.cachedContent, m.currentMsgID)
109 m.currentMsgID = msg.Payload.ID
110 m.needsRerender = true
111 }
112 }
113 for _, v := range m.messages {
114 for _, c := range v.ToolCalls() {
115 // the message is being added to the session of a tool called
116 if c.ID == msg.Payload.SessionID {
117 m.needsRerender = true
118 }
119 }
120 }
121 } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID {
122 for i, v := range m.messages {
123 if v.ID == msg.Payload.ID {
124 if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" {
125 cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false)))
126 }
127 m.messages[i] = msg.Payload
128 delete(m.cachedContent, msg.Payload.ID)
129 m.needsRerender = true
130 break
131 }
132 }
133 }
134 }
135 if m.agentWorking {
136 u, cmd := m.spinner.Update(msg)
137 m.spinner = u
138 cmds = append(cmds, cmd)
139 }
140 oldPos := m.viewport.YPosition
141 u, cmd := m.viewport.Update(msg)
142 m.viewport = u
143 m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos
144 cmds = append(cmds, cmd)
145 if m.needsRerender {
146 m.renderView()
147 if len(m.messages) > 0 {
148 if msg, ok := msg.(pubsub.Event[message.Message]); ok {
149 if (msg.Type == pubsub.CreatedEvent) ||
150 (msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) {
151 m.viewport.GotoBottom()
152 }
153 }
154 }
155 m.needsRerender = false
156 }
157 return m, tea.Batch(cmds...)
158}
159
160func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string {
161 if v, ok := m.cachedContent[msg.ID]; ok {
162 return v
163 }
164 style := styles.BaseStyle.
165 Width(m.width).
166 BorderLeft(true).
167 Foreground(styles.ForgroundDim).
168 BorderForeground(styles.ForgroundDim).
169 BorderStyle(lipgloss.ThickBorder())
170
171 renderer := m.renderer
172 if msg.ID == m.currentMsgID {
173 style = style.
174 Foreground(styles.Forground).
175 BorderForeground(styles.Blue).
176 BorderStyle(lipgloss.ThickBorder())
177 renderer = m.focusRenderer
178 }
179 c, _ := renderer.Render(msg.Content().String())
180 parts := []string{
181 styles.ForceReplaceBackgroundWithLipgloss(c, styles.Background),
182 }
183 // remove newline at the end
184 parts[0] = strings.TrimSuffix(parts[0], "\n")
185 if len(info) > 0 {
186 parts = append(parts, info...)
187 }
188 rendered := style.Render(
189 lipgloss.JoinVertical(
190 lipgloss.Left,
191 parts...,
192 ),
193 )
194 m.cachedContent[msg.ID] = rendered
195 return rendered
196}
197
198func formatTimeDifference(unixTime1, unixTime2 int64) string {
199 diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1)))
200
201 if diffSeconds < 60 {
202 return fmt.Sprintf("%.1fs", diffSeconds)
203 }
204
205 minutes := int(diffSeconds / 60)
206 seconds := int(diffSeconds) % 60
207 return fmt.Sprintf("%dm%ds", minutes, seconds)
208}
209
210func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string {
211 key := ""
212 value := ""
213 switch toolCall.Name {
214 // TODO: add result data to the tools
215 case agent.AgentToolName:
216 key = "Task"
217 var params agent.AgentParams
218 json.Unmarshal([]byte(toolCall.Input), ¶ms)
219 value = params.Prompt
220 // TODO: handle nested calls
221 case tools.BashToolName:
222 key = "Bash"
223 var params tools.BashParams
224 json.Unmarshal([]byte(toolCall.Input), ¶ms)
225 value = params.Command
226 case tools.EditToolName:
227 key = "Edit"
228 var params tools.EditParams
229 json.Unmarshal([]byte(toolCall.Input), ¶ms)
230 value = params.FilePath
231 case tools.FetchToolName:
232 key = "Fetch"
233 var params tools.FetchParams
234 json.Unmarshal([]byte(toolCall.Input), ¶ms)
235 value = params.URL
236 case tools.GlobToolName:
237 key = "Glob"
238 var params tools.GlobParams
239 json.Unmarshal([]byte(toolCall.Input), ¶ms)
240 if params.Path == "" {
241 params.Path = "."
242 }
243 value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
244 case tools.GrepToolName:
245 key = "Grep"
246 var params tools.GrepParams
247 json.Unmarshal([]byte(toolCall.Input), ¶ms)
248 if params.Path == "" {
249 params.Path = "."
250 }
251 value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path)
252 case tools.LSToolName:
253 key = "Ls"
254 var params tools.LSParams
255 json.Unmarshal([]byte(toolCall.Input), ¶ms)
256 if params.Path == "" {
257 params.Path = "."
258 }
259 value = params.Path
260 case tools.SourcegraphToolName:
261 key = "Sourcegraph"
262 var params tools.SourcegraphParams
263 json.Unmarshal([]byte(toolCall.Input), ¶ms)
264 value = params.Query
265 case tools.ViewToolName:
266 key = "View"
267 var params tools.ViewParams
268 json.Unmarshal([]byte(toolCall.Input), ¶ms)
269 value = params.FilePath
270 case tools.WriteToolName:
271 key = "Write"
272 var params tools.WriteParams
273 json.Unmarshal([]byte(toolCall.Input), ¶ms)
274 value = params.FilePath
275 default:
276 key = toolCall.Name
277 var params map[string]any
278 json.Unmarshal([]byte(toolCall.Input), ¶ms)
279 jsonData, _ := json.Marshal(params)
280 value = string(jsonData)
281 }
282
283 style := styles.BaseStyle.
284 Width(m.width).
285 BorderLeft(true).
286 BorderStyle(lipgloss.ThickBorder()).
287 PaddingLeft(1).
288 BorderForeground(styles.Yellow)
289
290 keyStyle := styles.BaseStyle.
291 Foreground(styles.ForgroundDim)
292 valyeStyle := styles.BaseStyle.
293 Foreground(styles.Forground)
294
295 if isNested {
296 valyeStyle = valyeStyle.Foreground(styles.ForgroundMid)
297 }
298 keyValye := keyStyle.Render(
299 fmt.Sprintf("%s: ", key),
300 )
301 if !isNested {
302 value = valyeStyle.
303 Width(m.width - lipgloss.Width(keyValye) - 2).
304 Render(
305 ansi.Truncate(
306 value,
307 m.width-lipgloss.Width(keyValye)-2,
308 "...",
309 ),
310 )
311 } else {
312 keyValye = keyStyle.Render(
313 fmt.Sprintf(" └ %s: ", key),
314 )
315 value = valyeStyle.
316 Width(m.width - lipgloss.Width(keyValye) - 2).
317 Render(
318 ansi.Truncate(
319 value,
320 m.width-lipgloss.Width(keyValye)-2,
321 "...",
322 ),
323 )
324 }
325
326 innerToolCalls := make([]string, 0)
327 if toolCall.Name == agent.AgentToolName {
328 messages, _ := m.app.Messages.List(context.Background(), toolCall.ID)
329 toolCalls := make([]message.ToolCall, 0)
330 for _, v := range messages {
331 toolCalls = append(toolCalls, v.ToolCalls()...)
332 }
333 for _, v := range toolCalls {
334 call := m.renderToolCall(v, true)
335 innerToolCalls = append(innerToolCalls, call)
336 }
337 }
338
339 if isNested {
340 return lipgloss.JoinHorizontal(
341 lipgloss.Left,
342 keyValye,
343 value,
344 )
345 }
346 callContent := lipgloss.JoinHorizontal(
347 lipgloss.Left,
348 keyValye,
349 value,
350 )
351 callContent = strings.ReplaceAll(callContent, "\n", "")
352 if len(innerToolCalls) > 0 {
353 callContent = lipgloss.JoinVertical(
354 lipgloss.Left,
355 callContent,
356 lipgloss.JoinVertical(
357 lipgloss.Left,
358 innerToolCalls...,
359 ),
360 )
361 }
362 return style.Render(callContent)
363}
364
365func (m *messagesCmp) renderAssistantMessage(msg message.Message) []uiMessage {
366 // find the user message that is before this assistant message
367 var userMsg message.Message
368 for i := len(m.messages) - 1; i >= 0; i-- {
369 if m.messages[i].Role == message.User {
370 userMsg = m.messages[i]
371 break
372 }
373 }
374 messages := make([]uiMessage, 0)
375 if msg.Content().String() != "" {
376 info := make([]string, 0)
377 if msg.IsFinished() && msg.FinishReason() == "end_turn" {
378 finish := msg.FinishPart()
379 took := formatTimeDifference(userMsg.CreatedAt, finish.Time)
380
381 info = append(info, styles.BaseStyle.Width(m.width-1).Foreground(styles.ForgroundDim).Render(
382 fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took),
383 ))
384 }
385 content := m.renderSimpleMessage(msg, info...)
386 messages = append(messages, uiMessage{
387 messageType: assistantMessageType,
388 position: 0, // gets updated in renderView
389 height: lipgloss.Height(content),
390 content: content,
391 })
392 }
393 for _, v := range msg.ToolCalls() {
394 content := m.renderToolCall(v, false)
395 messages = append(messages,
396 uiMessage{
397 messageType: toolMessageType,
398 position: 0, // gets updated in renderView
399 height: lipgloss.Height(content),
400 content: content,
401 },
402 )
403 }
404
405 return messages
406}
407
408func (m *messagesCmp) renderView() {
409 m.uiMessages = make([]uiMessage, 0)
410 pos := 0
411
412 for _, v := range m.messages {
413 switch v.Role {
414 case message.User:
415 content := m.renderSimpleMessage(v)
416 m.uiMessages = append(m.uiMessages, uiMessage{
417 messageType: userMessageType,
418 position: pos,
419 height: lipgloss.Height(content),
420 content: content,
421 })
422 pos += lipgloss.Height(content) + 1 // + 1 for spacing
423 case message.Assistant:
424 assistantMessages := m.renderAssistantMessage(v)
425 for _, msg := range assistantMessages {
426 msg.position = pos
427 m.uiMessages = append(m.uiMessages, msg)
428 pos += msg.height + 1 // + 1 for spacing
429 }
430
431 }
432 }
433
434 messages := make([]string, 0)
435 for _, v := range m.uiMessages {
436 messages = append(messages, v.content,
437 styles.BaseStyle.
438 Width(m.width).
439 Render(
440 "",
441 ),
442 )
443 }
444 m.viewport.SetContent(
445 styles.BaseStyle.
446 Width(m.width).
447 Render(
448 lipgloss.JoinVertical(
449 lipgloss.Top,
450 messages...,
451 ),
452 ),
453 )
454}
455
456func (m *messagesCmp) View() string {
457 if len(m.messages) == 0 {
458 content := styles.BaseStyle.
459 Width(m.width).
460 Height(m.height - 1).
461 Render(
462 m.initialScreen(),
463 )
464
465 return styles.BaseStyle.
466 Width(m.width).
467 Render(
468 lipgloss.JoinVertical(
469 lipgloss.Top,
470 content,
471 m.help(),
472 ),
473 )
474 }
475
476 return styles.BaseStyle.
477 Width(m.width).
478 Render(
479 lipgloss.JoinVertical(
480 lipgloss.Top,
481 m.viewport.View(),
482 m.help(),
483 ),
484 )
485}
486
487func (m *messagesCmp) help() string {
488 text := ""
489
490 if m.agentWorking {
491 text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render(
492 fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."),
493 )
494 }
495 if m.writingMode {
496 text += lipgloss.JoinHorizontal(
497 lipgloss.Left,
498 styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
499 styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"),
500 styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"),
501 )
502 } else {
503 text += lipgloss.JoinHorizontal(
504 lipgloss.Left,
505 styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "),
506 styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"),
507 styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"),
508 )
509 }
510
511 return styles.BaseStyle.
512 Width(m.width).
513 Render(text)
514}
515
516func (m *messagesCmp) initialScreen() string {
517 return styles.BaseStyle.Width(m.width).Render(
518 lipgloss.JoinVertical(
519 lipgloss.Top,
520 header(m.width),
521 "",
522 lspsConfigured(m.width),
523 ),
524 )
525}
526
527func (m *messagesCmp) SetSize(width, height int) {
528 m.width = width
529 m.height = height
530 m.viewport.Width = width
531 m.viewport.Height = height - 1
532 focusRenderer, _ := glamour.NewTermRenderer(
533 glamour.WithStyles(styles.MarkdownTheme(true)),
534 glamour.WithWordWrap(width-1),
535 )
536 renderer, _ := glamour.NewTermRenderer(
537 glamour.WithStyles(styles.MarkdownTheme(false)),
538 glamour.WithWordWrap(width-1),
539 )
540 m.focusRenderer = focusRenderer
541 // clear the cached content
542 for k := range m.cachedContent {
543 delete(m.cachedContent, k)
544 }
545 m.renderer = renderer
546 if len(m.messages) > 0 {
547 m.renderView()
548 m.viewport.GotoBottom()
549 }
550}
551
552func (m *messagesCmp) GetSize() (int, int) {
553 return m.width, m.height
554}
555
556func (m *messagesCmp) SetSession(session session.Session) tea.Cmd {
557 m.session = session
558 messages, err := m.app.Messages.List(context.Background(), session.ID)
559 if err != nil {
560 return util.ReportError(err)
561 }
562 m.messages = messages
563 m.currentMsgID = m.messages[len(m.messages)-1].ID
564 m.needsRerender = true
565 return nil
566}
567
568func NewMessagesCmp(app *app.App) tea.Model {
569 focusRenderer, _ := glamour.NewTermRenderer(
570 glamour.WithStyles(styles.MarkdownTheme(true)),
571 glamour.WithWordWrap(80),
572 )
573 renderer, _ := glamour.NewTermRenderer(
574 glamour.WithStyles(styles.MarkdownTheme(false)),
575 glamour.WithWordWrap(80),
576 )
577
578 s := spinner.New()
579 s.Spinner = spinner.Pulse
580 return &messagesCmp{
581 app: app,
582 writingMode: true,
583 cachedContent: make(map[string]string),
584 viewport: viewport.New(0, 0),
585 focusRenderer: focusRenderer,
586 renderer: renderer,
587 spinner: s,
588 }
589}