From 23e7a95083a8d875420c90e0479647f18a278c5f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Apr 2025 18:54:13 +0200 Subject: [PATCH 01/41] intiial layout --- internal/tui/components/chat/editor.go | 62 +++++ internal/tui/components/chat/messages.go | 21 ++ internal/tui/components/chat/sidebar.go | 21 ++ internal/tui/components/dialog/permission.go | 2 +- internal/tui/components/repl/editor.go | 48 ++-- internal/tui/layout/container.go | 224 ++++++++++++++++++ internal/tui/layout/single.go | 4 +- internal/tui/layout/split.go | 229 +++++++++++++++++++ internal/tui/page/chat.go | 30 +++ internal/tui/page/init.go | 2 +- internal/tui/styles/styles.go | 16 ++ internal/tui/tui.go | 5 +- 12 files changed, 638 insertions(+), 26 deletions(-) create mode 100644 internal/tui/components/chat/editor.go create mode 100644 internal/tui/components/chat/messages.go create mode 100644 internal/tui/components/chat/sidebar.go create mode 100644 internal/tui/layout/container.go create mode 100644 internal/tui/layout/split.go create mode 100644 internal/tui/page/chat.go diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go new file mode 100644 index 0000000000000000000000000000000000000000..0b7617174e661f6553774374bf71730207555b7e --- /dev/null +++ b/internal/tui/components/chat/editor.go @@ -0,0 +1,62 @@ +package chat + +import ( + "github.com/charmbracelet/bubbles/key" + "github.com/charmbracelet/bubbles/textarea" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/layout" + "github.com/kujtimiihoxha/termai/internal/tui/styles" +) + +type editorCmp struct { + textarea textarea.Model +} + +func (m *editorCmp) Init() tea.Cmd { + return textarea.Blink +} + +func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmd tea.Cmd + m.textarea, cmd = m.textarea.Update(msg) + return m, cmd +} + +func (m *editorCmp) View() string { + style := lipgloss.NewStyle().Padding(0, 0, 0, 1).Bold(true) + + return lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"), m.textarea.View()) +} + +func (m *editorCmp) SetSize(width, height int) { + m.textarea.SetWidth(width - 3) // account for the prompt and padding right + m.textarea.SetHeight(height) +} + +func (m *editorCmp) GetSize() (int, int) { + return m.textarea.Width(), m.textarea.Height() +} + +func (m *editorCmp) BindingKeys() []key.Binding { + return layout.KeyMapToSlice(m.textarea.KeyMap) +} + +func NewEditorCmp() tea.Model { + ti := textarea.New() + ti.Prompt = " " + ti.ShowLineNumbers = false + ti.BlurredStyle.Base = ti.BlurredStyle.Base.Background(styles.Background) + ti.BlurredStyle.CursorLine = ti.BlurredStyle.CursorLine.Background(styles.Background) + ti.BlurredStyle.Placeholder = ti.BlurredStyle.Placeholder.Background(styles.Background) + ti.BlurredStyle.Text = ti.BlurredStyle.Text.Background(styles.Background) + + ti.FocusedStyle.Base = ti.FocusedStyle.Base.Background(styles.Background) + ti.FocusedStyle.CursorLine = ti.FocusedStyle.CursorLine.Background(styles.Background) + ti.FocusedStyle.Placeholder = ti.FocusedStyle.Placeholder.Background(styles.Background) + ti.FocusedStyle.Text = ti.BlurredStyle.Text.Background(styles.Background) + ti.Focus() + return &editorCmp{ + textarea: ti, + } +} diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go new file mode 100644 index 0000000000000000000000000000000000000000..691954767aa71de5e5a4b3c2b92f30f2ea084866 --- /dev/null +++ b/internal/tui/components/chat/messages.go @@ -0,0 +1,21 @@ +package chat + +import tea "github.com/charmbracelet/bubbletea" + +type messagesCmp struct{} + +func (m *messagesCmp) Init() tea.Cmd { + return nil +} + +func (m *messagesCmp) Update(tea.Msg) (tea.Model, tea.Cmd) { + return m, nil +} + +func (m *messagesCmp) View() string { + return "Messages" +} + +func NewMessagesCmp() tea.Model { + return &messagesCmp{} +} diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go new file mode 100644 index 0000000000000000000000000000000000000000..afdd241f44e9114957da435f870c2570e61304a5 --- /dev/null +++ b/internal/tui/components/chat/sidebar.go @@ -0,0 +1,21 @@ +package chat + +import tea "github.com/charmbracelet/bubbletea" + +type sidebarCmp struct{} + +func (m *sidebarCmp) Init() tea.Cmd { + return nil +} + +func (m *sidebarCmp) Update(tea.Msg) (tea.Model, tea.Cmd) { + return m, nil +} + +func (m *sidebarCmp) View() string { + return "Sidebar" +} + +func NewSidebarCmp() tea.Model { + return &sidebarCmp{} +} diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 465f475d54ea4366a9010440fc42afcb738a999a..088697d5542259437d9c8e84f460309c14e6896c 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -441,7 +441,7 @@ func NewPermissionDialogCmd(permission permission.PermissionRequest) tea.Cmd { layout.WithSinglePaneBordered(true), layout.WithSinglePaneFocusable(true), layout.WithSinglePaneActiveColor(styles.Warning), - layout.WithSignlePaneBorderText(map[layout.BorderPosition]string{ + layout.WithSinglePaneBorderText(map[layout.BorderPosition]string{ layout.TopMiddleBorder: " Permission Required ", }), ) diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go index 37ac275e3f11db19fd464ba94394fbc65bcdc054..e9493129d9a266d703ba123351ff81b266fe4230 100644 --- a/internal/tui/components/repl/editor.go +++ b/internal/tui/components/repl/editor.go @@ -156,30 +156,36 @@ func (m *editorCmp) Cancel() tea.Cmd { } func (m *editorCmp) Send() tea.Cmd { - return func() tea.Msg { - messages, err := m.app.Messages.List(m.sessionID) - if err != nil { - return util.ReportError(err) - } - if hasUnfinishedMessages(messages) { - return util.ReportWarn("Assistant is still working on the previous message") - } - a, err := agent.NewCoderAgent(m.app) - if err != nil { - return util.ReportError(err) - } + if m.cancelMessage != nil { + return util.ReportWarn("Assistant is still working on the previous message") + } - content := strings.Join(m.editor.GetBuffer().Lines(), "\n") - ctx, cancel := context.WithCancel(m.app.Context) - m.cancelMessage = cancel - go func() { - defer cancel() - a.Generate(ctx, m.sessionID, content) - m.cancelMessage = nil - }() + messages, err := m.app.Messages.List(m.sessionID) + if err != nil { + return util.ReportError(err) + } + if hasUnfinishedMessages(messages) { + return util.ReportWarn("Assistant is still working on the previous message") + } + + a, err := agent.NewCoderAgent(m.app) + if err != nil { + return util.ReportError(err) + } - return m.editor.Reset() + content := strings.Join(m.editor.GetBuffer().Lines(), "\n") + if len(content) == 0 { + return util.ReportWarn("Message is empty") } + ctx, cancel := context.WithCancel(m.app.Context) + m.cancelMessage = cancel + go func() { + defer cancel() + a.Generate(ctx, m.sessionID, content) + m.cancelMessage = nil + }() + + return m.editor.Reset() } func (m *editorCmp) View() string { diff --git a/internal/tui/layout/container.go b/internal/tui/layout/container.go new file mode 100644 index 0000000000000000000000000000000000000000..db07d49fb925ad829f3e05f907c0207e6e0dbe89 --- /dev/null +++ b/internal/tui/layout/container.go @@ -0,0 +1,224 @@ +package layout + +import ( + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/styles" +) + +type Container interface { + tea.Model + Sizeable +} +type container struct { + width int + height int + + content tea.Model + + // Style options + paddingTop int + paddingRight int + paddingBottom int + paddingLeft int + + borderTop bool + borderRight bool + borderBottom bool + borderLeft bool + borderStyle lipgloss.Border + borderColor lipgloss.TerminalColor + + backgroundColor lipgloss.TerminalColor +} + +func (c *container) Init() tea.Cmd { + return c.content.Init() +} + +func (c *container) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + u, cmd := c.content.Update(msg) + c.content = u + return c, cmd +} + +func (c *container) View() string { + style := lipgloss.NewStyle() + width := c.width + height := c.height + // Apply background color if specified + if c.backgroundColor != nil { + style = style.Background(c.backgroundColor) + } + + // Apply border if any side is enabled + if c.borderTop || c.borderRight || c.borderBottom || c.borderLeft { + // Adjust width and height for borders + if c.borderTop { + height-- + } + if c.borderBottom { + height-- + } + if c.borderLeft { + width-- + } + if c.borderRight { + width-- + } + style = style.Border(c.borderStyle, c.borderTop, c.borderRight, c.borderBottom, c.borderLeft) + + // Apply border color if specified + if c.borderColor != nil { + style = style.BorderBackground(c.backgroundColor).BorderForeground(c.borderColor) + } + } + style = style. + Width(width). + Height(height). + PaddingTop(c.paddingTop). + PaddingRight(c.paddingRight). + PaddingBottom(c.paddingBottom). + PaddingLeft(c.paddingLeft) + + return style.Render(c.content.View()) +} + +func (c *container) SetSize(width, height int) { + c.width = width + c.height = height + + // If the content implements Sizeable, adjust its size to account for padding and borders + if sizeable, ok := c.content.(Sizeable); ok { + // Calculate horizontal space taken by padding and borders + horizontalSpace := c.paddingLeft + c.paddingRight + if c.borderLeft { + horizontalSpace++ + } + if c.borderRight { + horizontalSpace++ + } + + // Calculate vertical space taken by padding and borders + verticalSpace := c.paddingTop + c.paddingBottom + if c.borderTop { + verticalSpace++ + } + if c.borderBottom { + verticalSpace++ + } + + // Set content size with adjusted dimensions + contentWidth := max(0, width-horizontalSpace) + contentHeight := max(0, height-verticalSpace) + sizeable.SetSize(contentWidth, contentHeight) + } +} + +func (c *container) GetSize() (int, int) { + return c.width, c.height +} + +func (c *container) BindingKeys() []key.Binding { + if b, ok := c.content.(Bindings); ok { + return b.BindingKeys() + } + return []key.Binding{} +} + +type ContainerOption func(*container) + +func NewContainer(content tea.Model, options ...ContainerOption) Container { + c := &container{ + content: content, + borderColor: styles.BorderColor, + borderStyle: lipgloss.NormalBorder(), + backgroundColor: styles.Background, + } + + for _, option := range options { + option(c) + } + + return c +} + +// Padding options +func WithPadding(top, right, bottom, left int) ContainerOption { + return func(c *container) { + c.paddingTop = top + c.paddingRight = right + c.paddingBottom = bottom + c.paddingLeft = left + } +} + +func WithPaddingAll(padding int) ContainerOption { + return WithPadding(padding, padding, padding, padding) +} + +func WithPaddingHorizontal(padding int) ContainerOption { + return func(c *container) { + c.paddingLeft = padding + c.paddingRight = padding + } +} + +func WithPaddingVertical(padding int) ContainerOption { + return func(c *container) { + c.paddingTop = padding + c.paddingBottom = padding + } +} + +func WithBorder(top, right, bottom, left bool) ContainerOption { + return func(c *container) { + c.borderTop = top + c.borderRight = right + c.borderBottom = bottom + c.borderLeft = left + } +} + +func WithBorderAll() ContainerOption { + return WithBorder(true, true, true, true) +} + +func WithBorderHorizontal() ContainerOption { + return WithBorder(true, false, true, false) +} + +func WithBorderVertical() ContainerOption { + return WithBorder(false, true, false, true) +} + +func WithBorderStyle(style lipgloss.Border) ContainerOption { + return func(c *container) { + c.borderStyle = style + } +} + +func WithBorderColor(color lipgloss.TerminalColor) ContainerOption { + return func(c *container) { + c.borderColor = color + } +} + +func WithRoundedBorder() ContainerOption { + return WithBorderStyle(lipgloss.RoundedBorder()) +} + +func WithThickBorder() ContainerOption { + return WithBorderStyle(lipgloss.ThickBorder()) +} + +func WithDoubleBorder() ContainerOption { + return WithBorderStyle(lipgloss.DoubleBorder()) +} + +func WithBackgroundColor(color lipgloss.TerminalColor) ContainerOption { + return func(c *container) { + c.backgroundColor = color + } +} diff --git a/internal/tui/layout/single.go b/internal/tui/layout/single.go index e5c9a61c45d656f12ef0241bba02ca65cd47bd78..c77fa0d78e4b73bb3a79e4164ef264bcdc38aa02 100644 --- a/internal/tui/layout/single.go +++ b/internal/tui/layout/single.go @@ -151,7 +151,7 @@ func NewSinglePane(content tea.Model, opts ...SinglePaneOption) SinglePaneLayout return layout } -func WithSignlePaneSize(width, height int) SinglePaneOption { +func WithSinglePaneSize(width, height int) SinglePaneOption { return func(opts *singlePaneLayout) { opts.width = width opts.height = height @@ -170,7 +170,7 @@ func WithSinglePaneBordered(bordered bool) SinglePaneOption { } } -func WithSignlePaneBorderText(borderText map[BorderPosition]string) SinglePaneOption { +func WithSinglePaneBorderText(borderText map[BorderPosition]string) SinglePaneOption { return func(opts *singlePaneLayout) { opts.borderText = borderText } diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go new file mode 100644 index 0000000000000000000000000000000000000000..2a6822c7edbd2ee5c8ef2db9500270271673bd46 --- /dev/null +++ b/internal/tui/layout/split.go @@ -0,0 +1,229 @@ +package layout + +import ( + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/styles" +) + +type SplitPaneLayout interface { + tea.Model + Sizeable +} + +type splitPaneLayout struct { + width int + height int + ratio float64 + verticalRatio float64 + + rightPanel Container + leftPanel Container + bottomPanel Container + + backgroundColor lipgloss.TerminalColor +} + +type SplitPaneOption func(*splitPaneLayout) + +func (s *splitPaneLayout) Init() tea.Cmd { + var cmds []tea.Cmd + + if s.leftPanel != nil { + cmds = append(cmds, s.leftPanel.Init()) + } + + if s.rightPanel != nil { + cmds = append(cmds, s.rightPanel.Init()) + } + + if s.bottomPanel != nil { + cmds = append(cmds, s.bottomPanel.Init()) + } + + return tea.Batch(cmds...) +} + +func (s *splitPaneLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd + switch msg := msg.(type) { + case tea.WindowSizeMsg: + s.SetSize(msg.Width, msg.Height) + return s, nil + } + + if s.rightPanel != nil { + u, cmd := s.rightPanel.Update(msg) + s.rightPanel = u.(Container) + if cmd != nil { + cmds = append(cmds, cmd) + } + } + + if s.leftPanel != nil { + u, cmd := s.leftPanel.Update(msg) + s.leftPanel = u.(Container) + if cmd != nil { + cmds = append(cmds, cmd) + } + } + + if s.bottomPanel != nil { + u, cmd := s.bottomPanel.Update(msg) + s.bottomPanel = u.(Container) + if cmd != nil { + cmds = append(cmds, cmd) + } + } + + return s, tea.Batch(cmds...) +} + +func (s *splitPaneLayout) View() string { + var topSection string + + if s.leftPanel != nil && s.rightPanel != nil { + leftView := s.leftPanel.View() + rightView := s.rightPanel.View() + topSection = lipgloss.JoinHorizontal(lipgloss.Top, leftView, rightView) + } else if s.leftPanel != nil { + topSection = s.leftPanel.View() + } else if s.rightPanel != nil { + topSection = s.rightPanel.View() + } else { + topSection = "" + } + + var finalView string + + if s.bottomPanel != nil && topSection != "" { + bottomView := s.bottomPanel.View() + finalView = lipgloss.JoinVertical(lipgloss.Left, topSection, bottomView) + } else if s.bottomPanel != nil { + finalView = s.bottomPanel.View() + } else { + finalView = topSection + } + + if s.backgroundColor != nil && finalView != "" { + style := lipgloss.NewStyle(). + Width(s.width). + Height(s.height). + Background(s.backgroundColor) + + return style.Render(finalView) + } + + return finalView +} + +func (s *splitPaneLayout) SetSize(width, height int) { + s.width = width + s.height = height + + var topHeight, bottomHeight int + if s.bottomPanel != nil { + topHeight = int(float64(height) * s.verticalRatio) + bottomHeight = height - topHeight + } else { + topHeight = height + bottomHeight = 0 + } + + var leftWidth, rightWidth int + if s.leftPanel != nil && s.rightPanel != nil { + leftWidth = int(float64(width) * s.ratio) + rightWidth = width - leftWidth + } else if s.leftPanel != nil { + leftWidth = width + rightWidth = 0 + } else if s.rightPanel != nil { + leftWidth = 0 + rightWidth = width + } + + if s.leftPanel != nil { + s.leftPanel.SetSize(leftWidth, topHeight) + } + + if s.rightPanel != nil { + s.rightPanel.SetSize(rightWidth, topHeight) + } + + if s.bottomPanel != nil { + s.bottomPanel.SetSize(width, bottomHeight) + } +} + +func (s *splitPaneLayout) GetSize() (int, int) { + return s.width, s.height +} + +func (s *splitPaneLayout) BindingKeys() []key.Binding { + keys := []key.Binding{} + if s.leftPanel != nil { + if b, ok := s.leftPanel.(Bindings); ok { + keys = append(keys, b.BindingKeys()...) + } + } + if s.rightPanel != nil { + if b, ok := s.rightPanel.(Bindings); ok { + keys = append(keys, b.BindingKeys()...) + } + } + if s.bottomPanel != nil { + if b, ok := s.bottomPanel.(Bindings); ok { + keys = append(keys, b.BindingKeys()...) + } + } + return keys +} + +func NewSplitPane(options ...SplitPaneOption) SplitPaneLayout { + layout := &splitPaneLayout{ + ratio: 0.7, + verticalRatio: 0.9, // Default 80% for top section, 20% for bottom + backgroundColor: styles.Background, + } + for _, option := range options { + option(layout) + } + return layout +} + +func WithLeftPanel(panel Container) SplitPaneOption { + return func(s *splitPaneLayout) { + s.leftPanel = panel + } +} + +func WithRightPanel(panel Container) SplitPaneOption { + return func(s *splitPaneLayout) { + s.rightPanel = panel + } +} + +func WithRatio(ratio float64) SplitPaneOption { + return func(s *splitPaneLayout) { + s.ratio = ratio + } +} + +func WithSplitBackgroundColor(color lipgloss.TerminalColor) SplitPaneOption { + return func(s *splitPaneLayout) { + s.backgroundColor = color + } +} + +func WithBottomPanel(panel Container) SplitPaneOption { + return func(s *splitPaneLayout) { + s.bottomPanel = panel + } +} + +func WithVerticalRatio(ratio float64) SplitPaneOption { + return func(s *splitPaneLayout) { + s.verticalRatio = ratio + } +} diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go new file mode 100644 index 0000000000000000000000000000000000000000..de5b3910fec4e0efbca11a059cfeea33a235c953 --- /dev/null +++ b/internal/tui/page/chat.go @@ -0,0 +1,30 @@ +package page + +import ( + tea "github.com/charmbracelet/bubbletea" + "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/tui/components/chat" + "github.com/kujtimiihoxha/termai/internal/tui/layout" +) + +var ChatPage PageID = "chat" + +func NewChatPage(app *app.App) tea.Model { + messagesContainer := layout.NewContainer( + chat.NewMessagesCmp(), + layout.WithPadding(1, 1, 1, 1), + ) + sidebarContainer := layout.NewContainer( + chat.NewSidebarCmp(), + layout.WithPadding(1, 1, 1, 1), + ) + editorContainer := layout.NewContainer( + chat.NewEditorCmp(), + layout.WithBorder(true, false, false, false), + ) + return layout.NewSplitPane( + layout.WithRightPanel(sidebarContainer), + layout.WithLeftPanel(messagesContainer), + layout.WithBottomPanel(editorContainer), + ) +} diff --git a/internal/tui/page/init.go b/internal/tui/page/init.go index 93a5e6fba4a39f2ee10eb3ee98998e065e7ab691..0a5c6f82a522300e52e0a33676ce9eee71796072 100644 --- a/internal/tui/page/init.go +++ b/internal/tui/page/init.go @@ -299,7 +299,7 @@ func NewInitPage() tea.Model { initModel, layout.WithSinglePaneFocusable(true), layout.WithSinglePaneBordered(true), - layout.WithSignlePaneBorderText( + layout.WithSinglePaneBorderText( map[layout.BorderPosition]string{ layout.TopMiddleBorder: "Welcome to termai - Initial Setup", }, diff --git a/internal/tui/styles/styles.go b/internal/tui/styles/styles.go index fe92959e15e4ce2d730fb6e1803ce05b2bcaa93b..86ee3649035fc70773232142cfbecdc9fb3c4f9e 100644 --- a/internal/tui/styles/styles.go +++ b/internal/tui/styles/styles.go @@ -10,6 +10,22 @@ var ( dark = catppuccin.Mocha ) +// NEW STYLES +var ( + Background = lipgloss.AdaptiveColor{ + Dark: "#212121", + Light: "#212121", + } + BackgroundDarker = lipgloss.AdaptiveColor{ + Dark: "#181818", + Light: "#181818", + } + BorderColor = lipgloss.AdaptiveColor{ + Dark: "#4b4c5c", + Light: "#4b4c5c", + } +) + var ( Regular = lipgloss.NewStyle() Bold = Regular.Bold(true) diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 9e863d2ac844c3f3a034124b0eadbd6700893b01..eb996d44863ffa8af398082ab7fc52417f6daf85 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -200,6 +200,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case key.Matches(msg, keys.Logs): return a, a.moveToPage(page.LogsPage) + case msg.String() == "O": + return a, a.moveToPage(page.ReplPage) case key.Matches(msg, keys.Help): a.ToggleHelp() return a, nil @@ -292,7 +294,7 @@ func New(app *app.App) tea.Model { // homedir, _ := os.UserHomeDir() // configPath := filepath.Join(homedir, ".termai.yaml") // - startPage := page.ReplPage + startPage := page.ChatPage // if _, err := os.Stat(configPath); os.IsNotExist(err) { // startPage = page.InitPage // } @@ -305,6 +307,7 @@ func New(app *app.App) tea.Model { dialog: core.NewDialogCmp(), app: app, pages: map[page.PageID]tea.Model{ + page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), page.InitPage: page.NewInitPage(), page.ReplPage: page.NewReplPage(app), From 08bd75bb6e1fde0427dfd37204ee9a3c43bb1e5b Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 11 Apr 2025 20:31:24 +0200 Subject: [PATCH 02/41] add initial mock sidebar --- internal/tui/components/chat/editor.go | 73 +++++++++- internal/tui/components/chat/sidebar.go | 183 +++++++++++++++++++++++- internal/tui/styles/icons.go | 1 + internal/tui/styles/styles.go | 23 +++ 4 files changed, 274 insertions(+), 6 deletions(-) diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index 0b7617174e661f6553774374bf71730207555b7e..ea20d7e4420054486511a31f5a660194d236a29a 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -13,14 +13,75 @@ type editorCmp struct { textarea textarea.Model } +type focusedEditorKeyMaps struct { + Send key.Binding + Blur key.Binding +} + +type bluredEditorKeyMaps struct { + Send key.Binding + Focus key.Binding +} + +var focusedKeyMaps = focusedEditorKeyMaps{ + Send: key.NewBinding( + key.WithKeys("ctrl+s"), + key.WithHelp("ctrl+s", "send message"), + ), + Blur: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "blur editor"), + ), +} + +var bluredKeyMaps = bluredEditorKeyMaps{ + Send: key.NewBinding( + key.WithKeys("ctrl+s", "enter"), + key.WithHelp("ctrl+s/enter", "send message"), + ), + Focus: key.NewBinding( + key.WithKeys("i"), + key.WithHelp("i", "focus editor"), + ), +} + func (m *editorCmp) Init() tea.Cmd { return textarea.Blink } func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd - m.textarea, cmd = m.textarea.Update(msg) - return m, cmd + if m.textarea.Focused() { + switch msg := msg.(type) { + case tea.KeyMsg: + if key.Matches(msg, focusedKeyMaps.Send) { + // TODO: send message + m.textarea.Reset() + m.textarea.Blur() + return m, nil + } + if key.Matches(msg, focusedKeyMaps.Blur) { + m.textarea.Blur() + return m, nil + } + } + m.textarea, cmd = m.textarea.Update(msg) + return m, cmd + } + switch msg := msg.(type) { + case tea.KeyMsg: + if key.Matches(msg, bluredKeyMaps.Send) { + // TODO: send message + m.textarea.Reset() + return m, nil + } + if key.Matches(msg, bluredKeyMaps.Focus) { + m.textarea.Focus() + return m, textarea.Blink + } + } + + return m, nil } func (m *editorCmp) View() string { @@ -39,7 +100,13 @@ func (m *editorCmp) GetSize() (int, int) { } func (m *editorCmp) BindingKeys() []key.Binding { - return layout.KeyMapToSlice(m.textarea.KeyMap) + bindings := layout.KeyMapToSlice(m.textarea.KeyMap) + if m.textarea.Focused() { + bindings = append(bindings, layout.KeyMapToSlice(focusedKeyMaps)...) + } else { + bindings = append(bindings, layout.KeyMapToSlice(bluredKeyMaps)...) + } + return bindings } func NewEditorCmp() tea.Model { diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index afdd241f44e9114957da435f870c2570e61304a5..4a563157738246b6dc15c657146c01563b5bfd54 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -1,8 +1,18 @@ package chat -import tea "github.com/charmbracelet/bubbletea" +import ( + "fmt" -type sidebarCmp struct{} + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/termai/internal/version" +) + +type sidebarCmp struct { + width, height int +} func (m *sidebarCmp) Init() tea.Cmd { return nil @@ -13,7 +23,174 @@ func (m *sidebarCmp) Update(tea.Msg) (tea.Model, tea.Cmd) { } func (m *sidebarCmp) View() string { - return "Sidebar" + return styles.BaseStyle.Width(m.width).Render( + lipgloss.JoinVertical( + lipgloss.Top, + m.header(), + " ", + m.session(), + " ", + m.modifiedFiles(), + " ", + m.lspsConfigured(), + ), + ) +} + +func (m *sidebarCmp) session() string { + sessionKey := styles.BaseStyle.Foreground(styles.PrimaryColor).Render("Session") + sessionValue := styles.BaseStyle. + Foreground(styles.Forground). + Width(m.width - lipgloss.Width(sessionKey)). + Render(": New Session") + return lipgloss.JoinHorizontal( + lipgloss.Left, + sessionKey, + sessionValue, + ) +} + +func (m *sidebarCmp) modifiedFile(filePath string, additions, removals int) string { + stats := "" + if additions > 0 && removals > 0 { + stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%d additions and %d removals", additions, removals)) + } else if additions > 0 { + stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%d additions", additions)) + } else if removals > 0 { + stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%d removals", removals)) + } + filePathStr := styles.BaseStyle.Foreground(styles.Forground).Render(filePath) + + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinHorizontal( + lipgloss.Left, + filePathStr, + " ", + stats, + ), + ) +} + +func (m *sidebarCmp) lspsConfigured() string { + lsps := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Render("LSP Configuration:") + lspsConfigured := []struct { + name string + path string + }{ + {"golsp", "path/to/lsp1"}, + {"vtsls", "path/to/lsp2"}, + } + + var lspViews []string + for _, lsp := range lspsConfigured { + lspName := styles.BaseStyle.Foreground(styles.Forground).Render( + fmt.Sprintf("• %s", lsp.name), + ) + lspPath := styles.BaseStyle.Foreground(styles.ForgroundDim).Render( + fmt.Sprintf("(%s)", lsp.path), + ) + lspViews = append(lspViews, + styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinHorizontal( + lipgloss.Left, + lspName, + " ", + lspPath, + ), + ), + ) + + } + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Left, + lsps, + lipgloss.JoinVertical( + lipgloss.Left, + lspViews..., + ), + ), + ) +} + +func (m *sidebarCmp) modifiedFiles() string { + modifiedFiles := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Render("Modified Files:") + files := []struct { + path string + additions int + removals int + }{ + {"file1.txt", 10, 5}, + {"file2.txt", 20, 0}, + {"file3.txt", 0, 15}, + } + var fileViews []string + for _, file := range files { + fileViews = append(fileViews, m.modifiedFile(file.path, file.additions, file.removals)) + } + + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + modifiedFiles, + lipgloss.JoinVertical( + lipgloss.Left, + fileViews..., + ), + ), + ) +} + +func (m *sidebarCmp) logo() string { + logo := fmt.Sprintf("%s %s", styles.OpenCodeIcon, "OpenCode") + + version := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(version.Version) + + return styles.BaseStyle. + Bold(true). + Width(m.width). + Render( + lipgloss.JoinHorizontal( + lipgloss.Left, + logo, + " ", + version, + ), + ) +} + +func (m *sidebarCmp) header() string { + header := lipgloss.JoinVertical( + lipgloss.Top, + m.logo(), + m.cwd(), + ) + return header +} + +func (m *sidebarCmp) cwd() string { + cwd := fmt.Sprintf("cwd: %s", config.WorkingDirectory()) + return styles.BaseStyle. + Foreground(styles.ForgroundDim). + Width(m.width). + Render(cwd) +} + +func (m *sidebarCmp) SetSize(width, height int) { + m.width = width + m.height = height +} + +func (m *sidebarCmp) GetSize() (int, int) { + return m.width, m.height } func NewSidebarCmp() tea.Model { diff --git a/internal/tui/styles/icons.go b/internal/tui/styles/icons.go index f641984e773b47d33c89a51a1d77f469b53334ca..aa0df1e31ca994cab9b294694851787fb66f2e02 100644 --- a/internal/tui/styles/icons.go +++ b/internal/tui/styles/icons.go @@ -1,6 +1,7 @@ package styles const ( + OpenCodeIcon string = "⌬" SessionsIcon string = "󰧑" ChatIcon string = "󰭹" diff --git a/internal/tui/styles/styles.go b/internal/tui/styles/styles.go index 86ee3649035fc70773232142cfbecdc9fb3c4f9e..41863cf1b79a4d36ce9c3d27bb87d1c2b2ddedc4 100644 --- a/internal/tui/styles/styles.go +++ b/internal/tui/styles/styles.go @@ -16,6 +16,10 @@ var ( Dark: "#212121", Light: "#212121", } + BackgroundDim = lipgloss.AdaptiveColor{ + Dark: "#2c2c2c", + Light: "#2c2c2c", + } BackgroundDarker = lipgloss.AdaptiveColor{ Dark: "#181818", Light: "#181818", @@ -24,6 +28,25 @@ var ( Dark: "#4b4c5c", Light: "#4b4c5c", } + + Forground = lipgloss.AdaptiveColor{ + Dark: "#d3d3d3", + Light: "#d3d3d3", + } + + ForgroundDim = lipgloss.AdaptiveColor{ + Dark: "#737373", + Light: "#737373", + } + + BaseStyle = lipgloss.NewStyle(). + Background(Background). + Foreground(Forground) + + PrimaryColor = lipgloss.AdaptiveColor{ + Dark: "#fab283", + Light: "#fab283", + } ) var ( From 8d874b839db169906e18e4277cd198504018e022 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 12 Apr 2025 02:01:45 +0200 Subject: [PATCH 03/41] add initial message handling --- internal/db/messages.sql.go | 34 +- internal/db/migrations/000001_initial.up.sql | 2 + internal/db/models.go | 14 +- internal/db/sql/messages.sql | 5 +- internal/message/content.go | 14 +- internal/message/message.go | 14 +- internal/tui/components/chat/chat.go | 113 +++++ internal/tui/components/chat/editor.go | 65 +-- internal/tui/components/chat/messages.go | 339 +++++++++++++- internal/tui/components/chat/sidebar.go | 133 ++---- internal/tui/components/core/button.go | 287 ------------ internal/tui/layout/split.go | 24 + internal/tui/page/chat.go | 92 +++- internal/tui/styles/markdown.go | 446 ++++++++++++++++++- 14 files changed, 1124 insertions(+), 458 deletions(-) create mode 100644 internal/tui/components/chat/chat.go delete mode 100644 internal/tui/components/core/button.go diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go index 4309db181c29b117c3c436edbfb55ee1da0fdc03..0555b4330d79089c0d5a7127c311f55af567e604 100644 --- a/internal/db/messages.sql.go +++ b/internal/db/messages.sql.go @@ -7,6 +7,7 @@ package db import ( "context" + "database/sql" ) const createMessage = `-- name: CreateMessage :one @@ -15,19 +16,21 @@ INSERT INTO messages ( session_id, role, parts, + model, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) -RETURNING id, session_id, role, parts, created_at, updated_at +RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at ` type CreateMessageParams struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - Role string `json:"role"` - Parts string `json:"parts"` + ID string `json:"id"` + SessionID string `json:"session_id"` + Role string `json:"role"` + Parts string `json:"parts"` + Model sql.NullString `json:"model"` } func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) { @@ -36,6 +39,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M arg.SessionID, arg.Role, arg.Parts, + arg.Model, ) var i Message err := row.Scan( @@ -43,8 +47,10 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M &i.SessionID, &i.Role, &i.Parts, + &i.Model, &i.CreatedAt, &i.UpdatedAt, + &i.FinishedAt, ) return i, err } @@ -70,7 +76,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e } const getMessage = `-- name: GetMessage :one -SELECT id, session_id, role, parts, created_at, updated_at +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at FROM messages WHERE id = ? LIMIT 1 ` @@ -83,14 +89,16 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) { &i.SessionID, &i.Role, &i.Parts, + &i.Model, &i.CreatedAt, &i.UpdatedAt, + &i.FinishedAt, ) return i, err } const listMessagesBySession = `-- name: ListMessagesBySession :many -SELECT id, session_id, role, parts, created_at, updated_at +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at FROM messages WHERE session_id = ? ORDER BY created_at ASC @@ -110,8 +118,10 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) ( &i.SessionID, &i.Role, &i.Parts, + &i.Model, &i.CreatedAt, &i.UpdatedAt, + &i.FinishedAt, ); err != nil { return nil, err } @@ -130,16 +140,18 @@ const updateMessage = `-- name: UpdateMessage :exec UPDATE messages SET parts = ?, + finished_at = ?, updated_at = strftime('%s', 'now') WHERE id = ? ` type UpdateMessageParams struct { - Parts string `json:"parts"` - ID string `json:"id"` + Parts string `json:"parts"` + FinishedAt sql.NullInt64 `json:"finished_at"` + ID string `json:"id"` } func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error { - _, err := q.exec(ctx, q.updateMessageStmt, updateMessage, arg.Parts, arg.ID) + _, err := q.exec(ctx, q.updateMessageStmt, updateMessage, arg.Parts, arg.FinishedAt, arg.ID) return err } diff --git a/internal/db/migrations/000001_initial.up.sql b/internal/db/migrations/000001_initial.up.sql index 2fbe5547e9996e6f97110294aa91533f3ca71a0d..03479449d24492c04cda61f56056d9c3d7fb73fa 100644 --- a/internal/db/migrations/000001_initial.up.sql +++ b/internal/db/migrations/000001_initial.up.sql @@ -24,8 +24,10 @@ CREATE TABLE IF NOT EXISTS messages ( session_id TEXT NOT NULL, role TEXT NOT NULL, parts TEXT NOT NULL default '[]', + model TEXT, created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds + finished_at INTEGER, -- Unix timestamp in milliseconds FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE ); diff --git a/internal/db/models.go b/internal/db/models.go index 1ad8607a9654c227bc5a17c73667d2d94d0e137d..2fad913be831d4d475642def6d94f2f7fadd960d 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -9,12 +9,14 @@ import ( ) type Message struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - Role string `json:"role"` - Parts string `json:"parts"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + ID string `json:"id"` + SessionID string `json:"session_id"` + Role string `json:"role"` + Parts string `json:"parts"` + Model sql.NullString `json:"model"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + FinishedAt sql.NullInt64 `json:"finished_at"` } type Session struct { diff --git a/internal/db/sql/messages.sql b/internal/db/sql/messages.sql index 64571158fe55199fe79c4d1753180ee6ffbe383b..a59cebe7d00fe5fd7cbd449df681df45e832979a 100644 --- a/internal/db/sql/messages.sql +++ b/internal/db/sql/messages.sql @@ -15,10 +15,11 @@ INSERT INTO messages ( session_id, role, parts, + model, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) RETURNING *; @@ -26,9 +27,11 @@ RETURNING *; UPDATE messages SET parts = ?, + finished_at = ?, updated_at = strftime('%s', 'now') WHERE id = ?; + -- name: DeleteMessage :exec DELETE FROM messages WHERE id = ?; diff --git a/internal/message/content.go b/internal/message/content.go index 2604cd68ada39656c9ad089c27b8f33401bf538e..cd263798b35e8fc1df3278dca1fff288c4a2806c 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -2,6 +2,7 @@ package message import ( "encoding/base64" + "time" ) type MessageRole string @@ -64,6 +65,7 @@ type ToolCall struct { Name string `json:"name"` Input string `json:"input"` Type string `json:"type"` + Metadata any `json:"metadata"` Finished bool `json:"finished"` } @@ -80,6 +82,7 @@ func (ToolResult) isPart() {} type Finish struct { Reason string `json:"reason"` + Time int64 `json:"time"` } func (Finish) isPart() {} @@ -161,6 +164,15 @@ func (m *Message) IsFinished() bool { return false } +func (m *Message) FinishPart() *Finish { + for _, part := range m.Parts { + if c, ok := part.(Finish); ok { + return &c + } + } + return nil +} + func (m *Message) FinishReason() string { for _, part := range m.Parts { if c, ok := part.(Finish); ok { @@ -232,7 +244,7 @@ func (m *Message) SetToolResults(tr []ToolResult) { } func (m *Message) AddFinish(reason string) { - m.Parts = append(m.Parts, Finish{Reason: reason}) + m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()}) } func (m *Message) AddImageURL(url, detail string) { diff --git a/internal/message/message.go b/internal/message/message.go index 13cf54048fe45ea31b14c2e74537d7fd2a499b6a..eeeb83ed2e8cebc6f69bc569cf0f9f2784879307 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -2,17 +2,20 @@ package message import ( "context" + "database/sql" "encoding/json" "fmt" "github.com/google/uuid" "github.com/kujtimiihoxha/termai/internal/db" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/pubsub" ) type CreateMessageParams struct { Role MessageRole Parts []ContentPart + Model models.ModelID } type Service interface { @@ -68,6 +71,7 @@ func (s *service) Create(sessionID string, params CreateMessageParams) (Message, SessionID: sessionID, Role: string(params.Role), Parts: string(partsJSON), + Model: sql.NullString{String: string(params.Model), Valid: true}, }) if err != nil { return Message{}, err @@ -101,9 +105,15 @@ func (s *service) Update(message Message) error { if err != nil { return err } + finishedAt := sql.NullInt64{} + if f := message.FinishPart(); f != nil { + finishedAt.Int64 = f.Time + finishedAt.Valid = true + } err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{ - ID: message.ID, - Parts: string(parts), + ID: message.ID, + Parts: string(parts), + FinishedAt: finishedAt, }) if err != nil { return err diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go new file mode 100644 index 0000000000000000000000000000000000000000..e893ec2f5f962643024bd6003d3e97351858a8ae --- /dev/null +++ b/internal/tui/components/chat/chat.go @@ -0,0 +1,113 @@ +package chat + +import ( + "fmt" + + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/x/ansi" + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/termai/internal/version" +) + +type SendMsg struct { + Text string +} + +type SessionSelectedMsg = session.Session + +type SessionClearedMsg struct{} + +type AgentWorkingMsg bool + +type EditorFocusMsg bool + +func lspsConfigured(width int) string { + cfg := config.Get() + title := "LSP Configuration" + title = ansi.Truncate(title, width, "…") + + lsps := styles.BaseStyle.Width(width).Foreground(styles.PrimaryColor).Bold(true).Render(title) + + var lspViews []string + for name, lsp := range cfg.LSP { + lspName := styles.BaseStyle.Foreground(styles.Forground).Render( + fmt.Sprintf("• %s", name), + ) + cmd := lsp.Command + cmd = ansi.Truncate(cmd, width-lipgloss.Width(lspName)-3, "…") + lspPath := styles.BaseStyle.Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" (%s)", cmd), + ) + lspViews = append(lspViews, + styles.BaseStyle. + Width(width). + Render( + lipgloss.JoinHorizontal( + lipgloss.Left, + lspName, + lspPath, + ), + ), + ) + + } + return styles.BaseStyle. + Width(width). + Render( + lipgloss.JoinVertical( + lipgloss.Left, + lsps, + lipgloss.JoinVertical( + lipgloss.Left, + lspViews..., + ), + ), + ) +} + +func logo(width int) string { + logo := fmt.Sprintf("%s %s", styles.OpenCodeIcon, "OpenCode") + + version := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(version.Version) + + return styles.BaseStyle. + Bold(true). + Width(width). + Render( + lipgloss.JoinHorizontal( + lipgloss.Left, + logo, + " ", + version, + ), + ) +} + +func repo(width int) string { + repo := "https://github.com/kujtimiihoxha/opencode" + return styles.BaseStyle. + Foreground(styles.ForgroundDim). + Width(width). + Render(repo) +} + +func cwd(width int) string { + cwd := fmt.Sprintf("cwd: %s", config.WorkingDirectory()) + return styles.BaseStyle. + Foreground(styles.ForgroundDim). + Width(width). + Render(cwd) +} + +func header(width int) string { + header := lipgloss.JoinVertical( + lipgloss.Top, + logo(width), + repo(width), + "", + cwd(width), + ) + return header +} diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index ea20d7e4420054486511a31f5a660194d236a29a..df336818ce1b28fe76fed100c7eeba68bc4772f8 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -7,10 +7,12 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/termai/internal/tui/util" ) type editorCmp struct { - textarea textarea.Model + textarea textarea.Model + agentWorking bool } type focusedEditorKeyMaps struct { @@ -49,39 +51,51 @@ func (m *editorCmp) Init() tea.Cmd { return textarea.Blink } +func (m *editorCmp) send() tea.Cmd { + if m.agentWorking { + return util.ReportWarn("Agent is working, please wait...") + } + + value := m.textarea.Value() + m.textarea.Reset() + m.textarea.Blur() + if value == "" { + return nil + } + return tea.Batch( + util.CmdHandler(SendMsg{ + Text: value, + }), + util.CmdHandler(AgentWorkingMsg(true)), + util.CmdHandler(EditorFocusMsg(false)), + ) +} + func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd - if m.textarea.Focused() { - switch msg := msg.(type) { - case tea.KeyMsg: - if key.Matches(msg, focusedKeyMaps.Send) { - // TODO: send message - m.textarea.Reset() - m.textarea.Blur() - return m, nil - } - if key.Matches(msg, focusedKeyMaps.Blur) { - m.textarea.Blur() - return m, nil - } - } - m.textarea, cmd = m.textarea.Update(msg) - return m, cmd - } switch msg := msg.(type) { + case AgentWorkingMsg: + m.agentWorking = bool(msg) case tea.KeyMsg: + if key.Matches(msg, focusedKeyMaps.Send) { + return m, m.send() + } if key.Matches(msg, bluredKeyMaps.Send) { - // TODO: send message - m.textarea.Reset() - return m, nil + return m, m.send() + } + if key.Matches(msg, focusedKeyMaps.Blur) { + m.textarea.Blur() + return m, util.CmdHandler(EditorFocusMsg(false)) } if key.Matches(msg, bluredKeyMaps.Focus) { - m.textarea.Focus() - return m, textarea.Blink + if !m.textarea.Focused() { + m.textarea.Focus() + return m, tea.Batch(textarea.Blink, util.CmdHandler(EditorFocusMsg(true))) + } } } - - return m, nil + m.textarea, cmd = m.textarea.Update(msg) + return m, cmd } func (m *editorCmp) View() string { @@ -122,6 +136,7 @@ func NewEditorCmp() tea.Model { ti.FocusedStyle.CursorLine = ti.FocusedStyle.CursorLine.Background(styles.Background) ti.FocusedStyle.Placeholder = ti.FocusedStyle.Placeholder.Background(styles.Background) ti.FocusedStyle.Text = ti.BlurredStyle.Text.Background(styles.Background) + ti.CharLimit = -1 ti.Focus() return &editorCmp{ textarea: ti, diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index 691954767aa71de5e5a4b3c2b92f30f2ea084866..0a7e6e2a499e4d14ad19a327ba8de4dbfb1e59cd 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -1,21 +1,344 @@ package chat -import tea "github.com/charmbracelet/bubbletea" +import ( + "fmt" + "regexp" + "strconv" + "strings" -type messagesCmp struct{} + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/glamour" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/termai/internal/tui/util" +) + +type uiMessage struct { + position int + height int + content string +} + +type messagesCmp struct { + app *app.App + width, height int + writingMode bool + viewport viewport.Model + session session.Session + messages []message.Message + uiMessages []uiMessage + currentIndex int + renderer *glamour.TermRenderer + focusRenderer *glamour.TermRenderer + cachedContent map[string]string +} func (m *messagesCmp) Init() tea.Cmd { - return nil + return m.viewport.Init() +} + +var ansiEscape = regexp.MustCompile("\x1b\\[[0-9;]*m") + +func hexToBgSGR(hex string) (string, error) { + hex = strings.TrimPrefix(hex, "#") + if len(hex) != 6 { + return "", fmt.Errorf("invalid hex color: must be 6 hexadecimal digits") + } + + // Parse RGB components in one block + rgb := make([]uint64, 3) + for i := 0; i < 3; i++ { + val, err := strconv.ParseUint(hex[i*2:i*2+2], 16, 8) + if err != nil { + return "", err + } + rgb[i] = val + } + + return fmt.Sprintf("48;2;%d;%d;%d", rgb[0], rgb[1], rgb[2]), nil +} + +func forceReplaceBackgroundColors(input string, newBg string) string { + return ansiEscape.ReplaceAllStringFunc(input, func(seq string) string { + // Extract content between "\x1b[" and "m" + content := seq[2 : len(seq)-1] + tokens := strings.Split(content, ";") + var newTokens []string + + // Skip background color tokens + for i := 0; i < len(tokens); i++ { + if tokens[i] == "" { + continue + } + + val, err := strconv.Atoi(tokens[i]) + if err != nil { + newTokens = append(newTokens, tokens[i]) + continue + } + + // Skip background color tokens + if val == 48 { + // Skip "48;5;N" or "48;2;R;G;B" sequences + if i+1 < len(tokens) { + if nextVal, err := strconv.Atoi(tokens[i+1]); err == nil { + switch nextVal { + case 5: + i += 2 // Skip "5" and color index + case 2: + i += 4 // Skip "2" and RGB components + } + } + } + } else if (val < 40 || val > 47) && (val < 100 || val > 107) && val != 49 { + // Keep non-background tokens + newTokens = append(newTokens, tokens[i]) + } + } + + // Add new background if provided + if newBg != "" { + newTokens = append(newTokens, strings.Split(newBg, ";")...) + } + + if len(newTokens) == 0 { + return "" + } + + return "\x1b[" + strings.Join(newTokens, ";") + "m" + }) +} + +func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case EditorFocusMsg: + m.writingMode = bool(msg) + case SessionSelectedMsg: + if msg.ID != m.session.ID { + cmd := m.SetSession(msg) + return m, cmd + } + return m, nil + case pubsub.Event[message.Message]: + if msg.Type == pubsub.CreatedEvent { + if msg.Payload.SessionID == m.session.ID { + // check if message exists + for _, v := range m.messages { + if v.ID == msg.Payload.ID { + return m, nil + } + } + + m.messages = append(m.messages, msg.Payload) + m.renderView() + m.viewport.GotoBottom() + } + for _, v := range m.messages { + for _, c := range v.ToolCalls() { + // the message is being added to the session of a tool called + if c.ID == msg.Payload.SessionID { + m.renderView() + m.viewport.GotoBottom() + } + } + } + } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { + for i, v := range m.messages { + if v.ID == msg.Payload.ID { + m.messages[i] = msg.Payload + delete(m.cachedContent, msg.Payload.ID) + m.renderView() + if i == len(m.messages)-1 { + m.viewport.GotoBottom() + } + break + } + } + } + } + u, cmd := m.viewport.Update(msg) + m.viewport = u + return m, cmd } -func (m *messagesCmp) Update(tea.Msg) (tea.Model, tea.Cmd) { - return m, nil +func (m *messagesCmp) renderUserMessage(inx int, msg message.Message) string { + if v, ok := m.cachedContent[msg.ID]; ok { + return v + } + style := styles.BaseStyle. + Width(m.width). + BorderLeft(true). + Foreground(styles.ForgroundDim). + BorderForeground(styles.ForgroundDim). + BorderStyle(lipgloss.ThickBorder()) + + renderer := m.renderer + if inx == m.currentIndex { + style = style. + Foreground(styles.Forground). + BorderForeground(styles.Blue). + BorderStyle(lipgloss.ThickBorder()) + renderer = m.focusRenderer + } + c, _ := renderer.Render(msg.Content().String()) + col, _ := hexToBgSGR(styles.Background.Dark) + rendered := style.Render(forceReplaceBackgroundColors(c, col)) + m.cachedContent[msg.ID] = rendered + return rendered +} + +func (m *messagesCmp) renderView() { + m.uiMessages = make([]uiMessage, 0) + pos := 0 + + for _, v := range m.messages { + content := "" + switch v.Role { + case message.User: + content = m.renderUserMessage(pos, v) + } + m.uiMessages = append(m.uiMessages, uiMessage{ + position: pos, + height: lipgloss.Height(content), + content: content, + }) + pos += lipgloss.Height(content) + 1 // + 1 for spacing + } + + messages := make([]string, 0) + for _, v := range m.uiMessages { + messages = append(messages, v.content) + } + m.viewport.SetContent( + styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + messages..., + ), + ), + ) } func (m *messagesCmp) View() string { - return "Messages" + if len(m.messages) == 0 { + content := styles.BaseStyle. + Width(m.width). + Height(m.height - 1). + Render( + m.initialScreen(), + ) + + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + content, + m.help(), + ), + ) + } + + m.renderView() + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + m.viewport.View(), + m.help(), + ), + ) +} + +func (m *messagesCmp) help() string { + text := "" + if m.writingMode { + text = lipgloss.JoinHorizontal( + lipgloss.Left, + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), + styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"), + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"), + ) + } else { + text = lipgloss.JoinHorizontal( + lipgloss.Left, + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), + styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"), + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"), + ) + } + + return styles.BaseStyle. + Width(m.width). + Render(text) +} + +func (m *messagesCmp) initialScreen() string { + return styles.BaseStyle.Width(m.width).Render( + lipgloss.JoinVertical( + lipgloss.Top, + header(m.width), + "", + lspsConfigured(m.width), + ), + ) +} + +func (m *messagesCmp) SetSize(width, height int) { + m.width = width + m.height = height + m.viewport.Width = width + m.viewport.Height = height - 1 + focusRenderer, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(width-1), + ) + renderer, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(false)), + glamour.WithWordWrap(width-1), + ) + m.focusRenderer = focusRenderer + m.renderer = renderer +} + +func (m *messagesCmp) GetSize() (int, int) { + return m.width, m.height +} + +func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { + m.session = session + messages, err := m.app.Messages.List(session.ID) + if err != nil { + return util.ReportError(err) + } + m.messages = messages + m.messages = append(m.messages, m.messages[0]) + return nil } -func NewMessagesCmp() tea.Model { - return &messagesCmp{} +func NewMessagesCmp(app *app.App) tea.Model { + focusRenderer, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(80), + ) + renderer, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(false)), + glamour.WithWordWrap(80), + ) + return &messagesCmp{ + app: app, + writingMode: true, + cachedContent: make(map[string]string), + viewport: viewport.New(0, 0), + focusRenderer: focusRenderer, + renderer: renderer, + } } diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index 4a563157738246b6dc15c657146c01563b5bfd54..65c06f4a168e04a8bfc1bdff37368034518ea1dc 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -5,40 +5,43 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/version" ) type sidebarCmp struct { width, height int + session session.Session } func (m *sidebarCmp) Init() tea.Cmd { return nil } -func (m *sidebarCmp) Update(tea.Msg) (tea.Model, tea.Cmd) { +func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } func (m *sidebarCmp) View() string { - return styles.BaseStyle.Width(m.width).Render( - lipgloss.JoinVertical( - lipgloss.Top, - m.header(), - " ", - m.session(), - " ", - m.modifiedFiles(), - " ", - m.lspsConfigured(), - ), - ) + return styles.BaseStyle. + Width(m.width). + Height(m.height - 1). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + header(m.width), + " ", + m.sessionSection(), + " ", + m.modifiedFiles(), + " ", + lspsConfigured(m.width), + ), + ) } -func (m *sidebarCmp) session() string { - sessionKey := styles.BaseStyle.Foreground(styles.PrimaryColor).Render("Session") +func (m *sidebarCmp) sessionSection() string { + sessionKey := styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render("Session") sessionValue := styles.BaseStyle. Foreground(styles.Forground). Width(m.width - lipgloss.Width(sessionKey)). @@ -53,11 +56,11 @@ func (m *sidebarCmp) session() string { func (m *sidebarCmp) modifiedFile(filePath string, additions, removals int) string { stats := "" if additions > 0 && removals > 0 { - stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%d additions and %d removals", additions, removals)) + stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf(" %d additions and %d removals", additions, removals)) } else if additions > 0 { - stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%d additions", additions)) + stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf(" %d additions", additions)) } else if removals > 0 { - stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%d removals", removals)) + stats = styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf(" %d removals", removals)) } filePathStr := styles.BaseStyle.Foreground(styles.Forground).Render(filePath) @@ -67,60 +70,13 @@ func (m *sidebarCmp) modifiedFile(filePath string, additions, removals int) stri lipgloss.JoinHorizontal( lipgloss.Left, filePathStr, - " ", stats, ), ) } -func (m *sidebarCmp) lspsConfigured() string { - lsps := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Render("LSP Configuration:") - lspsConfigured := []struct { - name string - path string - }{ - {"golsp", "path/to/lsp1"}, - {"vtsls", "path/to/lsp2"}, - } - - var lspViews []string - for _, lsp := range lspsConfigured { - lspName := styles.BaseStyle.Foreground(styles.Forground).Render( - fmt.Sprintf("• %s", lsp.name), - ) - lspPath := styles.BaseStyle.Foreground(styles.ForgroundDim).Render( - fmt.Sprintf("(%s)", lsp.path), - ) - lspViews = append(lspViews, - styles.BaseStyle. - Width(m.width). - Render( - lipgloss.JoinHorizontal( - lipgloss.Left, - lspName, - " ", - lspPath, - ), - ), - ) - - } - return styles.BaseStyle. - Width(m.width). - Render( - lipgloss.JoinVertical( - lipgloss.Left, - lsps, - lipgloss.JoinVertical( - lipgloss.Left, - lspViews..., - ), - ), - ) -} - func (m *sidebarCmp) modifiedFiles() string { - modifiedFiles := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Render("Modified Files:") + modifiedFiles := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render("Modified Files:") files := []struct { path string additions int @@ -149,41 +105,6 @@ func (m *sidebarCmp) modifiedFiles() string { ) } -func (m *sidebarCmp) logo() string { - logo := fmt.Sprintf("%s %s", styles.OpenCodeIcon, "OpenCode") - - version := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(version.Version) - - return styles.BaseStyle. - Bold(true). - Width(m.width). - Render( - lipgloss.JoinHorizontal( - lipgloss.Left, - logo, - " ", - version, - ), - ) -} - -func (m *sidebarCmp) header() string { - header := lipgloss.JoinVertical( - lipgloss.Top, - m.logo(), - m.cwd(), - ) - return header -} - -func (m *sidebarCmp) cwd() string { - cwd := fmt.Sprintf("cwd: %s", config.WorkingDirectory()) - return styles.BaseStyle. - Foreground(styles.ForgroundDim). - Width(m.width). - Render(cwd) -} - func (m *sidebarCmp) SetSize(width, height int) { m.width = width m.height = height @@ -193,6 +114,8 @@ func (m *sidebarCmp) GetSize() (int, int) { return m.width, m.height } -func NewSidebarCmp() tea.Model { - return &sidebarCmp{} +func NewSidebarCmp(session session.Session) tea.Model { + return &sidebarCmp{ + session: session, + } } diff --git a/internal/tui/components/core/button.go b/internal/tui/components/core/button.go deleted file mode 100644 index 090fbc1ee510fb6912a3b47a76556f5c1bf0140a..0000000000000000000000000000000000000000 --- a/internal/tui/components/core/button.go +++ /dev/null @@ -1,287 +0,0 @@ -package core - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" -) - -// ButtonKeyMap defines key bindings for the button component -type ButtonKeyMap struct { - Enter key.Binding -} - -// DefaultButtonKeyMap returns default key bindings for the button -func DefaultButtonKeyMap() ButtonKeyMap { - return ButtonKeyMap{ - Enter: key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("enter", "select"), - ), - } -} - -// ShortHelp returns keybinding help -func (k ButtonKeyMap) ShortHelp() []key.Binding { - return []key.Binding{k.Enter} -} - -// FullHelp returns full help info for keybindings -func (k ButtonKeyMap) FullHelp() [][]key.Binding { - return [][]key.Binding{ - {k.Enter}, - } -} - -// ButtonState represents the state of a button -type ButtonState int - -const ( - // ButtonNormal is the default state - ButtonNormal ButtonState = iota - // ButtonHovered is when the button is focused/hovered - ButtonHovered - // ButtonPressed is when the button is being pressed - ButtonPressed - // ButtonDisabled is when the button is disabled - ButtonDisabled -) - -// ButtonVariant defines the visual style variant of a button -type ButtonVariant int - -const ( - // ButtonPrimary uses primary color styling - ButtonPrimary ButtonVariant = iota - // ButtonSecondary uses secondary color styling - ButtonSecondary - // ButtonDanger uses danger/error color styling - ButtonDanger - // ButtonWarning uses warning color styling - ButtonWarning - // ButtonNeutral uses neutral color styling - ButtonNeutral -) - -// ButtonMsg is sent when a button is clicked -type ButtonMsg struct { - ID string - Payload any -} - -// ButtonCmp represents a clickable button component -type ButtonCmp struct { - id string - label string - width int - height int - state ButtonState - variant ButtonVariant - keyMap ButtonKeyMap - payload any - style lipgloss.Style - hoverStyle lipgloss.Style -} - -// NewButtonCmp creates a new button component -func NewButtonCmp(id, label string) *ButtonCmp { - b := &ButtonCmp{ - id: id, - label: label, - state: ButtonNormal, - variant: ButtonPrimary, - keyMap: DefaultButtonKeyMap(), - width: len(label) + 4, // add some padding - height: 1, - } - b.updateStyles() - return b -} - -// WithVariant sets the button variant -func (b *ButtonCmp) WithVariant(variant ButtonVariant) *ButtonCmp { - b.variant = variant - b.updateStyles() - return b -} - -// WithPayload sets the payload sent with button events -func (b *ButtonCmp) WithPayload(payload any) *ButtonCmp { - b.payload = payload - return b -} - -// WithWidth sets a custom width -func (b *ButtonCmp) WithWidth(width int) *ButtonCmp { - b.width = width - b.updateStyles() - return b -} - -// updateStyles recalculates styles based on current state and variant -func (b *ButtonCmp) updateStyles() { - // Base styles - b.style = styles.Regular. - Padding(0, 1). - Width(b.width). - Align(lipgloss.Center). - BorderStyle(lipgloss.RoundedBorder()) - - b.hoverStyle = b.style. - Bold(true) - - // Variant-specific styling - switch b.variant { - case ButtonPrimary: - b.style = b.style. - Foreground(styles.Base). - Background(styles.Primary). - BorderForeground(styles.Primary) - - b.hoverStyle = b.hoverStyle. - Foreground(styles.Base). - Background(styles.Blue). - BorderForeground(styles.Blue) - - case ButtonSecondary: - b.style = b.style. - Foreground(styles.Base). - Background(styles.Secondary). - BorderForeground(styles.Secondary) - - b.hoverStyle = b.hoverStyle. - Foreground(styles.Base). - Background(styles.Mauve). - BorderForeground(styles.Mauve) - - case ButtonDanger: - b.style = b.style. - Foreground(styles.Base). - Background(styles.Error). - BorderForeground(styles.Error) - - b.hoverStyle = b.hoverStyle. - Foreground(styles.Base). - Background(styles.Red). - BorderForeground(styles.Red) - - case ButtonWarning: - b.style = b.style. - Foreground(styles.Text). - Background(styles.Warning). - BorderForeground(styles.Warning) - - b.hoverStyle = b.hoverStyle. - Foreground(styles.Text). - Background(styles.Peach). - BorderForeground(styles.Peach) - - case ButtonNeutral: - b.style = b.style. - Foreground(styles.Text). - Background(styles.Grey). - BorderForeground(styles.Grey) - - b.hoverStyle = b.hoverStyle. - Foreground(styles.Text). - Background(styles.DarkGrey). - BorderForeground(styles.DarkGrey) - } - - // Disabled style override - if b.state == ButtonDisabled { - b.style = b.style. - Foreground(styles.SubText0). - Background(styles.LightGrey). - BorderForeground(styles.LightGrey) - } -} - -// SetSize sets the button size -func (b *ButtonCmp) SetSize(width, height int) { - b.width = width - b.height = height - b.updateStyles() -} - -// Focus sets the button to focused state -func (b *ButtonCmp) Focus() tea.Cmd { - if b.state != ButtonDisabled { - b.state = ButtonHovered - } - return nil -} - -// Blur sets the button to normal state -func (b *ButtonCmp) Blur() tea.Cmd { - if b.state != ButtonDisabled { - b.state = ButtonNormal - } - return nil -} - -// Disable sets the button to disabled state -func (b *ButtonCmp) Disable() { - b.state = ButtonDisabled - b.updateStyles() -} - -// Enable enables the button if disabled -func (b *ButtonCmp) Enable() { - if b.state == ButtonDisabled { - b.state = ButtonNormal - b.updateStyles() - } -} - -// IsDisabled returns whether the button is disabled -func (b *ButtonCmp) IsDisabled() bool { - return b.state == ButtonDisabled -} - -// IsFocused returns whether the button is focused -func (b *ButtonCmp) IsFocused() bool { - return b.state == ButtonHovered -} - -// Init initializes the button -func (b *ButtonCmp) Init() tea.Cmd { - return nil -} - -// Update handles messages and user input -func (b *ButtonCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - // Skip updates if disabled - if b.state == ButtonDisabled { - return b, nil - } - - switch msg := msg.(type) { - case tea.KeyMsg: - // Handle key presses when focused - if b.state == ButtonHovered { - switch { - case key.Matches(msg, b.keyMap.Enter): - b.state = ButtonPressed - return b, func() tea.Msg { - return ButtonMsg{ - ID: b.id, - Payload: b.payload, - } - } - } - } - } - - return b, nil -} - -// View renders the button -func (b *ButtonCmp) View() string { - if b.state == ButtonHovered || b.state == ButtonPressed { - return b.hoverStyle.Render(b.label) - } - return b.style.Render(b.label) -} - diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index 2a6822c7edbd2ee5c8ef2db9500270271673bd46..0ed85dd6fc20f8859b550c9f4ead13b5876b8eec 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -10,6 +10,9 @@ import ( type SplitPaneLayout interface { tea.Model Sizeable + SetLeftPanel(panel Container) + SetRightPanel(panel Container) + SetBottomPanel(panel Container) } type splitPaneLayout struct { @@ -160,6 +163,27 @@ func (s *splitPaneLayout) GetSize() (int, int) { return s.width, s.height } +func (s *splitPaneLayout) SetLeftPanel(panel Container) { + s.leftPanel = panel + if s.width > 0 && s.height > 0 { + s.SetSize(s.width, s.height) + } +} + +func (s *splitPaneLayout) SetRightPanel(panel Container) { + s.rightPanel = panel + if s.width > 0 && s.height > 0 { + s.SetSize(s.width, s.height) + } +} + +func (s *splitPaneLayout) SetBottomPanel(panel Container) { + s.bottomPanel = panel + if s.width > 0 && s.height > 0 { + s.SetSize(s.width, s.height) + } +} + func (s *splitPaneLayout) BindingKeys() []key.Binding { keys := []key.Binding{} if s.leftPanel != nil { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index de5b3910fec4e0efbca11a059cfeea33a235c953..7ac0d2293f5b49bed1fee844a43319266c5c5263 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -3,28 +3,100 @@ package page import ( tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/components/chat" "github.com/kujtimiihoxha/termai/internal/tui/layout" + "github.com/kujtimiihoxha/termai/internal/tui/util" ) var ChatPage PageID = "chat" -func NewChatPage(app *app.App) tea.Model { - messagesContainer := layout.NewContainer( - chat.NewMessagesCmp(), +type chatPage struct { + app *app.App + layout layout.SplitPaneLayout + session session.Session +} + +func (p *chatPage) Init() tea.Cmd { + return p.layout.Init() +} + +func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + p.layout.SetSize(msg.Width, msg.Height) + case chat.SendMsg: + cmd := p.sendMessage(msg.Text) + if cmd != nil { + return p, cmd + } + } + u, cmd := p.layout.Update(msg) + p.layout = u.(layout.SplitPaneLayout) + if cmd != nil { + return p, cmd + } + return p, nil +} + +func (p *chatPage) setSidebar() tea.Cmd { + sidebarContainer := layout.NewContainer( + chat.NewSidebarCmp(p.session), layout.WithPadding(1, 1, 1, 1), ) - sidebarContainer := layout.NewContainer( - chat.NewSidebarCmp(), + p.layout.SetRightPanel(sidebarContainer) + width, height := p.layout.GetSize() + p.layout.SetSize(width, height) + return sidebarContainer.Init() +} + +func (p *chatPage) sendMessage(text string) tea.Cmd { + var cmds []tea.Cmd + if p.session.ID == "" { + session, err := p.app.Sessions.Create("New Session") + if err != nil { + return util.ReportError(err) + } + + p.session = session + cmd := p.setSidebar() + if cmd != nil { + cmds = append(cmds, cmd) + } + cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session))) + } + // TODO: actually call agent + p.app.Messages.Create(p.session.ID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{ + Text: text, + }, + }, + }) + return tea.Batch(cmds...) +} + +func (p *chatPage) View() string { + return p.layout.View() +} + +func NewChatPage(app *app.App) tea.Model { + messagesContainer := layout.NewContainer( + chat.NewMessagesCmp(app), layout.WithPadding(1, 1, 1, 1), ) + editorContainer := layout.NewContainer( chat.NewEditorCmp(), layout.WithBorder(true, false, false, false), ) - return layout.NewSplitPane( - layout.WithRightPanel(sidebarContainer), - layout.WithLeftPanel(messagesContainer), - layout.WithBottomPanel(editorContainer), - ) + return &chatPage{ + app: app, + layout: layout.NewSplitPane( + layout.WithLeftPanel(messagesContainer), + layout.WithBottomPanel(editorContainer), + ), + } } diff --git a/internal/tui/styles/markdown.go b/internal/tui/styles/markdown.go index 77dc314f51343797210093b7d83d90d01fe408ea..b4e71c51ef6615b9638156ac30129b1905669585 100644 --- a/internal/tui/styles/markdown.go +++ b/internal/tui/styles/markdown.go @@ -36,12 +36,13 @@ var catppuccinDark = ansi.StyleConfig{ Italic: boolPtr(true), Prefix: "┃ ", }, - Indent: uintPtr(1), - Margin: uintPtr(defaultMargin), + Indent: uintPtr(1), + IndentToken: stringPtr(BaseStyle.Render(" ")), }, List: ansi.StyleList{ LevelIndent: defaultMargin, StyleBlock: ansi.StyleBlock{ + IndentToken: stringPtr(BaseStyle.Render(" ")), StylePrimitive: ansi.StylePrimitive{ Color: stringPtr(dark.Text().Hex), }, @@ -496,3 +497,444 @@ var catppuccinLight = ansi.StyleConfig{ Color: stringPtr(light.Sapphire().Hex), }, } + +func MarkdownTheme(focused bool) ansi.StyleConfig { + if !focused { + return ASCIIStyleConfig + } else { + return DraculaStyleConfig + } +} + +const ( + defaultListIndent = 2 + defaultListLevelIndent = 4 +) + +var ASCIIStyleConfig = ansi.StyleConfig{ + Document: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + Indent: uintPtr(1), + IndentToken: stringPtr(BaseStyle.Render(" ")), + }, + BlockQuote: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + Indent: uintPtr(1), + IndentToken: stringPtr("| "), + }, + Paragraph: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + }, + List: ansi.StyleList{ + StyleBlock: ansi.StyleBlock{ + IndentToken: stringPtr(BaseStyle.Render(" ")), + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + }, + LevelIndent: defaultListLevelIndent, + }, + Heading: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + BlockSuffix: "\n", + }, + }, + H1: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + Prefix: "# ", + }, + }, + H2: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + Prefix: "## ", + }, + }, + H3: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + Prefix: "### ", + }, + }, + H4: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + Prefix: "#### ", + }, + }, + H5: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + Prefix: "##### ", + }, + }, + H6: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + Prefix: "###### ", + }, + }, + Strikethrough: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + BlockPrefix: "~~", + BlockSuffix: "~~", + }, + Emph: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + BlockPrefix: "*", + BlockSuffix: "*", + }, + Strong: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + BlockPrefix: "**", + BlockSuffix: "**", + }, + HorizontalRule: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + Format: "\n--------\n", + }, + Item: ansi.StylePrimitive{ + BlockPrefix: "• ", + BackgroundColor: stringPtr(Background.Dark), + }, + Enumeration: ansi.StylePrimitive{ + BlockPrefix: ". ", + BackgroundColor: stringPtr(Background.Dark), + }, + Task: ansi.StyleTask{ + Ticked: "[x] ", + Unticked: "[ ] ", + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + }, + ImageText: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + Format: "Image: {{.text}} →", + }, + Code: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BlockPrefix: "`", + BlockSuffix: "`", + BackgroundColor: stringPtr(Background.Dark), + }, + }, + CodeBlock: ansi.StyleCodeBlock{ + StyleBlock: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + Margin: uintPtr(defaultMargin), + }, + }, + Table: ansi.StyleTable{ + StyleBlock: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + IndentToken: stringPtr(BaseStyle.Render(" ")), + }, + CenterSeparator: stringPtr("|"), + ColumnSeparator: stringPtr("|"), + RowSeparator: stringPtr("-"), + }, + DefinitionDescription: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + BlockPrefix: "\n* ", + }, +} + +var DraculaStyleConfig = ansi.StyleConfig{ + Document: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Color: stringPtr(Forground.Dark), + BackgroundColor: stringPtr(Background.Dark), + }, + Indent: uintPtr(defaultMargin), + IndentToken: stringPtr(BaseStyle.Render(" ")), + }, + BlockQuote: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Color: stringPtr("#f1fa8c"), + Italic: boolPtr(true), + BackgroundColor: stringPtr(Background.Dark), + }, + Indent: uintPtr(defaultMargin), + IndentToken: stringPtr(BaseStyle.Render(" ")), + }, + Paragraph: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + }, + List: ansi.StyleList{ + LevelIndent: defaultMargin, + StyleBlock: ansi.StyleBlock{ + IndentToken: stringPtr(BaseStyle.Render(" ")), + StylePrimitive: ansi.StylePrimitive{ + Color: stringPtr(Forground.Dark), + BackgroundColor: stringPtr(Background.Dark), + }, + }, + }, + Heading: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BlockSuffix: "\n", + Color: stringPtr("#bd93f9"), + Bold: boolPtr(true), + BackgroundColor: stringPtr(Background.Dark), + }, + }, + H1: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Prefix: "# ", + BackgroundColor: stringPtr(Background.Dark), + }, + }, + H2: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Prefix: "## ", + BackgroundColor: stringPtr(Background.Dark), + }, + }, + H3: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Prefix: "### ", + BackgroundColor: stringPtr(Background.Dark), + }, + }, + H4: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Prefix: "#### ", + BackgroundColor: stringPtr(Background.Dark), + }, + }, + H5: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Prefix: "##### ", + BackgroundColor: stringPtr(Background.Dark), + }, + }, + H6: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Prefix: "###### ", + BackgroundColor: stringPtr(Background.Dark), + }, + }, + Strikethrough: ansi.StylePrimitive{ + CrossedOut: boolPtr(true), + BackgroundColor: stringPtr(Background.Dark), + }, + Emph: ansi.StylePrimitive{ + Color: stringPtr("#f1fa8c"), + Italic: boolPtr(true), + BackgroundColor: stringPtr(Background.Dark), + }, + Strong: ansi.StylePrimitive{ + Bold: boolPtr(true), + Color: stringPtr("#ffb86c"), + BackgroundColor: stringPtr(Background.Dark), + }, + HorizontalRule: ansi.StylePrimitive{ + Color: stringPtr("#6272A4"), + Format: "\n--------\n", + BackgroundColor: stringPtr(Background.Dark), + }, + Item: ansi.StylePrimitive{ + BlockPrefix: "• ", + BackgroundColor: stringPtr(Background.Dark), + }, + Enumeration: ansi.StylePrimitive{ + BlockPrefix: ". ", + Color: stringPtr("#8be9fd"), + BackgroundColor: stringPtr(Background.Dark), + }, + Task: ansi.StyleTask{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + Ticked: "[✓] ", + Unticked: "[ ] ", + }, + Link: ansi.StylePrimitive{ + Color: stringPtr("#8be9fd"), + Underline: boolPtr(true), + BackgroundColor: stringPtr(Background.Dark), + }, + LinkText: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + BackgroundColor: stringPtr(Background.Dark), + }, + Image: ansi.StylePrimitive{ + Color: stringPtr("#8be9fd"), + Underline: boolPtr(true), + BackgroundColor: stringPtr(Background.Dark), + }, + ImageText: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + Format: "Image: {{.text}} →", + BackgroundColor: stringPtr(Background.Dark), + }, + Code: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Color: stringPtr("#50fa7b"), + BackgroundColor: stringPtr(Background.Dark), + }, + }, + Text: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + DefinitionList: ansi.StyleBlock{}, + CodeBlock: ansi.StyleCodeBlock{ + StyleBlock: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + Color: stringPtr("#ffb86c"), + BackgroundColor: stringPtr(Background.Dark), + }, + Margin: uintPtr(defaultMargin), + }, + Chroma: &ansi.Chroma{ + NameOther: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + Literal: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + NameException: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + LiteralDate: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + Text: ansi.StylePrimitive{ + Color: stringPtr(Forground.Dark), + BackgroundColor: stringPtr(Background.Dark), + }, + Error: ansi.StylePrimitive{ + Color: stringPtr("#f8f8f2"), + BackgroundColor: stringPtr("#ff5555"), + }, + Comment: ansi.StylePrimitive{ + Color: stringPtr("#6272A4"), + BackgroundColor: stringPtr(Background.Dark), + }, + CommentPreproc: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + BackgroundColor: stringPtr(Background.Dark), + }, + Keyword: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + BackgroundColor: stringPtr(Background.Dark), + }, + KeywordReserved: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + BackgroundColor: stringPtr(Background.Dark), + }, + KeywordNamespace: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + BackgroundColor: stringPtr(Background.Dark), + }, + KeywordType: ansi.StylePrimitive{ + Color: stringPtr("#8be9fd"), + BackgroundColor: stringPtr(Background.Dark), + }, + Operator: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + BackgroundColor: stringPtr(Background.Dark), + }, + Punctuation: ansi.StylePrimitive{ + Color: stringPtr(Forground.Dark), + BackgroundColor: stringPtr(Background.Dark), + }, + Name: ansi.StylePrimitive{ + Color: stringPtr("#8be9fd"), + BackgroundColor: stringPtr(Background.Dark), + }, + NameBuiltin: ansi.StylePrimitive{ + Color: stringPtr("#8be9fd"), + BackgroundColor: stringPtr(Background.Dark), + }, + NameTag: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + BackgroundColor: stringPtr(Background.Dark), + }, + NameAttribute: ansi.StylePrimitive{ + Color: stringPtr("#50fa7b"), + BackgroundColor: stringPtr(Background.Dark), + }, + NameClass: ansi.StylePrimitive{ + Color: stringPtr("#8be9fd"), + BackgroundColor: stringPtr(Background.Dark), + }, + NameConstant: ansi.StylePrimitive{ + Color: stringPtr("#bd93f9"), + BackgroundColor: stringPtr(Background.Dark), + }, + NameDecorator: ansi.StylePrimitive{ + Color: stringPtr("#50fa7b"), + BackgroundColor: stringPtr(Background.Dark), + }, + NameFunction: ansi.StylePrimitive{ + Color: stringPtr("#50fa7b"), + BackgroundColor: stringPtr(Background.Dark), + }, + LiteralNumber: ansi.StylePrimitive{ + Color: stringPtr("#6EEFC0"), + BackgroundColor: stringPtr(Background.Dark), + }, + LiteralString: ansi.StylePrimitive{ + Color: stringPtr("#f1fa8c"), + BackgroundColor: stringPtr(Background.Dark), + }, + LiteralStringEscape: ansi.StylePrimitive{ + Color: stringPtr("#ff79c6"), + BackgroundColor: stringPtr(Background.Dark), + }, + GenericDeleted: ansi.StylePrimitive{ + Color: stringPtr("#ff5555"), + BackgroundColor: stringPtr(Background.Dark), + }, + GenericEmph: ansi.StylePrimitive{ + Color: stringPtr("#f1fa8c"), + Italic: boolPtr(true), + BackgroundColor: stringPtr(Background.Dark), + }, + GenericInserted: ansi.StylePrimitive{ + Color: stringPtr("#50fa7b"), + BackgroundColor: stringPtr(Background.Dark), + }, + GenericStrong: ansi.StylePrimitive{ + Color: stringPtr("#ffb86c"), + Bold: boolPtr(true), + BackgroundColor: stringPtr(Background.Dark), + }, + GenericSubheading: ansi.StylePrimitive{ + Color: stringPtr("#bd93f9"), + BackgroundColor: stringPtr(Background.Dark), + }, + Background: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + }, + }, + Table: ansi.StyleTable{ + StyleBlock: ansi.StyleBlock{ + StylePrimitive: ansi.StylePrimitive{ + BackgroundColor: stringPtr(Background.Dark), + }, + IndentToken: stringPtr(BaseStyle.Render(" ")), + }, + }, + DefinitionDescription: ansi.StylePrimitive{ + BlockPrefix: "\n* ", + BackgroundColor: stringPtr(Background.Dark), + }, +} From 0697dcc1d9c7330d8c9d8a2be0bb94b3d46c9345 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 12 Apr 2025 14:49:01 +0200 Subject: [PATCH 04/41] implement nested tool calls and initial setup for result metadata --- go.mod | 21 +- go.sum | 43 ++- internal/llm/agent/agent.go | 1 + internal/llm/tools/bash.go | 13 +- internal/llm/tools/tools.go | 23 +- internal/message/content.go | 5 +- internal/message/message.go | 1 + internal/tui/components/chat/editor.go | 15 +- internal/tui/components/chat/messages.go | 458 +++++++++++++++++------ internal/tui/components/chat/sidebar.go | 11 +- internal/tui/page/chat.go | 62 ++- internal/tui/styles/background.go | 81 ++++ internal/tui/styles/markdown.go | 7 +- internal/tui/styles/styles.go | 10 + 14 files changed, 584 insertions(+), 167 deletions(-) create mode 100644 internal/tui/styles/background.go diff --git a/go.mod b/go.mod index 63df37fba20eb74d1f2f24ccccb760acf9992bad..3b8bd99b1f42db365b1c50e483e8c4ad1f559837 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/charmbracelet/glamour v0.9.1 github.com/charmbracelet/huh v0.6.0 github.com/charmbracelet/lipgloss v1.1.0 + github.com/charmbracelet/x/ansi v0.8.0 github.com/fsnotify/fsnotify v1.8.0 github.com/go-logfmt/logfmt v0.6.0 github.com/golang-migrate/migrate/v4 v4.18.2 @@ -29,11 +30,11 @@ require ( github.com/muesli/reflow v0.3.0 github.com/muesli/termenv v0.16.0 github.com/openai/openai-go v0.1.0-beta.2 - github.com/sergi/go-diff v1.3.1 + github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 - golang.org/x/net v0.34.0 + golang.org/x/net v0.39.0 google.golang.org/api v0.215.0 ) @@ -64,7 +65,6 @@ require ( github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect - github.com/charmbracelet/x/ansi v0.8.0 // indirect github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect @@ -76,6 +76,7 @@ require ( github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect @@ -92,6 +93,7 @@ require ( github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -115,20 +117,21 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.design/x/clipboard v0.7.0 // indirect - golang.org/x/crypto v0.33.0 // indirect - golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/image v0.14.0 // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect golang.org/x/oauth2 v0.25.0 // indirect - golang.org/x/sync v0.12.0 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/term v0.30.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/sync v0.13.0 // indirect + golang.org/x/sys v0.32.0 // indirect + golang.org/x/term v0.31.0 // indirect + golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.8.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect google.golang.org/grpc v1.67.3 // indirect google.golang.org/protobuf v1.36.1 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c4b32ef32bed371a72ab4ba5f917753a77077d18..08e7e7c42e61f69fd454df0683507bb91ddd3cd9 100644 --- a/go.sum +++ b/go.sum @@ -117,8 +117,8 @@ github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -139,6 +139,7 @@ github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSo github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -189,8 +190,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= @@ -199,8 +200,9 @@ github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8 github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= +github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= +github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= @@ -261,10 +263,10 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= -golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= -golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 h1:bFYqOIMdeiCEdzPJkLiOoMDzW/v3tjW4AA/RmUZYsL8= golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8= golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= @@ -282,15 +284,15 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -304,8 +306,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -314,8 +316,8 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -323,8 +325,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -343,8 +345,9 @@ google.golang.org/grpc v1.67.3/go.mod h1:YGaHCc6Oap+FzBJTZLBzkGSYt/cvGPFTPxkn7Qf google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 998dc1551e8adc3bbc5facfb80338803a8f0afb0..b01ffec3cc312848450fec6e5c016ba377eeec7c 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -305,6 +305,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, + Model: c.model.ID, }) if err != nil { return err diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 4e80ae60a3e4de34da6a800f339be22ab2439785..d20afb7f28629db045452a9687e6b00e8acf5931 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools/shell" @@ -21,6 +22,9 @@ type BashPermissionsParams struct { Timeout int `json:"timeout"` } +type BashToolResponseMetadata struct { + Took int64 `json:"took"` +} type bashTool struct { permissions permission.Service } @@ -272,11 +276,13 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("permission denied"), nil } } + startTime := time.Now() shell := shell.GetPersistentShell(config.WorkingDirectory()) stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil } + took := time.Since(startTime).Milliseconds() stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) @@ -304,10 +310,13 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout += "\n" + errorMessage } + metadata := BashToolResponseMetadata{ + Took: took, + } if stdout == "" { - return NewTextResponse("no output"), nil + return WithResponseMetadata(NewTextResponse("no output"), metadata), nil } - return NewTextResponse(stdout), nil + return WithResponseMetadata(NewTextResponse(stdout), metadata), nil } func truncateOutput(content string) string { diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index e15c1c31f49484c41bf9621bebb7f4e10a466e57..6bb5286863128b26ca16a1d2965b3c4901be4471 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -1,6 +1,9 @@ package tools -import "context" +import ( + "context" + "encoding/json" +) type ToolInfo struct { Name string @@ -17,9 +20,10 @@ const ( ) type ToolResponse struct { - Type toolResponseType `json:"type"` - Content string `json:"content"` - IsError bool `json:"is_error"` + Type toolResponseType `json:"type"` + Content string `json:"content"` + Metadata string `json:"metadata,omitempty"` + IsError bool `json:"is_error"` } func NewTextResponse(content string) ToolResponse { @@ -29,6 +33,17 @@ func NewTextResponse(content string) ToolResponse { } } +func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse { + if metadata != nil { + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return response + } + response.Metadata = string(metadataBytes) + } + return response +} + func NewTextErrorResponse(content string) ToolResponse { return ToolResponse{ Type: ToolResponseTypeText, diff --git a/internal/message/content.go b/internal/message/content.go index cd263798b35e8fc1df3278dca1fff288c4a2806c..422c04f52ca0e5546986ea62b610f467502b08f1 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -3,6 +3,8 @@ package message import ( "encoding/base64" "time" + + "github.com/kujtimiihoxha/termai/internal/llm/models" ) type MessageRole string @@ -65,7 +67,6 @@ type ToolCall struct { Name string `json:"name"` Input string `json:"input"` Type string `json:"type"` - Metadata any `json:"metadata"` Finished bool `json:"finished"` } @@ -75,6 +76,7 @@ type ToolResult struct { ToolCallID string `json:"tool_call_id"` Name string `json:"name"` Content string `json:"content"` + Metadata string `json:"metadata"` IsError bool `json:"is_error"` } @@ -92,6 +94,7 @@ type Message struct { Role MessageRole SessionID string Parts []ContentPart + Model models.ModelID CreatedAt int64 UpdatedAt int64 diff --git a/internal/message/message.go b/internal/message/message.go index eeeb83ed2e8cebc6f69bc569cf0f9f2784879307..06dae13a57a8ae1ca7ede1e5fd22be6bea5ee669 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -155,6 +155,7 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { SessionID: item.SessionID, Role: MessageRole(item.Role), Parts: parts, + Model: models.ModelID(item.Model.String), CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, }, nil diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index df336818ce1b28fe76fed100c7eeba68bc4772f8..e87f1ffae79914fcacccac5dd5bc336c31a0b980 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -77,21 +77,20 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case AgentWorkingMsg: m.agentWorking = bool(msg) case tea.KeyMsg: - if key.Matches(msg, focusedKeyMaps.Send) { + // if the key does not match any binding, return + if m.textarea.Focused() && key.Matches(msg, focusedKeyMaps.Send) { return m, m.send() } - if key.Matches(msg, bluredKeyMaps.Send) { + if !m.textarea.Focused() && key.Matches(msg, bluredKeyMaps.Send) { return m, m.send() } - if key.Matches(msg, focusedKeyMaps.Blur) { + if m.textarea.Focused() && key.Matches(msg, focusedKeyMaps.Blur) { m.textarea.Blur() return m, util.CmdHandler(EditorFocusMsg(false)) } - if key.Matches(msg, bluredKeyMaps.Focus) { - if !m.textarea.Focused() { - m.textarea.Focus() - return m, tea.Batch(textarea.Blink, util.CmdHandler(EditorFocusMsg(true))) - } + if !m.textarea.Focused() && key.Matches(msg, bluredKeyMaps.Focus) { + m.textarea.Focus() + return m, tea.Batch(textarea.Blink, util.CmdHandler(EditorFocusMsg(true))) } } m.textarea, cmd = m.textarea.Update(msg) diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index 0a7e6e2a499e4d14ad19a327ba8de4dbfb1e59cd..b5a36139239b08f514a36d3a07d308f107a67520 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -1,16 +1,21 @@ package chat import ( + "encoding/json" "fmt" - "regexp" - "strconv" + "math" "strings" + "github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/x/ansi" "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/llm/agent" + "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" @@ -18,10 +23,20 @@ import ( "github.com/kujtimiihoxha/termai/internal/tui/util" ) +type uiMessageType int + +const ( + userMessageType uiMessageType = iota + assistantMessageType + toolMessageType +) + type uiMessage struct { - position int - height int - content string + ID string + messageType uiMessageType + position int + height int + content string } type messagesCmp struct { @@ -32,141 +47,116 @@ type messagesCmp struct { session session.Session messages []message.Message uiMessages []uiMessage - currentIndex int + currentMsgID string renderer *glamour.TermRenderer focusRenderer *glamour.TermRenderer cachedContent map[string]string + agentWorking bool + spinner spinner.Model + needsRerender bool + lastViewport string } func (m *messagesCmp) Init() tea.Cmd { - return m.viewport.Init() -} - -var ansiEscape = regexp.MustCompile("\x1b\\[[0-9;]*m") - -func hexToBgSGR(hex string) (string, error) { - hex = strings.TrimPrefix(hex, "#") - if len(hex) != 6 { - return "", fmt.Errorf("invalid hex color: must be 6 hexadecimal digits") - } - - // Parse RGB components in one block - rgb := make([]uint64, 3) - for i := 0; i < 3; i++ { - val, err := strconv.ParseUint(hex[i*2:i*2+2], 16, 8) - if err != nil { - return "", err - } - rgb[i] = val - } - - return fmt.Sprintf("48;2;%d;%d;%d", rgb[0], rgb[1], rgb[2]), nil -} - -func forceReplaceBackgroundColors(input string, newBg string) string { - return ansiEscape.ReplaceAllStringFunc(input, func(seq string) string { - // Extract content between "\x1b[" and "m" - content := seq[2 : len(seq)-1] - tokens := strings.Split(content, ";") - var newTokens []string - - // Skip background color tokens - for i := 0; i < len(tokens); i++ { - if tokens[i] == "" { - continue - } - - val, err := strconv.Atoi(tokens[i]) - if err != nil { - newTokens = append(newTokens, tokens[i]) - continue - } - - // Skip background color tokens - if val == 48 { - // Skip "48;5;N" or "48;2;R;G;B" sequences - if i+1 < len(tokens) { - if nextVal, err := strconv.Atoi(tokens[i+1]); err == nil { - switch nextVal { - case 5: - i += 2 // Skip "5" and color index - case 2: - i += 4 // Skip "2" and RGB components - } - } - } - } else if (val < 40 || val > 47) && (val < 100 || val > 107) && val != 49 { - // Keep non-background tokens - newTokens = append(newTokens, tokens[i]) - } - } - - // Add new background if provided - if newBg != "" { - newTokens = append(newTokens, strings.Split(newBg, ";")...) - } - - if len(newTokens) == 0 { - return "" - } - - return "\x1b[" + strings.Join(newTokens, ";") + "m" - }) + return tea.Batch(m.viewport.Init()) } func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd switch msg := msg.(type) { + case AgentWorkingMsg: + m.agentWorking = bool(msg) + if m.agentWorking { + cmds = append(cmds, m.spinner.Tick) + } case EditorFocusMsg: m.writingMode = bool(msg) case SessionSelectedMsg: if msg.ID != m.session.ID { cmd := m.SetSession(msg) + m.needsRerender = true return m, cmd } return m, nil + case SessionClearedMsg: + m.session = session.Session{} + m.messages = make([]message.Message, 0) + m.currentMsgID = "" + m.needsRerender = true + return m, nil + + case tea.KeyMsg: + if m.writingMode { + return m, nil + } case pubsub.Event[message.Message]: if msg.Type == pubsub.CreatedEvent { if msg.Payload.SessionID == m.session.ID { // check if message exists + + messageExists := false for _, v := range m.messages { if v.ID == msg.Payload.ID { - return m, nil + messageExists = true + break } } - m.messages = append(m.messages, msg.Payload) - m.renderView() - m.viewport.GotoBottom() + if !messageExists { + m.messages = append(m.messages, msg.Payload) + delete(m.cachedContent, m.currentMsgID) + m.currentMsgID = msg.Payload.ID + m.needsRerender = true + } } for _, v := range m.messages { for _, c := range v.ToolCalls() { // the message is being added to the session of a tool called if c.ID == msg.Payload.SessionID { - m.renderView() - m.viewport.GotoBottom() + m.needsRerender = true } } } } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { for i, v := range m.messages { if v.ID == msg.Payload.ID { + if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" { + cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false))) + } m.messages[i] = msg.Payload delete(m.cachedContent, msg.Payload.ID) - m.renderView() - if i == len(m.messages)-1 { - m.viewport.GotoBottom() - } + m.needsRerender = true break } } } } + if m.agentWorking { + u, cmd := m.spinner.Update(msg) + m.spinner = u + cmds = append(cmds, cmd) + } + oldPos := m.viewport.YPosition u, cmd := m.viewport.Update(msg) m.viewport = u - return m, cmd + m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos + cmds = append(cmds, cmd) + if m.needsRerender { + m.renderView() + if len(m.messages) > 0 { + if msg, ok := msg.(pubsub.Event[message.Message]); ok { + if (msg.Type == pubsub.CreatedEvent) || + (msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) { + m.viewport.GotoBottom() + } + } + } + m.needsRerender = false + } + return m, tea.Batch(cmds...) } -func (m *messagesCmp) renderUserMessage(inx int, msg message.Message) string { +func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string { if v, ok := m.cachedContent[msg.ID]; ok { return v } @@ -178,7 +168,7 @@ func (m *messagesCmp) renderUserMessage(inx int, msg message.Message) string { BorderStyle(lipgloss.ThickBorder()) renderer := m.renderer - if inx == m.currentIndex { + if msg.ID == m.currentMsgID { style = style. Foreground(styles.Forground). BorderForeground(styles.Blue). @@ -186,33 +176,269 @@ func (m *messagesCmp) renderUserMessage(inx int, msg message.Message) string { renderer = m.focusRenderer } c, _ := renderer.Render(msg.Content().String()) - col, _ := hexToBgSGR(styles.Background.Dark) - rendered := style.Render(forceReplaceBackgroundColors(c, col)) + parts := []string{ + styles.ForceReplaceBackgroundWithLipgloss(c, styles.Background), + } + // remove newline at the end + parts[0] = strings.TrimSuffix(parts[0], "\n") + if len(info) > 0 { + parts = append(parts, info...) + } + rendered := style.Render( + lipgloss.JoinVertical( + lipgloss.Left, + parts..., + ), + ) m.cachedContent[msg.ID] = rendered return rendered } +func formatTimeDifference(unixTime1, unixTime2 int64) string { + diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1))) + + if diffSeconds < 60 { + return fmt.Sprintf("%.1fs", diffSeconds) + } + + minutes := int(diffSeconds / 60) + seconds := int(diffSeconds) % 60 + return fmt.Sprintf("%dm%ds", minutes, seconds) +} + +func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string { + key := "" + value := "" + switch toolCall.Name { + // TODO: add result data to the tools + case agent.AgentToolName: + key = "Task" + var params agent.AgentParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + value = params.Prompt + // TODO: handle nested calls + case tools.BashToolName: + key = "Bash" + var params tools.BashParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + value = params.Command + case tools.EditToolName: + key = "Edit" + var params tools.EditParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + value = params.FilePath + case tools.FetchToolName: + key = "Fetch" + var params tools.FetchParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + value = params.URL + case tools.GlobToolName: + key = "Glob" + var params tools.GlobParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + if params.Path == "" { + params.Path = "." + } + value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) + case tools.GrepToolName: + key = "Grep" + var params tools.GrepParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + if params.Path == "" { + params.Path = "." + } + value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) + case tools.LSToolName: + key = "Ls" + var params tools.LSParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + if params.Path == "" { + params.Path = "." + } + value = params.Path + case tools.SourcegraphToolName: + key = "Sourcegraph" + var params tools.SourcegraphParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + value = params.Query + case tools.ViewToolName: + key = "View" + var params tools.ViewParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + value = params.FilePath + case tools.WriteToolName: + key = "Write" + var params tools.WriteParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + value = params.FilePath + default: + key = toolCall.Name + var params map[string]any + json.Unmarshal([]byte(toolCall.Input), ¶ms) + jsonData, _ := json.Marshal(params) + value = string(jsonData) + } + + style := styles.BaseStyle. + Width(m.width). + BorderLeft(true). + BorderStyle(lipgloss.ThickBorder()). + PaddingLeft(1). + BorderForeground(styles.Yellow) + + keyStyle := styles.BaseStyle. + Foreground(styles.ForgroundDim) + valyeStyle := styles.BaseStyle. + Foreground(styles.Forground) + + if isNested { + valyeStyle = valyeStyle.Foreground(styles.ForgroundMid) + } + keyValye := keyStyle.Render( + fmt.Sprintf("%s: ", key), + ) + if !isNested { + value = valyeStyle. + Width(m.width - lipgloss.Width(keyValye) - 2). + Render( + ansi.Truncate( + value, + m.width-lipgloss.Width(keyValye)-2, + "...", + ), + ) + } else { + keyValye = keyStyle.Render( + fmt.Sprintf(" └ %s: ", key), + ) + value = valyeStyle. + Width(m.width - lipgloss.Width(keyValye) - 2). + Render( + ansi.Truncate( + value, + m.width-lipgloss.Width(keyValye)-2, + "...", + ), + ) + } + + innerToolCalls := make([]string, 0) + if toolCall.Name == agent.AgentToolName { + messages, _ := m.app.Messages.List(toolCall.ID) + toolCalls := make([]message.ToolCall, 0) + for _, v := range messages { + toolCalls = append(toolCalls, v.ToolCalls()...) + } + for _, v := range toolCalls { + call := m.renderToolCall(v, true) + innerToolCalls = append(innerToolCalls, call) + } + } + + if isNested { + return lipgloss.JoinHorizontal( + lipgloss.Left, + keyValye, + value, + ) + } + callContent := lipgloss.JoinHorizontal( + lipgloss.Left, + keyValye, + value, + ) + callContent = strings.ReplaceAll(callContent, "\n", "") + if len(innerToolCalls) > 0 { + callContent = lipgloss.JoinVertical( + lipgloss.Left, + callContent, + lipgloss.JoinVertical( + lipgloss.Left, + innerToolCalls..., + ), + ) + } + return style.Render(callContent) +} + +func (m *messagesCmp) renderAssistantMessage(msg message.Message) []uiMessage { + // find the user message that is before this assistant message + var userMsg message.Message + for i := len(m.messages) - 1; i >= 0; i-- { + if m.messages[i].Role == message.User { + userMsg = m.messages[i] + break + } + } + messages := make([]uiMessage, 0) + if msg.Content().String() != "" { + info := make([]string, 0) + if msg.IsFinished() && msg.FinishReason() == "end_turn" { + finish := msg.FinishPart() + took := formatTimeDifference(userMsg.CreatedAt, finish.Time) + + info = append(info, styles.BaseStyle.Width(m.width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took), + )) + } + content := m.renderSimpleMessage(msg, info...) + messages = append(messages, uiMessage{ + messageType: assistantMessageType, + position: 0, // gets updated in renderView + height: lipgloss.Height(content), + content: content, + }) + } + for _, v := range msg.ToolCalls() { + content := m.renderToolCall(v, false) + messages = append(messages, + uiMessage{ + messageType: toolMessageType, + position: 0, // gets updated in renderView + height: lipgloss.Height(content), + content: content, + }, + ) + } + + return messages +} + func (m *messagesCmp) renderView() { m.uiMessages = make([]uiMessage, 0) pos := 0 for _, v := range m.messages { - content := "" switch v.Role { case message.User: - content = m.renderUserMessage(pos, v) + content := m.renderSimpleMessage(v) + m.uiMessages = append(m.uiMessages, uiMessage{ + messageType: userMessageType, + position: pos, + height: lipgloss.Height(content), + content: content, + }) + pos += lipgloss.Height(content) + 1 // + 1 for spacing + case message.Assistant: + assistantMessages := m.renderAssistantMessage(v) + for _, msg := range assistantMessages { + msg.position = pos + m.uiMessages = append(m.uiMessages, msg) + pos += msg.height + 1 // + 1 for spacing + } + } - m.uiMessages = append(m.uiMessages, uiMessage{ - position: pos, - height: lipgloss.Height(content), - content: content, - }) - pos += lipgloss.Height(content) + 1 // + 1 for spacing } messages := make([]string, 0) for _, v := range m.uiMessages { - messages = append(messages, v.content) + messages = append(messages, v.content, + styles.BaseStyle. + Width(m.width). + Render( + "", + ), + ) } m.viewport.SetContent( styles.BaseStyle. @@ -246,7 +472,6 @@ func (m *messagesCmp) View() string { ) } - m.renderView() return styles.BaseStyle. Width(m.width). Render( @@ -260,15 +485,21 @@ func (m *messagesCmp) View() string { func (m *messagesCmp) help() string { text := "" + + if m.agentWorking { + text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render( + fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."), + ) + } if m.writingMode { - text = lipgloss.JoinHorizontal( + text += lipgloss.JoinHorizontal( lipgloss.Left, styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"), styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"), ) } else { - text = lipgloss.JoinHorizontal( + text += lipgloss.JoinHorizontal( lipgloss.Left, styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"), @@ -306,7 +537,15 @@ func (m *messagesCmp) SetSize(width, height int) { glamour.WithWordWrap(width-1), ) m.focusRenderer = focusRenderer + // clear the cached content + for k := range m.cachedContent { + delete(m.cachedContent, k) + } m.renderer = renderer + if len(m.messages) > 0 { + m.renderView() + m.viewport.GotoBottom() + } } func (m *messagesCmp) GetSize() (int, int) { @@ -320,7 +559,8 @@ func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { return util.ReportError(err) } m.messages = messages - m.messages = append(m.messages, m.messages[0]) + m.currentMsgID = m.messages[len(m.messages)-1].ID + m.needsRerender = true return nil } @@ -333,6 +573,9 @@ func NewMessagesCmp(app *app.App) tea.Model { glamour.WithStyles(styles.MarkdownTheme(false)), glamour.WithWordWrap(80), ) + + s := spinner.New() + s.Spinner = spinner.Pulse return &messagesCmp{ app: app, writingMode: true, @@ -340,5 +583,6 @@ func NewMessagesCmp(app *app.App) tea.Model { viewport: viewport.New(0, 0), focusRenderer: focusRenderer, renderer: renderer, + spinner: s, } } diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index 65c06f4a168e04a8bfc1bdff37368034518ea1dc..51192cf9a6e7a58c94251ae1f608d834a79a49f2 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -5,6 +5,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/styles" ) @@ -19,6 +20,14 @@ func (m *sidebarCmp) Init() tea.Cmd { } func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case pubsub.Event[session.Session]: + if msg.Type == pubsub.UpdatedEvent { + if m.session.ID == msg.Payload.ID { + m.session = msg.Payload + } + } + } return m, nil } @@ -45,7 +54,7 @@ func (m *sidebarCmp) sessionSection() string { sessionValue := styles.BaseStyle. Foreground(styles.Forground). Width(m.width - lipgloss.Width(sessionKey)). - Render(": New Session") + Render(fmt.Sprintf(": %s", m.session.Title)) return lipgloss.JoinHorizontal( lipgloss.Left, sessionKey, diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index 7ac0d2293f5b49bed1fee844a43319266c5c5263..a7a51bb844640ee0bd0e819d7a78531ac81fdc94 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -1,9 +1,10 @@ package page import ( + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/components/chat" "github.com/kujtimiihoxha/termai/internal/tui/layout" @@ -18,8 +19,32 @@ type chatPage struct { session session.Session } +type ChatKeyMap struct { + NewSession key.Binding +} + +var keyMap = ChatKeyMap{ + NewSession: key.NewBinding( + key.WithKeys("ctrl+n"), + key.WithHelp("ctrl+n", "new session"), + ), +} + func (p *chatPage) Init() tea.Cmd { - return p.layout.Init() + // TODO: remove + cmds := []tea.Cmd{ + p.layout.Init(), + } + + sessions, _ := p.app.Sessions.List() + if len(sessions) > 0 { + p.session = sessions[0] + cmd := p.setSidebar() + cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(p.session)), cmd) + } + return tea.Batch( + cmds..., + ) } func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -31,6 +56,13 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if cmd != nil { return p, cmd } + case tea.KeyMsg: + switch { + case key.Matches(msg, keyMap.NewSession): + p.session = session.Session{} + p.clearSidebar() + return p, util.CmdHandler(chat.SessionClearedMsg{}) + } } u, cmd := p.layout.Update(msg) p.layout = u.(layout.SplitPaneLayout) @@ -51,6 +83,12 @@ func (p *chatPage) setSidebar() tea.Cmd { return sidebarContainer.Init() } +func (p *chatPage) clearSidebar() { + p.layout.SetRightPanel(nil) + width, height := p.layout.GetSize() + p.layout.SetSize(width, height) +} + func (p *chatPage) sendMessage(text string) tea.Cmd { var cmds []tea.Cmd if p.session.ID == "" { @@ -66,15 +104,15 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { } cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session))) } - // TODO: actually call agent - p.app.Messages.Create(p.session.ID, message.CreateMessageParams{ - Role: message.User, - Parts: []message.ContentPart{ - message.TextContent{ - Text: text, - }, - }, - }) + // TODO: move this to a service + a, err := agent.NewCoderAgent(p.app) + if err != nil { + return util.ReportError(err) + } + go func() { + a.Generate(p.app.Context, p.session.ID, text) + }() + return tea.Batch(cmds...) } @@ -85,7 +123,7 @@ func (p *chatPage) View() string { func NewChatPage(app *app.App) tea.Model { messagesContainer := layout.NewContainer( chat.NewMessagesCmp(app), - layout.WithPadding(1, 1, 1, 1), + layout.WithPadding(1, 1, 0, 1), ) editorContainer := layout.NewContainer( diff --git a/internal/tui/styles/background.go b/internal/tui/styles/background.go new file mode 100644 index 0000000000000000000000000000000000000000..bf6cbc1059f81d54cda3e7c3de10b925ab11a160 --- /dev/null +++ b/internal/tui/styles/background.go @@ -0,0 +1,81 @@ +package styles + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/charmbracelet/lipgloss" +) + +var ansiEscape = regexp.MustCompile("\x1b\\[[0-9;]*m") + +func getColorRGB(c lipgloss.TerminalColor) (uint8, uint8, uint8) { + r, g, b, a := c.RGBA() + + // Un-premultiply alpha if needed + if a > 0 && a < 0xffff { + r = (r * 0xffff) / a + g = (g * 0xffff) / a + b = (b * 0xffff) / a + } + + // Convert from 16-bit to 8-bit color + return uint8(r >> 8), uint8(g >> 8), uint8(b >> 8) +} + +func ForceReplaceBackgroundWithLipgloss(input string, newBgColor lipgloss.TerminalColor) string { + r, g, b := getColorRGB(newBgColor) + + newBg := fmt.Sprintf("48;2;%d;%d;%d", r, g, b) + + return ansiEscape.ReplaceAllStringFunc(input, func(seq string) string { + // Extract content between "\x1b[" and "m" + content := seq[2 : len(seq)-1] + tokens := strings.Split(content, ";") + var newTokens []string + + // Skip background color tokens + for i := 0; i < len(tokens); i++ { + if tokens[i] == "" { + continue + } + + val, err := strconv.Atoi(tokens[i]) + if err != nil { + newTokens = append(newTokens, tokens[i]) + continue + } + + // Skip background color tokens + if val == 48 { + // Skip "48;5;N" or "48;2;R;G;B" sequences + if i+1 < len(tokens) { + if nextVal, err := strconv.Atoi(tokens[i+1]); err == nil { + switch nextVal { + case 5: + i += 2 // Skip "5" and color index + case 2: + i += 4 // Skip "2" and RGB components + } + } + } + } else if (val < 40 || val > 47) && (val < 100 || val > 107) && val != 49 { + // Keep non-background tokens + newTokens = append(newTokens, tokens[i]) + } + } + + // Add new background if provided + if newBg != "" { + newTokens = append(newTokens, strings.Split(newBg, ";")...) + } + + if len(newTokens) == 0 { + return "" + } + + return "\x1b[" + strings.Join(newTokens, ";") + "m" + }) +} diff --git a/internal/tui/styles/markdown.go b/internal/tui/styles/markdown.go index b4e71c51ef6615b9638156ac30129b1905669585..52816eab3ac0d104a0add3cf935c4327f21fc469 100644 --- a/internal/tui/styles/markdown.go +++ b/internal/tui/styles/markdown.go @@ -515,6 +515,7 @@ var ASCIIStyleConfig = ansi.StyleConfig{ Document: ansi.StyleBlock{ StylePrimitive: ansi.StylePrimitive{ BackgroundColor: stringPtr(Background.Dark), + Color: stringPtr(ForgroundDim.Dark), }, Indent: uintPtr(1), IndentToken: stringPtr(BaseStyle.Render(" ")), @@ -688,7 +689,7 @@ var DraculaStyleConfig = ansi.StyleConfig{ Heading: ansi.StyleBlock{ StylePrimitive: ansi.StylePrimitive{ BlockSuffix: "\n", - Color: stringPtr("#bd93f9"), + Color: stringPtr(PrimaryColor.Dark), Bold: boolPtr(true), BackgroundColor: stringPtr(Background.Dark), }, @@ -740,7 +741,7 @@ var DraculaStyleConfig = ansi.StyleConfig{ }, Strong: ansi.StylePrimitive{ Bold: boolPtr(true), - Color: stringPtr("#ffb86c"), + Color: stringPtr(Blue.Dark), BackgroundColor: stringPtr(Background.Dark), }, HorizontalRule: ansi.StylePrimitive{ @@ -796,7 +797,7 @@ var DraculaStyleConfig = ansi.StyleConfig{ CodeBlock: ansi.StyleCodeBlock{ StyleBlock: ansi.StyleBlock{ StylePrimitive: ansi.StylePrimitive{ - Color: stringPtr("#ffb86c"), + Color: stringPtr(Blue.Dark), BackgroundColor: stringPtr(Background.Dark), }, Margin: uintPtr(defaultMargin), diff --git a/internal/tui/styles/styles.go b/internal/tui/styles/styles.go index 41863cf1b79a4d36ce9c3d27bb87d1c2b2ddedc4..476339b57a72157de2dd2035fe89a8b11fe56860 100644 --- a/internal/tui/styles/styles.go +++ b/internal/tui/styles/styles.go @@ -34,6 +34,11 @@ var ( Light: "#d3d3d3", } + ForgroundMid = lipgloss.AdaptiveColor{ + Dark: "#a0a0a0", + Light: "#a0a0a0", + } + ForgroundDim = lipgloss.AdaptiveColor{ Dark: "#737373", Light: "#737373", @@ -159,6 +164,11 @@ var ( Light: light.Peach().Hex, } + Yellow = lipgloss.AdaptiveColor{ + Dark: dark.Yellow().Hex, + Light: light.Yellow().Hex, + } + Primary = Blue Secondary = Mauve From bd2ec29b65e430f83f430db5fdc424c7d631989d Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 12 Apr 2025 18:45:36 +0200 Subject: [PATCH 05/41] add initial git support --- .gitignore | 1 + README.md | 13 + cmd/diff/main.go | 102 +++++++ cmd/git/main.go | 4 + cmd/root.go | 6 + go.mod | 20 +- go.sum | 57 +++- internal/assets/diff/themes/dark.json | 73 +++++ internal/assets/embed.go | 6 + internal/assets/write.go | 60 +++++ internal/git/diff.go | 265 +++++++++++++++++++ internal/llm/agent/agent.go | 3 + internal/llm/tools/bash.go | 4 +- internal/llm/tools/edit.go | 234 ++++++++-------- internal/llm/tools/edit_test.go | 48 ---- internal/llm/tools/file.go | 10 + internal/llm/tools/tools.go | 15 ++ internal/llm/tools/write.go | 29 +- internal/tui/components/dialog/permission.go | 17 +- 19 files changed, 791 insertions(+), 176 deletions(-) create mode 100644 cmd/diff/main.go create mode 100644 cmd/git/main.go create mode 100644 internal/assets/diff/themes/dark.json create mode 100644 internal/assets/embed.go create mode 100644 internal/assets/write.go create mode 100644 internal/git/diff.go diff --git a/.gitignore b/.gitignore index 894a451094749c5324c5eee6e583de1aed7f74d9..388f8b2cac6bb87ad9917595e84d5d98f007ad54 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,4 @@ debug.log .termai +internal/assets/diff/index.mjs diff --git a/README.md b/README.md index 3c8df7345374e524900411604ba86114ec277ae2..ebef72cad3500d910217239a015c9f2a780de034 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,9 @@ TermAI is built with a modular architecture: git clone https://github.com/kujtimiihoxha/termai.git cd termai +# Build the diff script first +go run cmd/diff/main.go + # Build go build -o termai @@ -109,6 +112,16 @@ go build -o termai ./termai ``` +### Important: Building the Diff Script + +Before building or running the application, you must first build the diff script by running: + +```bash +go run cmd/diff/main.go +``` + +This command generates the necessary JavaScript file (`index.mjs`) used by the diff functionality in the application. + ## Acknowledgments TermAI builds upon the work of several open source projects and developers: diff --git a/cmd/diff/main.go b/cmd/diff/main.go new file mode 100644 index 0000000000000000000000000000000000000000..da93e4660069912f4081495ab3d2eac8d93c0729 --- /dev/null +++ b/cmd/diff/main.go @@ -0,0 +1,102 @@ +package main + +import ( + "fmt" + "io" + "os" + "os/exec" + "path/filepath" +) + +func main() { + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "git-split-diffs") + if err != nil { + fmt.Printf("Error creating temp directory: %v\n", err) + os.Exit(1) + } + defer func() { + fmt.Printf("Cleaning up temporary directory: %s\n", tempDir) + os.RemoveAll(tempDir) + }() + fmt.Printf("Created temporary directory: %s\n", tempDir) + + // Clone the repository with minimum depth + fmt.Println("Cloning git-split-diffs repository with minimum depth...") + cmd := exec.Command("git", "clone", "--depth=1", "https://github.com/kujtimiihoxha/git-split-diffs", tempDir) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + fmt.Printf("Error cloning repository: %v\n", err) + os.Exit(1) + } + + // Run npm install + fmt.Println("Running npm install...") + cmdNpmInstall := exec.Command("npm", "install") + cmdNpmInstall.Dir = tempDir + cmdNpmInstall.Stdout = os.Stdout + cmdNpmInstall.Stderr = os.Stderr + if err := cmdNpmInstall.Run(); err != nil { + fmt.Printf("Error running npm install: %v\n", err) + os.Exit(1) + } + + // Run npm run build + fmt.Println("Running npm run build...") + cmdNpmBuild := exec.Command("npm", "run", "build") + cmdNpmBuild.Dir = tempDir + cmdNpmBuild.Stdout = os.Stdout + cmdNpmBuild.Stderr = os.Stderr + if err := cmdNpmBuild.Run(); err != nil { + fmt.Printf("Error running npm run build: %v\n", err) + os.Exit(1) + } + + destDir := filepath.Join(".", "internal", "assets", "diff") + destFile := filepath.Join(destDir, "index.mjs") + + // Make sure the destination directory exists + if err := os.MkdirAll(destDir, 0o755); err != nil { + fmt.Printf("Error creating destination directory: %v\n", err) + os.Exit(1) + } + + // Copy the file + srcFile := filepath.Join(tempDir, "build", "index.mjs") + fmt.Printf("Copying %s to %s\n", srcFile, destFile) + if err := copyFile(srcFile, destFile); err != nil { + fmt.Printf("Error copying file: %v\n", err) + os.Exit(1) + } + + fmt.Println("Successfully completed the process!") +} + +// copyFile copies a file from src to dst +func copyFile(src, dst string) error { + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + destFile, err := os.Create(dst) + if err != nil { + return err + } + defer destFile.Close() + + _, err = io.Copy(destFile, sourceFile) + if err != nil { + return err + } + + // Make sure the file is written to disk + err = destFile.Sync() + if err != nil { + return err + } + + return nil +} diff --git a/cmd/git/main.go b/cmd/git/main.go new file mode 100644 index 0000000000000000000000000000000000000000..da29a2cadf1e00b14b1a4bd0a52780888bf3e532 --- /dev/null +++ b/cmd/git/main.go @@ -0,0 +1,4 @@ +package main + +func main() { +} diff --git a/cmd/root.go b/cmd/root.go index bdab53e14969d3b4852d482e9c6443e596a03f87..d846a14c263790d9c27df4ad8aa151cc555c2e53 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,6 +8,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/assets" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/llm/agent" @@ -28,6 +29,9 @@ var rootCmd = &cobra.Command{ } debug, _ := cmd.Flags().GetBool("debug") err := config.Load(debug) + if err != nil { + return err + } cfg := config.Get() defaultLevel := slog.LevelInfo if cfg.Debug { @@ -38,9 +42,11 @@ var rootCmd = &cobra.Command{ })) slog.SetDefault(logger) + err = assets.WriteAssets() if err != nil { return err } + conn, err := db.Connect() if err != nil { return err diff --git a/go.mod b/go.mod index 3b8bd99b1f42db365b1c50e483e8c4ad1f559837..617dad3a1f541801dd8d2b8af1dd3254c2a69623 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/x/ansi v0.8.0 github.com/fsnotify/fsnotify v1.8.0 + github.com/go-git/go-git/v5 v5.15.0 github.com/go-logfmt/logfmt v0.6.0 github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/generative-ai-go v0.19.0 @@ -45,6 +46,9 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/longrunning v0.5.7 // indirect + dario.cat/mergo v1.0.0 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/ProtonMail/go-crypto v1.1.6 // indirect github.com/alecthomas/chroma/v2 v2.15.0 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect github.com/atotto/clipboard v0.1.4 // indirect @@ -68,15 +72,20 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/cloudflare/circl v1.6.1 // indirect + github.com/cyphar/filepath-securejoin v0.4.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/emirpasic/gods v1.18.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect + github.com/go-git/go-billy/v5 v5.6.2 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/google/go-cmp v0.7.0 // indirect + github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect @@ -84,6 +93,8 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect + github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -91,11 +102,12 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/pjbgf/sha1cd v0.3.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect github.com/spf13/cast v1.7.1 // indirect @@ -105,6 +117,7 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.7.8 // indirect @@ -118,7 +131,6 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.design/x/clipboard v0.7.0 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/image v0.14.0 // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect @@ -132,6 +144,6 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect google.golang.org/grpc v1.67.3 // indirect google.golang.org/protobuf v1.36.1 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 08e7e7c42e61f69fd454df0683507bb91ddd3cd9..9c2c2df8fbcb0d909a73a67ead79339bf3814892 100644 --- a/go.sum +++ b/go.sum @@ -10,10 +10,17 @@ cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4 cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5x+rHJnb1ssNmqpLH/k= github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= +github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/ProtonMail/go-crypto v1.1.6 h1:ZcV+Ropw6Qn0AX9brlQLAUXfqLBc7Bl+f/DmNxpLfdw= +github.com/ProtonMail/go-crypto v1.1.6/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4yPeE= github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= @@ -24,8 +31,12 @@ github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 h1:h7qxtumNjKPWFv1QM/HJy60MteeW23iKeEtBoY7bYZk= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= @@ -88,7 +99,11 @@ github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= +github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= +github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -96,6 +111,10 @@ github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yA github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= +github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -104,6 +123,16 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= +github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM= +github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= +github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= +github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= +github.com/go-git/go-git/v5 v5.15.0 h1:f5Qn0W0F7ry1iN0ZwIU5m/n7/BKB4hiZfc+zlZx7ly0= +github.com/go-git/go-git/v5 v5.15.0/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -115,6 +144,8 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -138,8 +169,11 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= +github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= +github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= +github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -179,11 +213,17 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/openai/openai-go v0.1.0-beta.2 h1:Ra5nCFkbEl9w+UJwAciC4kqnIBUCcJazhmMA0/YN894= github.com/openai/openai-go v0.1.0-beta.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= +github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= +github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -203,6 +243,9 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8= +github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= @@ -216,6 +259,7 @@ github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.20.0 h1:zrxIyR3RQIOsarIrgL8+sAvALXul9jeEPa06Y0Ph6vY= github.com/spf13/viper v1.20.0/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -232,6 +276,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= +github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= @@ -260,6 +306,7 @@ golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo= golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= @@ -277,6 +324,7 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= @@ -294,10 +342,14 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -320,6 +372,7 @@ golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= @@ -348,6 +401,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/assets/diff/themes/dark.json b/internal/assets/diff/themes/dark.json new file mode 100644 index 0000000000000000000000000000000000000000..05c18e08c327178af34397cd2aafe736cc99a93a --- /dev/null +++ b/internal/assets/diff/themes/dark.json @@ -0,0 +1,73 @@ +{ + "SYNTAX_HIGHLIGHTING_THEME": "dark-plus", + "DEFAULT_COLOR": { + "color": "#ffffff", + "backgroundColor": "#212121" + }, + "COMMIT_HEADER_COLOR": { + "color": "#cccccc" + }, + "COMMIT_HEADER_LABEL_COLOR": { + "color": "#00000022" + }, + "COMMIT_SHA_COLOR": { + "color": "#00eeaa" + }, + "COMMIT_AUTHOR_COLOR": { + "color": "#00aaee" + }, + "COMMIT_DATE_COLOR": { + "color": "#cccccc" + }, + "COMMIT_MESSAGE_COLOR": { + "color": "#cccccc" + }, + "COMMIT_TITLE_COLOR": { + "modifiers": [ + "bold" + ] + }, + "FILE_NAME_COLOR": { + "color": "#ffdd99" + }, + "BORDER_COLOR": { + "color": "#ffdd9966", + "modifiers": [ + "dim" + ] + }, + "HUNK_HEADER_COLOR": { + "modifiers": [ + "dim" + ] + }, + "DELETED_WORD_COLOR": { + "color": "#ffcccc", + "backgroundColor": "#ff000033" + }, + "INSERTED_WORD_COLOR": { + "color": "#ccffcc", + "backgroundColor": "#00ff0033" + }, + "DELETED_LINE_NO_COLOR": { + "color": "#00000022", + "backgroundColor": "#00000022" + }, + "INSERTED_LINE_NO_COLOR": { + "color": "#00000022", + "backgroundColor": "#00000022" + }, + "UNMODIFIED_LINE_NO_COLOR": { + "color": "#666666" + }, + "DELETED_LINE_COLOR": { + "color": "#cc6666", + "backgroundColor": "#3a3030" + }, + "INSERTED_LINE_COLOR": { + "color": "#66cc66", + "backgroundColor": "#303a30" + }, + "UNMODIFIED_LINE_COLOR": {}, + "MISSING_LINE_COLOR": {} +} diff --git a/internal/assets/embed.go b/internal/assets/embed.go new file mode 100644 index 0000000000000000000000000000000000000000..9e1316d08e8a5e0a4c9586db63a88d69ff767ae3 --- /dev/null +++ b/internal/assets/embed.go @@ -0,0 +1,6 @@ +package assets + +import "embed" + +//go:embed diff +var FS embed.FS diff --git a/internal/assets/write.go b/internal/assets/write.go new file mode 100644 index 0000000000000000000000000000000000000000..602b589ce353a994f4ea063f007639dc21b86fe5 --- /dev/null +++ b/internal/assets/write.go @@ -0,0 +1,60 @@ +package assets + +import ( + "os" + "path/filepath" + + "github.com/kujtimiihoxha/termai/internal/config" +) + +func WriteAssets() error { + appCfg := config.Get() + appWd := config.WorkingDirectory() + scriptDir := filepath.Join( + appWd, + appCfg.Data.Directory, + "diff", + ) + scriptPath := filepath.Join(scriptDir, "index.mjs") + // Before, run the script in cmd/diff/main.go to build this file + if _, err := os.Stat(scriptPath); err != nil { + scriptData, err := FS.ReadFile("diff/index.mjs") + if err != nil { + return err + } + + err = os.MkdirAll(scriptDir, 0o755) + if err != nil { + return err + } + err = os.WriteFile(scriptPath, scriptData, 0o755) + if err != nil { + return err + } + } + + themeDir := filepath.Join( + appWd, + appCfg.Data.Directory, + "themes", + ) + + themePath := filepath.Join(themeDir, "dark.json") + + if _, err := os.Stat(themePath); err != nil { + themeData, err := FS.ReadFile("diff/themes/dark.json") + if err != nil { + return err + } + + err = os.MkdirAll(themeDir, 0o755) + if err != nil { + return err + } + err = os.WriteFile(themePath, themeData, 0o755) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/git/diff.go b/internal/git/diff.go new file mode 100644 index 0000000000000000000000000000000000000000..d87956f0172ce92872db975a0d209e49cbbac7c5 --- /dev/null +++ b/internal/git/diff.go @@ -0,0 +1,265 @@ +package git + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/object" + "github.com/kujtimiihoxha/termai/internal/config" +) + +type DiffStats struct { + Additions int + Removals int +} + +func GenerateGitDiff(filePath string, contentBefore string, contentAfter string) (string, error) { + tempDir, err := os.MkdirTemp("", "git-diff-temp") + if err != nil { + return "", fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + repo, err := git.PlainInit(tempDir, false) + if err != nil { + return "", fmt.Errorf("failed to initialize git repo: %w", err) + } + + wt, err := repo.Worktree() + if err != nil { + return "", fmt.Errorf("failed to get worktree: %w", err) + } + + fullPath := filepath.Join(tempDir, filePath) + if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + return "", fmt.Errorf("failed to create directories: %w", err) + } + if err = os.WriteFile(fullPath, []byte(contentBefore), 0o644); err != nil { + return "", fmt.Errorf("failed to write 'before' content: %w", err) + } + + _, err = wt.Add(filePath) + if err != nil { + return "", fmt.Errorf("failed to add file to git: %w", err) + } + + beforeCommit, err := wt.Commit("Before", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "coder@opencode.ai", + When: time.Now(), + }, + }) + if err != nil { + return "", fmt.Errorf("failed to commit 'before' version: %w", err) + } + + if err = os.WriteFile(fullPath, []byte(contentAfter), 0o644); err != nil { + return "", fmt.Errorf("failed to write 'after' content: %w", err) + } + + _, err = wt.Add(filePath) + if err != nil { + return "", fmt.Errorf("failed to add updated file to git: %w", err) + } + + afterCommit, err := wt.Commit("After", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "coder@opencode.ai", + When: time.Now(), + }, + }) + if err != nil { + return "", fmt.Errorf("failed to commit 'after' version: %w", err) + } + + beforeCommitObj, err := repo.CommitObject(beforeCommit) + if err != nil { + return "", fmt.Errorf("failed to get 'before' commit: %w", err) + } + + afterCommitObj, err := repo.CommitObject(afterCommit) + if err != nil { + return "", fmt.Errorf("failed to get 'after' commit: %w", err) + } + + patch, err := beforeCommitObj.Patch(afterCommitObj) + if err != nil { + return "", fmt.Errorf("failed to generate patch: %w", err) + } + + return patch.String(), nil +} + +func GenerateGitDiffWithStats(filePath string, contentBefore string, contentAfter string) (string, DiffStats, error) { + tempDir, err := os.MkdirTemp("", "git-diff-temp") + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + repo, err := git.PlainInit(tempDir, false) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to initialize git repo: %w", err) + } + + wt, err := repo.Worktree() + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to get worktree: %w", err) + } + + fullPath := filepath.Join(tempDir, filePath) + if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + return "", DiffStats{}, fmt.Errorf("failed to create directories: %w", err) + } + if err = os.WriteFile(fullPath, []byte(contentBefore), 0o644); err != nil { + return "", DiffStats{}, fmt.Errorf("failed to write 'before' content: %w", err) + } + + _, err = wt.Add(filePath) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to add file to git: %w", err) + } + + beforeCommit, err := wt.Commit("Before", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "coder@opencode.ai", + When: time.Now(), + }, + }) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to commit 'before' version: %w", err) + } + + if err = os.WriteFile(fullPath, []byte(contentAfter), 0o644); err != nil { + return "", DiffStats{}, fmt.Errorf("failed to write 'after' content: %w", err) + } + + _, err = wt.Add(filePath) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to add updated file to git: %w", err) + } + + afterCommit, err := wt.Commit("After", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "coder@opencode.ai", + When: time.Now(), + }, + }) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to commit 'after' version: %w", err) + } + + beforeCommitObj, err := repo.CommitObject(beforeCommit) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to get 'before' commit: %w", err) + } + + afterCommitObj, err := repo.CommitObject(afterCommit) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to get 'after' commit: %w", err) + } + + patch, err := beforeCommitObj.Patch(afterCommitObj) + if err != nil { + return "", DiffStats{}, fmt.Errorf("failed to generate patch: %w", err) + } + + stats := DiffStats{} + for _, fileStat := range patch.Stats() { + stats.Additions += fileStat.Addition + stats.Removals += fileStat.Deletion + } + + return patch.String(), stats, nil +} + +func FormatDiff(diffText string, width int) (string, error) { + if isSplitDiffsAvailable() { + return formatWithSplitDiffs(diffText, width) + } + + return formatSimple(diffText), nil +} + +func isSplitDiffsAvailable() bool { + _, err := exec.LookPath("node") + return err == nil +} + +func formatWithSplitDiffs(diffText string, width int) (string, error) { + var cmd *exec.Cmd + + appCfg := config.Get() + appWd := config.WorkingDirectory() + script := filepath.Join( + appWd, + appCfg.Data.Directory, + "diff", + "index.mjs", + ) + + cmd = exec.Command("node", script, "--color") + + cmd.Env = append(os.Environ(), fmt.Sprintf("COLUMNS=%d", width)) + + cmd.Stdin = strings.NewReader(diffText) + + var out bytes.Buffer + cmd.Stdout = &out + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + return "", fmt.Errorf("git-split-diffs error: %v, stderr: %s", err, stderr.String()) + } + + return out.String(), nil +} + +func formatSimple(diffText string) string { + lines := strings.Split(diffText, "\n") + var result strings.Builder + + for _, line := range lines { + if len(line) == 0 { + result.WriteString("\n") + continue + } + + switch line[0] { + case '+': + result.WriteString("\033[32m" + line + "\033[0m\n") + case '-': + result.WriteString("\033[31m" + line + "\033[0m\n") + case '@': + result.WriteString("\033[36m" + line + "\033[0m\n") + case 'd': + if strings.HasPrefix(line, "diff --git") { + result.WriteString("\033[1m" + line + "\033[0m\n") + } else { + result.WriteString(line + "\n") + } + default: + result.WriteString(line + "\n") + } + } + + if !strings.HasSuffix(diffText, "\n") { + output := result.String() + return output[:len(output)-1] + } + + return result.String() +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index b01ffec3cc312848450fec6e5c016ba377eeec7c..89de627f7ebdaa42b43cbdfb8b610474768ab8a9 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -246,6 +246,7 @@ func (c *agent) handleToolExecution( } func (c *agent) generate(ctx context.Context, sessionID string, content string) error { + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) messages, err := c.Messages.List(sessionID) if err != nil { return err @@ -310,6 +311,8 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) if err != nil { return err } + + ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) for event := range eventChan { err = c.processEvent(sessionID, &assistantMsg, event) if err != nil { diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index d20afb7f28629db045452a9687e6b00e8acf5931..d55cb241b9c0ade5a8b49518ee6b42f7115dfc7e 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -22,7 +22,7 @@ type BashPermissionsParams struct { Timeout int `json:"timeout"` } -type BashToolResponseMetadata struct { +type BashResponseMetadata struct { Took int64 `json:"took"` } type bashTool struct { @@ -310,7 +310,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) stdout += "\n" + errorMessage } - metadata := BashToolResponseMetadata{ + metadata := BashResponseMetadata{ Took: took, } if stdout == "" { diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 32e2034e451063f0f5e8e33328793c0f44bca4d1..c9a0be07956bc65725abf93a5bdd1774bda371ae 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -10,9 +10,9 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/git" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/sergi/go-diff/diffmatchpatch" ) type EditParams struct { @@ -22,10 +22,13 @@ type EditParams struct { } type EditPermissionsParams struct { - FilePath string `json:"file_path"` - OldString string `json:"old_string"` - NewString string `json:"new_string"` - Diff string `json:"diff"` + FilePath string `json:"file_path"` + Diff string `json:"diff"` +} + +type EditResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` } type editTool struct { @@ -129,48 +132,77 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } if params.OldString == "" { - result, err := e.createNewFile(params.FilePath, params.NewString) + result, err := e.createNewFile(ctx, params.FilePath, params.NewString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil } - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil } if params.NewString == "" { - result, err := e.deleteContent(params.FilePath, params.OldString) + result, err := e.deleteContent(ctx, params.FilePath, params.OldString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil } - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil } - result, err := e.replaceContent(params.FilePath, params.OldString, params.NewString) + result, err := e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) if err != nil { return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) - result = fmt.Sprintf("\n%s\n\n", result) - result += appendDiagnostics(params.FilePath, e.lspClients) - return NewTextResponse(result), nil + text := fmt.Sprintf("\n%s\n\n", result.text) + text += appendDiagnostics(params.FilePath, e.lspClients) + return WithResponseMetadata(NewTextResponse(text), EditResponseMetadata{ + Additions: result.additions, + Removals: result.removals, + }), nil +} + +type editResponse struct { + text string + additions int + removals int } -func (e *editTool) createNewFile(filePath, content string) (string, error) { +func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err == nil { if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } - return "", fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath) + return er, fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath) } else if !os.IsNotExist(err) { - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } dir := filepath.Dir(filePath) if err = os.MkdirAll(dir, 0o755); err != nil { - return "", fmt.Errorf("failed to create parent directories: %w", err) + return er, fmt.Errorf("failed to create parent directories: %w", err) } + sessionID, messageID := getContextValues(ctx) + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + "", + content, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), @@ -178,71 +210,88 @@ func (e *editTool) createNewFile(filePath, content string) (string, error) { Action: "create", Description: fmt.Sprintf("Create file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: "", - NewString: content, - Diff: GenerateDiff("", content), + FilePath: filePath, + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(content), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) - return "File created: " + filePath, nil + er.text = "File created: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals + return er, nil } -func (e *editTool) deleteContent(filePath, oldString string) (string, error) { +func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", filePath) + return er, fmt.Errorf("file not found: %s", filePath) } - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } if getLastReadTime(filePath).IsZero() { - return "", fmt.Errorf("you must read the file before editing it. Use the View tool first") + return er, fmt.Errorf("you must read the file before editing it. Use the View tool first") } modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { - return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)) } content, err := os.ReadFile(filePath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return er, fmt.Errorf("failed to read file: %w", err) } oldContent := string(content) index := strings.Index(oldContent, oldString) if index == -1 { - return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") + return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") } lastIndex := strings.LastIndex(oldContent, oldString) if index != lastIndex { - return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") + return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") } newContent := oldContent[:index] + oldContent[index+len(oldString):] + sessionID, messageID := getContextValues(ctx) + + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + newContent, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } + p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), @@ -250,76 +299,85 @@ func (e *editTool) deleteContent(filePath, oldString string) (string, error) { Action: "delete", Description: fmt.Sprintf("Delete content from file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: oldString, - NewString: "", - Diff: GenerateDiff(oldContent, newContent), + FilePath: filePath, + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(newContent), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } - recordFileWrite(filePath) recordFileRead(filePath) - return "Content deleted from file: " + filePath, nil + er.text = "Content deleted from file: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals + return er, nil } -func (e *editTool) replaceContent(filePath, oldString, newString string) (string, error) { +func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (editResponse, error) { + er := editResponse{} fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", filePath) + return er, fmt.Errorf("file not found: %s", filePath) } - return "", fmt.Errorf("failed to access file: %w", err) + return er, fmt.Errorf("failed to access file: %w", err) } if fileInfo.IsDir() { - return "", fmt.Errorf("path is a directory, not a file: %s", filePath) + return er, fmt.Errorf("path is a directory, not a file: %s", filePath) } if getLastReadTime(filePath).IsZero() { - return "", fmt.Errorf("you must read the file before editing it. Use the View tool first") + return er, fmt.Errorf("you must read the file before editing it. Use the View tool first") } modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { - return "", fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)) } content, err := os.ReadFile(filePath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return er, fmt.Errorf("failed to read file: %w", err) } oldContent := string(content) index := strings.Index(oldContent, oldString) if index == -1 { - return "", fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") + return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") } lastIndex := strings.LastIndex(oldContent, oldString) if index != lastIndex { - return "", fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") + return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") } newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] - startIndex := max(0, index-3) - oldEndIndex := min(len(oldContent), index+len(oldString)+3) - newEndIndex := min(len(newContent), index+len(newString)+3) + sessionID, messageID := getContextValues(ctx) - diff := GenerateDiff(oldContent[startIndex:oldEndIndex], newContent[startIndex:newEndIndex]) + if sessionID == "" || messageID == "" { + return er, fmt.Errorf("session ID and message ID are required for creating a new file") + } + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + newContent, + ) + if err != nil { + return er, fmt.Errorf("failed to get file diff: %w", err) + } p := e.permissions.Request( permission.CreatePermissionRequest{ @@ -328,75 +386,27 @@ func (e *editTool) replaceContent(filePath, oldString, newString string) (string Action: "replace", Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ - FilePath: filePath, - OldString: oldString, - NewString: newString, - Diff: diff, + FilePath: filePath, + + Diff: diff, }, }, ) if !p { - return "", fmt.Errorf("permission denied") + return er, fmt.Errorf("permission denied") } err = os.WriteFile(filePath, []byte(newContent), 0o644) if err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return er, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) + er.text = "Content replaced in file: " + filePath + er.additions = stats.Additions + er.removals = stats.Removals - return "Content replaced in file: " + filePath, nil + return er, nil } -func GenerateDiff(oldContent, newContent string) string { - dmp := diffmatchpatch.New() - fileAdmp, fileBdmp, dmpStrings := dmp.DiffLinesToChars(oldContent, newContent) - diffs := dmp.DiffMain(fileAdmp, fileBdmp, false) - diffs = dmp.DiffCharsToLines(diffs, dmpStrings) - diffs = dmp.DiffCleanupSemantic(diffs) - buff := strings.Builder{} - - buff.WriteString("Changes:\n") - - for _, diff := range diffs { - text := diff.Text - - switch diff.Type { - case diffmatchpatch.DiffInsert: - for line := range strings.SplitSeq(text, "\n") { - if line == "" { - continue - } - _, _ = buff.WriteString("+ " + line + "\n") - } - case diffmatchpatch.DiffDelete: - for line := range strings.SplitSeq(text, "\n") { - if line == "" { - continue - } - _, _ = buff.WriteString("- " + line + "\n") - } - case diffmatchpatch.DiffEqual: - lines := strings.Split(text, "\n") - if len(lines) > 3 { - if lines[0] != "" { - _, _ = buff.WriteString(" " + lines[0] + "\n") - } - _, _ = buff.WriteString(" ...\n") - if lines[len(lines)-1] != "" { - _, _ = buff.WriteString(" " + lines[len(lines)-1] + "\n") - } - } else { - for _, line := range lines { - if line == "" { - continue - } - _, _ = buff.WriteString(" " + line + "\n") - } - } - } - } - return buff.String() -} diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index dbc6e488f822378432ed41e6cd3d3909651bc3e9..48a34ed75c2261f022cbed1cd0e08a6f1949d642 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -459,51 +459,3 @@ func TestEditTool_Run(t *testing.T) { assert.Equal(t, initialContent, string(fileContent)) }) } - -func TestGenerateDiff(t *testing.T) { - testCases := []struct { - name string - oldContent string - newContent string - expectedDiff string - }{ - { - name: "add content", - oldContent: "Line 1\nLine 2\n", - newContent: "Line 1\nLine 2\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n Line 2\n+ Line 3\n", - }, - { - name: "remove content", - oldContent: "Line 1\nLine 2\nLine 3\n", - newContent: "Line 1\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n- Line 2\n Line 3\n", - }, - { - name: "replace content", - oldContent: "Line 1\nLine 2\nLine 3\n", - newContent: "Line 1\nModified Line\nLine 3\n", - expectedDiff: "Changes:\n Line 1\n- Line 2\n+ Modified Line\n Line 3\n", - }, - { - name: "empty to content", - oldContent: "", - newContent: "Line 1\nLine 2\n", - expectedDiff: "Changes:\n+ Line 1\n+ Line 2\n", - }, - { - name: "content to empty", - oldContent: "Line 1\nLine 2\n", - newContent: "", - expectedDiff: "Changes:\n- Line 1\n- Line 2\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - diff := GenerateDiff(tc.oldContent, tc.newContent) - assert.Contains(t, diff, tc.expectedDiff) - }) - } -} - diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go index 7f34fdc1f615031decf00706c58aac37a235b57e..9c9707c9c3f387b59e7f8a528344dd2879e1ee71 100644 --- a/internal/llm/tools/file.go +++ b/internal/llm/tools/file.go @@ -3,6 +3,8 @@ package tools import ( "sync" "time" + + "github.com/kujtimiihoxha/termai/internal/config" ) // File record to track when files were read/written @@ -17,6 +19,14 @@ var ( fileRecordMutex sync.RWMutex ) +func removeWorkingDirectoryPrefix(path string) string { + wd := config.WorkingDirectory() + if len(path) > len(wd) && path[:len(wd)] == wd { + return path[len(wd)+1:] + } + return path +} + func recordFileRead(path string) { fileRecordMutex.Lock() defer fileRecordMutex.Unlock() diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 6bb5286863128b26ca16a1d2965b3c4901be4471..473b787bbbe9c2e20c3c08aa7ed348b7dafc009f 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -17,6 +17,9 @@ type toolResponseType string const ( ToolResponseTypeText toolResponseType = "text" ToolResponseTypeImage toolResponseType = "image" + + SessionIDContextKey = "session_id" + MessageIDContextKey = "message_id" ) type ToolResponse struct { @@ -62,3 +65,15 @@ type BaseTool interface { Info() ToolInfo Run(ctx context.Context, params ToolCall) (ToolResponse, error) } + +func getContextValues(ctx context.Context) (string, string) { + sessionID := ctx.Value(SessionIDContextKey) + messageID := ctx.Value(MessageIDContextKey) + if sessionID == nil { + return "", "" + } + if messageID == nil { + return sessionID.(string), "" + } + return sessionID.(string), messageID.(string) +} diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 7b698d2d8dcc5de7bbd32d26c9e94749fd5ed996..27c98bb9d244ff0054165d6d3e036a803c5f9908 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -9,6 +9,7 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/git" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -20,7 +21,7 @@ type WriteParams struct { type WritePermissionsParams struct { FilePath string `json:"file_path"` - Content string `json:"content"` + Diff string `json:"diff"` } type writeTool struct { @@ -28,6 +29,11 @@ type writeTool struct { permissions permission.Service } +type WriteResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` +} + const ( WriteToolName = "write" writeDescription = `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content. @@ -138,6 +144,18 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error } } + sessionID, messageID := getContextValues(ctx) + if sessionID == "" || messageID == "" { + return NewTextErrorResponse("session ID or message ID is missing"), nil + } + diff, stats, err := git.GenerateGitDiffWithStats( + removeWorkingDirectoryPrefix(filePath), + oldContent, + params.Content, + ) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("Failed to get file diff: %s", err)), nil + } p := w.permissions.Request( permission.CreatePermissionRequest{ Path: filePath, @@ -146,7 +164,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error Description: fmt.Sprintf("Create file %s", filePath), Params: WritePermissionsParams{ FilePath: filePath, - Content: GenerateDiff(oldContent, params.Content), + Diff: diff, }, }, ) @@ -166,5 +184,10 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error result := fmt.Sprintf("File successfully written: %s", filePath) result = fmt.Sprintf("\n%s\n", result) result += appendDiagnostics(filePath, w.lspClients) - return NewTextResponse(result), nil + return WithResponseMetadata(NewTextResponse(result), + WriteResponseMetadata{ + Additions: stats.Additions, + Removals: stats.Removals, + }, + ), nil } diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 088697d5542259437d9c8e84f460309c14e6896c..344310eb6563688cf83e989eecbf7e24ec0bbd78 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -9,6 +9,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/git" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/permission" "github.com/kujtimiihoxha/termai/internal/tui/components/core" @@ -234,7 +235,6 @@ func (p *permissionDialogCmp) render() string { headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) // Format the diff with colors - formattedDiff := formatDiff(pr.Diff) // Set up viewport for the diff content p.contentViewPort.Width = p.width - 2 - 2 @@ -242,7 +242,11 @@ func (p *permissionDialogCmp) render() string { // Calculate content height dynamically based on window size maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 p.contentViewPort.Height = maxContentHeight - p.contentViewPort.SetContent(formattedDiff) + diff, err := git.FormatDiff(pr.Diff, p.contentViewPort.Width) + if err != nil { + diff = fmt.Sprintf("Error formatting diff: %v", err) + } + p.contentViewPort.SetContent(diff) // Style the viewport var contentBorder lipgloss.Border @@ -281,16 +285,17 @@ func (p *permissionDialogCmp) render() string { // Recreate header content with the updated headerParts headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - // Format the diff with colors - formattedDiff := formatDiff(pr.Content) - // Set up viewport for the content p.contentViewPort.Width = p.width - 2 - 2 // Calculate content height dynamically based on window size maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 p.contentViewPort.Height = maxContentHeight - p.contentViewPort.SetContent(formattedDiff) + diff, err := git.FormatDiff(pr.Diff, p.contentViewPort.Width) + if err != nil { + diff = fmt.Sprintf("Error formatting diff: %v", err) + } + p.contentViewPort.SetContent(diff) // Style the viewport var contentBorder lipgloss.Border From d63d0c4c44cb2356645fd112cec01598301f5949 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sun, 13 Apr 2025 11:29:20 +0200 Subject: [PATCH 06/41] wip files --- internal/app/services.go | 4 + internal/db/db.go | 152 +++++++-- internal/db/files.sql.go | 309 ++++++++++++++++++ .../db/migrations/000001_initial.down.sql | 2 + internal/db/migrations/000001_initial.up.sql | 22 ++ internal/db/models.go | 10 + internal/db/querier.go | 10 + internal/db/sql/files.sql | 69 ++++ internal/history/file.go | 206 ++++++++++++ internal/tui/tui.go | 4 +- 10 files changed, 760 insertions(+), 28 deletions(-) create mode 100644 internal/db/files.sql.go create mode 100644 internal/db/sql/files.sql create mode 100644 internal/history/file.go diff --git a/internal/app/services.go b/internal/app/services.go index 76b2226aef288848efb4c2473971052f7b911d9c..6ecdef03c2f7d7766431aa80e4e9eeaf9f154a0f 100644 --- a/internal/app/services.go +++ b/internal/app/services.go @@ -6,6 +6,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/lsp/watcher" @@ -19,6 +20,7 @@ type App struct { Sessions session.Service Messages message.Service + Files history.Service Permissions permission.Service LSPClients map[string]*lsp.Client @@ -31,11 +33,13 @@ func New(ctx context.Context, conn *sql.DB) *App { q := db.New(conn) sessions := session.NewService(ctx, q) messages := message.NewService(ctx, q) + files := history.NewService(ctx, q) app := &App{ Context: ctx, Sessions: sessions, Messages: messages, + Files: files, Permissions: permission.NewPermissionService(), LSPClients: make(map[string]*lsp.Client), } diff --git a/internal/db/db.go b/internal/db/db.go index 75f6260139472ff2fd30855b7d5d927bab3815d9..16e66380405615910bd1ebf1449cb40e0fca5756 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -24,33 +24,63 @@ func New(db DBTX) *Queries { func Prepare(ctx context.Context, db DBTX) (*Queries, error) { q := Queries{db: db} var err error + if q.createFileStmt, err = db.PrepareContext(ctx, createFile); err != nil { + return nil, fmt.Errorf("error preparing query CreateFile: %w", err) + } if q.createMessageStmt, err = db.PrepareContext(ctx, createMessage); err != nil { return nil, fmt.Errorf("error preparing query CreateMessage: %w", err) } if q.createSessionStmt, err = db.PrepareContext(ctx, createSession); err != nil { return nil, fmt.Errorf("error preparing query CreateSession: %w", err) } + if q.deleteFileStmt, err = db.PrepareContext(ctx, deleteFile); err != nil { + return nil, fmt.Errorf("error preparing query DeleteFile: %w", err) + } if q.deleteMessageStmt, err = db.PrepareContext(ctx, deleteMessage); err != nil { return nil, fmt.Errorf("error preparing query DeleteMessage: %w", err) } if q.deleteSessionStmt, err = db.PrepareContext(ctx, deleteSession); err != nil { return nil, fmt.Errorf("error preparing query DeleteSession: %w", err) } + if q.deleteSessionFilesStmt, err = db.PrepareContext(ctx, deleteSessionFiles); err != nil { + return nil, fmt.Errorf("error preparing query DeleteSessionFiles: %w", err) + } if q.deleteSessionMessagesStmt, err = db.PrepareContext(ctx, deleteSessionMessages); err != nil { return nil, fmt.Errorf("error preparing query DeleteSessionMessages: %w", err) } + if q.getFileStmt, err = db.PrepareContext(ctx, getFile); err != nil { + return nil, fmt.Errorf("error preparing query GetFile: %w", err) + } + if q.getFileByPathAndSessionStmt, err = db.PrepareContext(ctx, getFileByPathAndSession); err != nil { + return nil, fmt.Errorf("error preparing query GetFileByPathAndSession: %w", err) + } if q.getMessageStmt, err = db.PrepareContext(ctx, getMessage); err != nil { return nil, fmt.Errorf("error preparing query GetMessage: %w", err) } if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil { return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err) } + if q.listFilesByPathStmt, err = db.PrepareContext(ctx, listFilesByPath); err != nil { + return nil, fmt.Errorf("error preparing query ListFilesByPath: %w", err) + } + if q.listFilesBySessionStmt, err = db.PrepareContext(ctx, listFilesBySession); err != nil { + return nil, fmt.Errorf("error preparing query ListFilesBySession: %w", err) + } + if q.listLatestSessionFilesStmt, err = db.PrepareContext(ctx, listLatestSessionFiles); err != nil { + return nil, fmt.Errorf("error preparing query ListLatestSessionFiles: %w", err) + } if q.listMessagesBySessionStmt, err = db.PrepareContext(ctx, listMessagesBySession); err != nil { return nil, fmt.Errorf("error preparing query ListMessagesBySession: %w", err) } + if q.listNewFilesStmt, err = db.PrepareContext(ctx, listNewFiles); err != nil { + return nil, fmt.Errorf("error preparing query ListNewFiles: %w", err) + } if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil { return nil, fmt.Errorf("error preparing query ListSessions: %w", err) } + if q.updateFileStmt, err = db.PrepareContext(ctx, updateFile); err != nil { + return nil, fmt.Errorf("error preparing query UpdateFile: %w", err) + } if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil { return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err) } @@ -62,6 +92,11 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { func (q *Queries) Close() error { var err error + if q.createFileStmt != nil { + if cerr := q.createFileStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing createFileStmt: %w", cerr) + } + } if q.createMessageStmt != nil { if cerr := q.createMessageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing createMessageStmt: %w", cerr) @@ -72,6 +107,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing createSessionStmt: %w", cerr) } } + if q.deleteFileStmt != nil { + if cerr := q.deleteFileStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing deleteFileStmt: %w", cerr) + } + } if q.deleteMessageStmt != nil { if cerr := q.deleteMessageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing deleteMessageStmt: %w", cerr) @@ -82,11 +122,26 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing deleteSessionStmt: %w", cerr) } } + if q.deleteSessionFilesStmt != nil { + if cerr := q.deleteSessionFilesStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing deleteSessionFilesStmt: %w", cerr) + } + } if q.deleteSessionMessagesStmt != nil { if cerr := q.deleteSessionMessagesStmt.Close(); cerr != nil { err = fmt.Errorf("error closing deleteSessionMessagesStmt: %w", cerr) } } + if q.getFileStmt != nil { + if cerr := q.getFileStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getFileStmt: %w", cerr) + } + } + if q.getFileByPathAndSessionStmt != nil { + if cerr := q.getFileByPathAndSessionStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getFileByPathAndSessionStmt: %w", cerr) + } + } if q.getMessageStmt != nil { if cerr := q.getMessageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing getMessageStmt: %w", cerr) @@ -97,16 +152,41 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr) } } + if q.listFilesByPathStmt != nil { + if cerr := q.listFilesByPathStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing listFilesByPathStmt: %w", cerr) + } + } + if q.listFilesBySessionStmt != nil { + if cerr := q.listFilesBySessionStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing listFilesBySessionStmt: %w", cerr) + } + } + if q.listLatestSessionFilesStmt != nil { + if cerr := q.listLatestSessionFilesStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing listLatestSessionFilesStmt: %w", cerr) + } + } if q.listMessagesBySessionStmt != nil { if cerr := q.listMessagesBySessionStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listMessagesBySessionStmt: %w", cerr) } } + if q.listNewFilesStmt != nil { + if cerr := q.listNewFilesStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing listNewFilesStmt: %w", cerr) + } + } if q.listSessionsStmt != nil { if cerr := q.listSessionsStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listSessionsStmt: %w", cerr) } } + if q.updateFileStmt != nil { + if cerr := q.updateFileStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing updateFileStmt: %w", cerr) + } + } if q.updateMessageStmt != nil { if cerr := q.updateMessageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing updateMessageStmt: %w", cerr) @@ -154,35 +234,55 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar } type Queries struct { - db DBTX - tx *sql.Tx - createMessageStmt *sql.Stmt - createSessionStmt *sql.Stmt - deleteMessageStmt *sql.Stmt - deleteSessionStmt *sql.Stmt - deleteSessionMessagesStmt *sql.Stmt - getMessageStmt *sql.Stmt - getSessionByIDStmt *sql.Stmt - listMessagesBySessionStmt *sql.Stmt - listSessionsStmt *sql.Stmt - updateMessageStmt *sql.Stmt - updateSessionStmt *sql.Stmt + db DBTX + tx *sql.Tx + createFileStmt *sql.Stmt + createMessageStmt *sql.Stmt + createSessionStmt *sql.Stmt + deleteFileStmt *sql.Stmt + deleteMessageStmt *sql.Stmt + deleteSessionStmt *sql.Stmt + deleteSessionFilesStmt *sql.Stmt + deleteSessionMessagesStmt *sql.Stmt + getFileStmt *sql.Stmt + getFileByPathAndSessionStmt *sql.Stmt + getMessageStmt *sql.Stmt + getSessionByIDStmt *sql.Stmt + listFilesByPathStmt *sql.Stmt + listFilesBySessionStmt *sql.Stmt + listLatestSessionFilesStmt *sql.Stmt + listMessagesBySessionStmt *sql.Stmt + listNewFilesStmt *sql.Stmt + listSessionsStmt *sql.Stmt + updateFileStmt *sql.Stmt + updateMessageStmt *sql.Stmt + updateSessionStmt *sql.Stmt } func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ - db: tx, - tx: tx, - createMessageStmt: q.createMessageStmt, - createSessionStmt: q.createSessionStmt, - deleteMessageStmt: q.deleteMessageStmt, - deleteSessionStmt: q.deleteSessionStmt, - deleteSessionMessagesStmt: q.deleteSessionMessagesStmt, - getMessageStmt: q.getMessageStmt, - getSessionByIDStmt: q.getSessionByIDStmt, - listMessagesBySessionStmt: q.listMessagesBySessionStmt, - listSessionsStmt: q.listSessionsStmt, - updateMessageStmt: q.updateMessageStmt, - updateSessionStmt: q.updateSessionStmt, + db: tx, + tx: tx, + createFileStmt: q.createFileStmt, + createMessageStmt: q.createMessageStmt, + createSessionStmt: q.createSessionStmt, + deleteFileStmt: q.deleteFileStmt, + deleteMessageStmt: q.deleteMessageStmt, + deleteSessionStmt: q.deleteSessionStmt, + deleteSessionFilesStmt: q.deleteSessionFilesStmt, + deleteSessionMessagesStmt: q.deleteSessionMessagesStmt, + getFileStmt: q.getFileStmt, + getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, + getMessageStmt: q.getMessageStmt, + getSessionByIDStmt: q.getSessionByIDStmt, + listFilesByPathStmt: q.listFilesByPathStmt, + listFilesBySessionStmt: q.listFilesBySessionStmt, + listLatestSessionFilesStmt: q.listLatestSessionFilesStmt, + listMessagesBySessionStmt: q.listMessagesBySessionStmt, + listNewFilesStmt: q.listNewFilesStmt, + listSessionsStmt: q.listSessionsStmt, + updateFileStmt: q.updateFileStmt, + updateMessageStmt: q.updateMessageStmt, + updateSessionStmt: q.updateSessionStmt, } } diff --git a/internal/db/files.sql.go b/internal/db/files.sql.go new file mode 100644 index 0000000000000000000000000000000000000000..b45731098451eaee1ed2b0b198ce5db39ac40094 --- /dev/null +++ b/internal/db/files.sql.go @@ -0,0 +1,309 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: files.sql + +package db + +import ( + "context" +) + +const createFile = `-- name: CreateFile :one +INSERT INTO files ( + id, + session_id, + path, + content, + version, + created_at, + updated_at +) VALUES ( + ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') +) +RETURNING id, session_id, path, content, version, created_at, updated_at +` + +type CreateFileParams struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + Path string `json:"path"` + Content string `json:"content"` + Version string `json:"version"` +} + +func (q *Queries) CreateFile(ctx context.Context, arg CreateFileParams) (File, error) { + row := q.queryRow(ctx, q.createFileStmt, createFile, + arg.ID, + arg.SessionID, + arg.Path, + arg.Content, + arg.Version, + ) + var i File + err := row.Scan( + &i.ID, + &i.SessionID, + &i.Path, + &i.Content, + &i.Version, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteFile = `-- name: DeleteFile :exec +DELETE FROM files +WHERE id = ? +` + +func (q *Queries) DeleteFile(ctx context.Context, id string) error { + _, err := q.exec(ctx, q.deleteFileStmt, deleteFile, id) + return err +} + +const deleteSessionFiles = `-- name: DeleteSessionFiles :exec +DELETE FROM files +WHERE session_id = ? +` + +func (q *Queries) DeleteSessionFiles(ctx context.Context, sessionID string) error { + _, err := q.exec(ctx, q.deleteSessionFilesStmt, deleteSessionFiles, sessionID) + return err +} + +const getFile = `-- name: GetFile :one +SELECT id, session_id, path, content, version, created_at, updated_at +FROM files +WHERE id = ? LIMIT 1 +` + +func (q *Queries) GetFile(ctx context.Context, id string) (File, error) { + row := q.queryRow(ctx, q.getFileStmt, getFile, id) + var i File + err := row.Scan( + &i.ID, + &i.SessionID, + &i.Path, + &i.Content, + &i.Version, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getFileByPathAndSession = `-- name: GetFileByPathAndSession :one +SELECT id, session_id, path, content, version, created_at, updated_at +FROM files +WHERE path = ? AND session_id = ? LIMIT 1 +` + +type GetFileByPathAndSessionParams struct { + Path string `json:"path"` + SessionID string `json:"session_id"` +} + +func (q *Queries) GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error) { + row := q.queryRow(ctx, q.getFileByPathAndSessionStmt, getFileByPathAndSession, arg.Path, arg.SessionID) + var i File + err := row.Scan( + &i.ID, + &i.SessionID, + &i.Path, + &i.Content, + &i.Version, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listFilesByPath = `-- name: ListFilesByPath :many +SELECT id, session_id, path, content, version, created_at, updated_at +FROM files +WHERE path = ? +ORDER BY created_at DESC +` + +func (q *Queries) ListFilesByPath(ctx context.Context, path string) ([]File, error) { + rows, err := q.query(ctx, q.listFilesByPathStmt, listFilesByPath, path) + if err != nil { + return nil, err + } + defer rows.Close() + items := []File{} + for rows.Next() { + var i File + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Path, + &i.Content, + &i.Version, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listFilesBySession = `-- name: ListFilesBySession :many +SELECT id, session_id, path, content, version, created_at, updated_at +FROM files +WHERE session_id = ? +ORDER BY created_at ASC +` + +func (q *Queries) ListFilesBySession(ctx context.Context, sessionID string) ([]File, error) { + rows, err := q.query(ctx, q.listFilesBySessionStmt, listFilesBySession, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []File{} + for rows.Next() { + var i File + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Path, + &i.Content, + &i.Version, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listLatestSessionFiles = `-- name: ListLatestSessionFiles :many +SELECT f.id, f.session_id, f.path, f.content, f.version, f.created_at, f.updated_at +FROM files f +INNER JOIN ( + SELECT path, MAX(created_at) as max_created_at + FROM files + GROUP BY path +) latest ON f.path = latest.path AND f.created_at = latest.max_created_at +WHERE f.session_id = ? +ORDER BY f.path +` + +func (q *Queries) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) { + rows, err := q.query(ctx, q.listLatestSessionFilesStmt, listLatestSessionFiles, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []File{} + for rows.Next() { + var i File + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Path, + &i.Content, + &i.Version, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listNewFiles = `-- name: ListNewFiles :many +SELECT id, session_id, path, content, version, created_at, updated_at +FROM files +WHERE is_new = 1 +ORDER BY created_at DESC +` + +func (q *Queries) ListNewFiles(ctx context.Context) ([]File, error) { + rows, err := q.query(ctx, q.listNewFilesStmt, listNewFiles) + if err != nil { + return nil, err + } + defer rows.Close() + items := []File{} + for rows.Next() { + var i File + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Path, + &i.Content, + &i.Version, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateFile = `-- name: UpdateFile :one +UPDATE files +SET + content = ?, + version = ?, + updated_at = strftime('%s', 'now') +WHERE id = ? +RETURNING id, session_id, path, content, version, created_at, updated_at +` + +type UpdateFileParams struct { + Content string `json:"content"` + Version string `json:"version"` + ID string `json:"id"` +} + +func (q *Queries) UpdateFile(ctx context.Context, arg UpdateFileParams) (File, error) { + row := q.queryRow(ctx, q.updateFileStmt, updateFile, arg.Content, arg.Version, arg.ID) + var i File + err := row.Scan( + &i.ID, + &i.SessionID, + &i.Path, + &i.Content, + &i.Version, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/db/migrations/000001_initial.down.sql b/internal/db/migrations/000001_initial.down.sql index 4f2712d869e0e736ea084bee8f41ee3ca6f676f0..a2b0d13214ebd4e603904fd2028cd49f48857cdd 100644 --- a/internal/db/migrations/000001_initial.down.sql +++ b/internal/db/migrations/000001_initial.down.sql @@ -1,8 +1,10 @@ DROP TRIGGER IF EXISTS update_sessions_updated_at; DROP TRIGGER IF EXISTS update_messages_updated_at; +DROP TRIGGER IF EXISTS update_files_updated_at; DROP TRIGGER IF EXISTS update_session_message_count_on_delete; DROP TRIGGER IF EXISTS update_session_message_count_on_insert; DROP TABLE IF EXISTS sessions; DROP TABLE IF EXISTS messages; +DROP TABLE IF EXISTS files; diff --git a/internal/db/migrations/000001_initial.up.sql b/internal/db/migrations/000001_initial.up.sql index 03479449d24492c04cda61f56056d9c3d7fb73fa..4ac297dc5f1ab16f38eeb81a7b3135f64cbf9860 100644 --- a/internal/db/migrations/000001_initial.up.sql +++ b/internal/db/migrations/000001_initial.up.sql @@ -18,6 +18,28 @@ UPDATE sessions SET updated_at = strftime('%s', 'now') WHERE id = new.id; END; +-- Files +CREATE TABLE IF NOT EXISTS files ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + path TEXT NOT NULL, + content TEXT NOT NULL, + version TEXT NOT NULL, + created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds + updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds + FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_files_session_id ON files (session_id); +CREATE INDEX IF NOT EXISTS idx_files_path ON files (path); + +CREATE TRIGGER IF NOT EXISTS update_files_updated_at +AFTER UPDATE ON files +BEGIN +UPDATE files SET updated_at = strftime('%s', 'now') +WHERE id = new.id; +END; + -- Messages CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, diff --git a/internal/db/models.go b/internal/db/models.go index 2fad913be831d4d475642def6d94f2f7fadd960d..f00cb6ad17ec5f9426502bb7612191dd6065f255 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -8,6 +8,16 @@ import ( "database/sql" ) +type File struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + Path string `json:"path"` + Content string `json:"content"` + Version string `json:"version"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + type Message struct { ID string `json:"id"` SessionID string `json:"session_id"` diff --git a/internal/db/querier.go b/internal/db/querier.go index c9d73ec39662eaa736f07af4955dc17a8286023a..704a97da26c7feaf022ff3d8fa228b918ab298b6 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -9,15 +9,25 @@ import ( ) type Querier interface { + CreateFile(ctx context.Context, arg CreateFileParams) (File, error) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) + DeleteFile(ctx context.Context, id string) error DeleteMessage(ctx context.Context, id string) error DeleteSession(ctx context.Context, id string) error + DeleteSessionFiles(ctx context.Context, sessionID string) error DeleteSessionMessages(ctx context.Context, sessionID string) error + GetFile(ctx context.Context, id string) (File, error) + GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error) GetMessage(ctx context.Context, id string) (Message, error) GetSessionByID(ctx context.Context, id string) (Session, error) + ListFilesByPath(ctx context.Context, path string) ([]File, error) + ListFilesBySession(ctx context.Context, sessionID string) ([]File, error) + ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error) + ListNewFiles(ctx context.Context) ([]File, error) ListSessions(ctx context.Context) ([]Session, error) + UpdateFile(ctx context.Context, arg UpdateFileParams) (File, error) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) } diff --git a/internal/db/sql/files.sql b/internal/db/sql/files.sql new file mode 100644 index 0000000000000000000000000000000000000000..c2e7990764fc71a827ce92e7297cb4b155a2eafd --- /dev/null +++ b/internal/db/sql/files.sql @@ -0,0 +1,69 @@ +-- name: GetFile :one +SELECT * +FROM files +WHERE id = ? LIMIT 1; + +-- name: GetFileByPathAndSession :one +SELECT * +FROM files +WHERE path = ? AND session_id = ? LIMIT 1; + +-- name: ListFilesBySession :many +SELECT * +FROM files +WHERE session_id = ? +ORDER BY created_at ASC; + +-- name: ListFilesByPath :many +SELECT * +FROM files +WHERE path = ? +ORDER BY created_at DESC; + +-- name: CreateFile :one +INSERT INTO files ( + id, + session_id, + path, + content, + version, + created_at, + updated_at +) VALUES ( + ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') +) +RETURNING *; + +-- name: UpdateFile :one +UPDATE files +SET + content = ?, + version = ?, + updated_at = strftime('%s', 'now') +WHERE id = ? +RETURNING *; + +-- name: DeleteFile :exec +DELETE FROM files +WHERE id = ?; + +-- name: DeleteSessionFiles :exec +DELETE FROM files +WHERE session_id = ?; + +-- name: ListLatestSessionFiles :many +SELECT f.* +FROM files f +INNER JOIN ( + SELECT path, MAX(created_at) as max_created_at + FROM files + GROUP BY path +) latest ON f.path = latest.path AND f.created_at = latest.max_created_at +WHERE f.session_id = ? +ORDER BY f.path; + +-- name: ListNewFiles :many +SELECT * +FROM files +WHERE is_new = 1 +ORDER BY created_at DESC; diff --git a/internal/history/file.go b/internal/history/file.go new file mode 100644 index 0000000000000000000000000000000000000000..25953b27325bc1266d993850337448097901fd02 --- /dev/null +++ b/internal/history/file.go @@ -0,0 +1,206 @@ +package history + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/google/uuid" + "github.com/kujtimiihoxha/termai/internal/db" + "github.com/kujtimiihoxha/termai/internal/pubsub" +) + +const ( + InitialVersion = "initial" +) + +type File struct { + ID string + SessionID string + Path string + Content string + Version string + CreatedAt int64 + UpdatedAt int64 +} + +type Service interface { + pubsub.Suscriber[File] + Create(sessionID, path, content string) (File, error) + CreateVersion(sessionID, path, content string) (File, error) + Get(id string) (File, error) + GetByPathAndSession(path, sessionID string) (File, error) + ListBySession(sessionID string) ([]File, error) + ListLatestSessionFiles(sessionID string) ([]File, error) + Update(file File) (File, error) + Delete(id string) error + DeleteSessionFiles(sessionID string) error +} + +type service struct { + *pubsub.Broker[File] + q db.Querier + ctx context.Context +} + +func NewService(ctx context.Context, q db.Querier) Service { + return &service{ + Broker: pubsub.NewBroker[File](), + q: q, + ctx: ctx, + } +} + +func (s *service) Create(sessionID, path, content string) (File, error) { + return s.createWithVersion(sessionID, path, content, InitialVersion) +} + +func (s *service) CreateVersion(sessionID, path, content string) (File, error) { + // Get the latest version for this path + files, err := s.q.ListFilesByPath(s.ctx, path) + if err != nil { + return File{}, err + } + + if len(files) == 0 { + // No previous versions, create initial + return s.Create(sessionID, path, content) + } + + // Get the latest version + latestFile := files[0] // Files are ordered by created_at DESC + latestVersion := latestFile.Version + + // Generate the next version + var nextVersion string + if latestVersion == InitialVersion { + nextVersion = "v1" + } else if strings.HasPrefix(latestVersion, "v") { + versionNum, err := strconv.Atoi(latestVersion[1:]) + if err != nil { + // If we can't parse the version, just use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } else { + nextVersion = fmt.Sprintf("v%d", versionNum+1) + } + } else { + // If the version format is unexpected, use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } + + return s.createWithVersion(sessionID, path, content, nextVersion) +} + +func (s *service) createWithVersion(sessionID, path, content, version string) (File, error) { + dbFile, err := s.q.CreateFile(s.ctx, db.CreateFileParams{ + ID: uuid.New().String(), + SessionID: sessionID, + Path: path, + Content: content, + Version: version, + }) + if err != nil { + return File{}, err + } + file := s.fromDBItem(dbFile) + s.Publish(pubsub.CreatedEvent, file) + return file, nil +} + +func (s *service) Get(id string) (File, error) { + dbFile, err := s.q.GetFile(s.ctx, id) + if err != nil { + return File{}, err + } + return s.fromDBItem(dbFile), nil +} + +func (s *service) GetByPathAndSession(path, sessionID string) (File, error) { + dbFile, err := s.q.GetFileByPathAndSession(s.ctx, db.GetFileByPathAndSessionParams{ + Path: path, + SessionID: sessionID, + }) + if err != nil { + return File{}, err + } + return s.fromDBItem(dbFile), nil +} + +func (s *service) ListBySession(sessionID string) ([]File, error) { + dbFiles, err := s.q.ListFilesBySession(s.ctx, sessionID) + if err != nil { + return nil, err + } + files := make([]File, len(dbFiles)) + for i, dbFile := range dbFiles { + files[i] = s.fromDBItem(dbFile) + } + return files, nil +} + +func (s *service) ListLatestSessionFiles(sessionID string) ([]File, error) { + dbFiles, err := s.q.ListLatestSessionFiles(s.ctx, sessionID) + if err != nil { + return nil, err + } + files := make([]File, len(dbFiles)) + for i, dbFile := range dbFiles { + files[i] = s.fromDBItem(dbFile) + } + return files, nil +} + +func (s *service) Update(file File) (File, error) { + dbFile, err := s.q.UpdateFile(s.ctx, db.UpdateFileParams{ + ID: file.ID, + Content: file.Content, + Version: file.Version, + }) + if err != nil { + return File{}, err + } + updatedFile := s.fromDBItem(dbFile) + s.Publish(pubsub.UpdatedEvent, updatedFile) + return updatedFile, nil +} + +func (s *service) Delete(id string) error { + file, err := s.Get(id) + if err != nil { + return err + } + err = s.q.DeleteFile(s.ctx, id) + if err != nil { + return err + } + s.Publish(pubsub.DeletedEvent, file) + return nil +} + +func (s *service) DeleteSessionFiles(sessionID string) error { + files, err := s.ListBySession(sessionID) + if err != nil { + return err + } + for _, file := range files { + err = s.Delete(file.ID) + if err != nil { + return err + } + } + return nil +} + +func (s *service) fromDBItem(item db.File) File { + return File{ + ID: item.ID, + SessionID: item.SessionID, + Path: item.Path, + Content: item.Content, + Version: item.Version, + CreatedAt: item.CreatedAt, + UpdatedAt: item.UpdatedAt, + } +} + diff --git a/internal/tui/tui.go b/internal/tui/tui.go index eb996d44863ffa8af398082ab7fc52417f6daf85..db9ac9ff6ef5eb0277210a9bf5e26188ff79ee00 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -198,8 +198,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: s.ID}) } - case key.Matches(msg, keys.Logs): - return a, a.moveToPage(page.LogsPage) + // case key.Matches(msg, keys.Logs): + // return a, a.moveToPage(page.LogsPage) case msg.String() == "O": return a, a.moveToPage(page.ReplPage) case key.Matches(msg, keys.Help): From 5601466fe1610b777895682050b1b458f80c0ac8 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sun, 13 Apr 2025 12:25:11 +0200 Subject: [PATCH 07/41] cleanup config --- internal/config/config.go | 282 +++++++++++++------- internal/config/config_test.go | 465 --------------------------------- 2 files changed, 183 insertions(+), 564 deletions(-) delete mode 100644 internal/config/config_test.go diff --git a/internal/config/config.go b/internal/config/config.go index fdfacd11af84bd9a0841ea48d7efa1b1cff6ac4c..6f757b3f48957185c9b256ea7b6dc1b07986a62d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,21 +1,27 @@ +// Package config manages application configuration from various sources. package config import ( "fmt" + "log/slog" "os" "strings" "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/spf13/viper" ) +// MCPType defines the type of MCP (Model Control Protocol) server. type MCPType string +// Supported MCP types const ( MCPStdio MCPType = "stdio" MCPSse MCPType = "sse" ) +// MCPServer defines the configuration for a Model Control Protocol server. type MCPServer struct { Command string `json:"command"` Env []string `json:"env"` @@ -23,37 +29,28 @@ type MCPServer struct { Type MCPType `json:"type"` URL string `json:"url"` Headers map[string]string `json:"headers"` - // TODO: add permissions configuration - // TODO: add the ability to specify the tools to import } +// Model defines configuration for different LLM models and their token limits. type Model struct { Coder models.ModelID `json:"coder"` CoderMaxTokens int64 `json:"coderMaxTokens"` - - Task models.ModelID `json:"task"` - TaskMaxTokens int64 `json:"taskMaxTokens"` - // TODO: Maybe support multiple models for different purposes -} - -type AnthropicConfig struct { - DisableCache bool `json:"disableCache"` - UseBedrock bool `json:"useBedrock"` + Task models.ModelID `json:"task"` + TaskMaxTokens int64 `json:"taskMaxTokens"` } +// Provider defines configuration for an LLM provider. type Provider struct { - APIKey string `json:"apiKey"` - Enabled bool `json:"enabled"` + APIKey string `json:"apiKey"` + Disabled bool `json:"disabled"` } +// Data defines storage configuration. type Data struct { Directory string `json:"directory"` } -type Log struct { - Level string `json:"level"` -} - +// LSPConfig defines configuration for Language Server Protocol integration. type LSPConfig struct { Disabled bool `json:"enabled"` Command string `json:"command"` @@ -61,41 +58,88 @@ type LSPConfig struct { Options any `json:"options"` } +// Config is the main configuration structure for the application. type Config struct { - Data *Data `json:"data,omitempty"` - Log *Log `json:"log,omitempty"` + Data Data `json:"data"` + WorkingDir string `json:"wd,omitempty"` MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` Providers map[models.ModelProvider]Provider `json:"providers,omitempty"` - - LSP map[string]LSPConfig `json:"lsp,omitempty"` - - Model *Model `json:"model,omitempty"` - - Debug bool `json:"debug,omitempty"` + LSP map[string]LSPConfig `json:"lsp,omitempty"` + Model Model `json:"model"` + Debug bool `json:"debug,omitempty"` } -var cfg *Config - +// Application constants const ( - defaultDataDirectory = ".termai" + defaultDataDirectory = ".opencode" defaultLogLevel = "info" defaultMaxTokens = int64(5000) - termai = "termai" + appName = "opencode" ) -func Load(debug bool) error { +// Global configuration instance +var cfg *Config + +// Load initializes the configuration from environment variables and config files. +// If debug is true, debug mode is enabled and log level is set to debug. +// It returns an error if configuration loading fails. +func Load(workingDir string, debug bool) error { if cfg != nil { return nil } - viper.SetConfigName(fmt.Sprintf(".%s", termai)) + cfg = &Config{ + WorkingDir: workingDir, + MCPServers: make(map[string]MCPServer), + Providers: make(map[models.ModelProvider]Provider), + LSP: make(map[string]LSPConfig), + } + + configureViper() + setDefaults(debug) + setProviderDefaults() + + // Read global config + if err := readConfig(viper.ReadInConfig()); err != nil { + return err + } + + // Load and merge local config + mergeLocalConfig(workingDir) + + // Apply configuration to the struct + if err := viper.Unmarshal(cfg); err != nil { + return err + } + + applyDefaultValues() + + defaultLevel := slog.LevelInfo + if cfg.Debug { + defaultLevel = slog.LevelDebug + } + // Configure logger + logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ + Level: defaultLevel, + })) + slog.SetDefault(logger) + return nil +} + +// configureViper sets up viper's configuration paths and environment variables. +func configureViper() { + viper.SetConfigName(fmt.Sprintf(".%s", appName)) viper.SetConfigType("json") viper.AddConfigPath("$HOME") - viper.AddConfigPath(fmt.Sprintf("$XDG_CONFIG_HOME/%s", termai)) - viper.SetEnvPrefix(strings.ToUpper(termai)) + viper.AddConfigPath(fmt.Sprintf("$XDG_CONFIG_HOME/%s", appName)) + viper.SetEnvPrefix(strings.ToUpper(appName)) + viper.AutomaticEnv() +} - // Add defaults +// setDefaults configures default values for configuration options. +func setDefaults(debug bool) { viper.SetDefault("data.directory", defaultDataDirectory) + if debug { viper.SetDefault("debug", true) viper.Set("log.level", "debug") @@ -103,98 +147,138 @@ func Load(debug bool) error { viper.SetDefault("debug", false) viper.SetDefault("log.level", defaultLogLevel) } +} + +// setProviderDefaults configures LLM provider defaults based on environment variables. +// the default model priority is: +// 1. Anthropic +// 2. OpenAI +// 3. Google Gemini +// 4. AWS Bedrock +func setProviderDefaults() { + // Groq configuration + if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { + viper.SetDefault("providers.groq.apiKey", apiKey) + viper.SetDefault("model.coder", models.QWENQwq) + viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) + viper.SetDefault("model.task", models.QWENQwq) + viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + } + + // Google Gemini configuration + if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { + viper.SetDefault("providers.gemini.apiKey", apiKey) + viper.SetDefault("model.coder", models.GRMINI20Flash) + viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) + viper.SetDefault("model.task", models.GRMINI20Flash) + viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + } - defaultModelSet := false - if os.Getenv("ANTHROPIC_API_KEY") != "" { - viper.SetDefault("providers.anthropic.apiKey", os.Getenv("ANTHROPIC_API_KEY")) - viper.SetDefault("providers.anthropic.enabled", true) + // OpenAI configuration + if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { + viper.SetDefault("providers.openai.apiKey", apiKey) + viper.SetDefault("model.coder", models.GPT4o) + viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) + viper.SetDefault("model.task", models.GPT4o) + viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + } + + // Anthropic configuration + if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { + viper.SetDefault("providers.anthropic.apiKey", apiKey) viper.SetDefault("model.coder", models.Claude37Sonnet) + viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) viper.SetDefault("model.task", models.Claude37Sonnet) - defaultModelSet = true - } - if os.Getenv("OPENAI_API_KEY") != "" { - viper.SetDefault("providers.openai.apiKey", os.Getenv("OPENAI_API_KEY")) - viper.SetDefault("providers.openai.enabled", true) - if !defaultModelSet { - viper.SetDefault("model.coder", models.GPT41) - viper.SetDefault("model.task", models.GPT41) - defaultModelSet = true - } + viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) } - if os.Getenv("GEMINI_API_KEY") != "" { - viper.SetDefault("providers.gemini.apiKey", os.Getenv("GEMINI_API_KEY")) - viper.SetDefault("providers.gemini.enabled", true) - if !defaultModelSet { - viper.SetDefault("model.coder", models.GRMINI20Flash) - viper.SetDefault("model.task", models.GRMINI20Flash) - defaultModelSet = true - } + + if hasAWSCredentials() { + viper.SetDefault("model.coder", models.BedrockClaude37Sonnet) + viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) + viper.SetDefault("model.task", models.BedrockClaude37Sonnet) + viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) } - if os.Getenv("GROQ_API_KEY") != "" { - viper.SetDefault("providers.groq.apiKey", os.Getenv("GROQ_API_KEY")) - viper.SetDefault("providers.groq.enabled", true) - if !defaultModelSet { - viper.SetDefault("model.coder", models.QWENQwq) - viper.SetDefault("model.task", models.QWENQwq) - defaultModelSet = true - } +} + +// hasAWSCredentials checks if AWS credentials are available in the environment. +func hasAWSCredentials() bool { + // Check for explicit AWS credentials + if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { + return true } - viper.SetDefault("providers.bedrock.enabled", true) - // TODO: add more providers - cfg = &Config{} + // Check for AWS profile + if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" { + return true + } - err := viper.ReadInConfig() - if err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); !ok { - return err - } + // Check for AWS region + if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" { + return true } - local := viper.New() - local.SetConfigName(fmt.Sprintf(".%s", termai)) - local.SetConfigType("json") - local.AddConfigPath(".") - // load local config, this will override the global config - if err = local.ReadInConfig(); err == nil { - viper.MergeConfigMap(local.AllSettings()) + + // Check if running on EC2 with instance profile + if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { + return true + } + + return false +} + +// readConfig handles the result of reading a configuration file. +func readConfig(err error) error { + if err == nil { + return nil } - viper.Unmarshal(cfg) - if cfg.Model != nil && cfg.Model.CoderMaxTokens <= 0 { - cfg.Model.CoderMaxTokens = defaultMaxTokens + // It's okay if the config file doesn't exist + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + return nil } - if cfg.Model != nil && cfg.Model.TaskMaxTokens <= 0 { - cfg.Model.TaskMaxTokens = defaultMaxTokens + + return err +} + +// mergeLocalConfig loads and merges configuration from the local directory. +func mergeLocalConfig(workingDir string) { + local := viper.New() + local.SetConfigName(fmt.Sprintf(".%s", appName)) + local.SetConfigType("json") + local.AddConfigPath(workingDir) + + // Merge local config if it exists + if err := local.ReadInConfig(); err == nil { + viper.MergeConfigMap(local.AllSettings()) } +} - for _, v := range cfg.MCPServers { +// applyDefaultValues sets default values for configuration fields that need processing. +func applyDefaultValues() { + // Set default MCP type if not specified + for k, v := range cfg.MCPServers { if v.Type == "" { v.Type = MCPStdio + cfg.MCPServers[k] = v } } +} +// setWorkingDirectory stores the current working directory in the configuration. +func setWorkingDirectory() { workdir, err := os.Getwd() - if err != nil { - return err + if err == nil { + viper.Set("wd", workdir) } - viper.Set("wd", workdir) - return nil } +// Get returns the current configuration. +// It's safe to call this function multiple times. func Get() *Config { - if cfg == nil { - err := Load(false) - if err != nil { - panic(err) - } - } return cfg } +// WorkingDirectory returns the current working directory from the configuration. func WorkingDirectory() string { return viper.GetString("wd") } - -func Write() error { - return viper.WriteConfig() -} diff --git a/internal/config/config_test.go b/internal/config/config_test.go deleted file mode 100644 index 9111aa0fa667de52a1fc30586c5a9c2425adf81d..0000000000000000000000000000000000000000 --- a/internal/config/config_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package config - -import ( - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/spf13/viper" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestLoad(t *testing.T) { - setupTest(t) - - t.Run("loads configuration successfully", func(t *testing.T) { - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - - configContent := `{ - "data": { - "directory": "custom-dir" - }, - "log": { - "level": "debug" - }, - "mcpServers": { - "test-server": { - "command": "test-command", - "env": ["TEST_ENV=value"], - "args": ["--arg1", "--arg2"], - "type": "stdio", - "url": "", - "headers": {} - }, - "sse-server": { - "command": "", - "env": [], - "args": [], - "type": "sse", - "url": "https://api.example.com/events", - "headers": { - "Authorization": "Bearer token123", - "Content-Type": "application/json" - } - } - }, - "providers": { - "anthropic": { - "apiKey": "test-api-key", - "enabled": true - } - }, - "model": { - "coder": "claude-3-haiku", - "task": "claude-3-haiku" - } - }` - err := os.WriteFile(configPath, []byte(configContent), 0o644) - require.NoError(t, err) - - cfg = nil - viper.Reset() - - err = Load(false) - require.NoError(t, err) - - config := Get() - assert.NotNil(t, config) - assert.Equal(t, "custom-dir", config.Data.Directory) - assert.Equal(t, "debug", config.Log.Level) - - assert.Contains(t, config.MCPServers, "test-server") - stdioServer := config.MCPServers["test-server"] - assert.Equal(t, "test-command", stdioServer.Command) - assert.Equal(t, []string{"TEST_ENV=value"}, stdioServer.Env) - assert.Equal(t, []string{"--arg1", "--arg2"}, stdioServer.Args) - assert.Equal(t, MCPStdio, stdioServer.Type) - assert.Equal(t, "", stdioServer.URL) - assert.Empty(t, stdioServer.Headers) - - assert.Contains(t, config.MCPServers, "sse-server") - sseServer := config.MCPServers["sse-server"] - assert.Equal(t, "", sseServer.Command) - assert.Empty(t, sseServer.Env) - assert.Empty(t, sseServer.Args) - assert.Equal(t, MCPSse, sseServer.Type) - assert.Equal(t, "https://api.example.com/events", sseServer.URL) - assert.Equal(t, map[string]string{ - "authorization": "Bearer token123", - "content-type": "application/json", - }, sseServer.Headers) - - assert.Contains(t, config.Providers, models.ModelProvider("anthropic")) - provider := config.Providers[models.ModelProvider("anthropic")] - assert.Equal(t, "test-api-key", provider.APIKey) - assert.True(t, provider.Enabled) - - assert.NotNil(t, config.Model) - assert.Equal(t, models.Claude3Haiku, config.Model.Coder) - assert.Equal(t, models.Claude3Haiku, config.Model.Task) - assert.Equal(t, defaultMaxTokens, config.Model.CoderMaxTokens) - }) - - t.Run("loads configuration with environment variables", func(t *testing.T) { - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - err := os.WriteFile(configPath, []byte("{}"), 0o644) - require.NoError(t, err) - - t.Setenv("ANTHROPIC_API_KEY", "env-anthropic-key") - t.Setenv("OPENAI_API_KEY", "env-openai-key") - t.Setenv("GEMINI_API_KEY", "env-gemini-key") - - cfg = nil - viper.Reset() - - err = Load(false) - require.NoError(t, err) - - config := Get() - assert.NotNil(t, config) - - assert.Equal(t, defaultDataDirectory, config.Data.Directory) - assert.Equal(t, defaultLogLevel, config.Log.Level) - - assert.Contains(t, config.Providers, models.ModelProvider("anthropic")) - assert.Equal(t, "env-anthropic-key", config.Providers[models.ModelProvider("anthropic")].APIKey) - assert.True(t, config.Providers[models.ModelProvider("anthropic")].Enabled) - - assert.Contains(t, config.Providers, models.ModelProvider("openai")) - assert.Equal(t, "env-openai-key", config.Providers[models.ModelProvider("openai")].APIKey) - assert.True(t, config.Providers[models.ModelProvider("openai")].Enabled) - - assert.Contains(t, config.Providers, models.ModelProvider("gemini")) - assert.Equal(t, "env-gemini-key", config.Providers[models.ModelProvider("gemini")].APIKey) - assert.True(t, config.Providers[models.ModelProvider("gemini")].Enabled) - - assert.Equal(t, models.Claude37Sonnet, config.Model.Coder) - }) - - t.Run("local config overrides global config", func(t *testing.T) { - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - globalConfigPath := filepath.Join(homeDir, ".termai.json") - globalConfig := `{ - "data": { - "directory": "global-dir" - }, - "log": { - "level": "info" - } - }` - err := os.WriteFile(globalConfigPath, []byte(globalConfig), 0o644) - require.NoError(t, err) - - workDir := t.TempDir() - origDir, err := os.Getwd() - require.NoError(t, err) - defer os.Chdir(origDir) - err = os.Chdir(workDir) - require.NoError(t, err) - - localConfigPath := filepath.Join(workDir, ".termai.json") - localConfig := `{ - "data": { - "directory": "local-dir" - }, - "log": { - "level": "debug" - } - }` - err = os.WriteFile(localConfigPath, []byte(localConfig), 0o644) - require.NoError(t, err) - - cfg = nil - viper.Reset() - - err = Load(false) - require.NoError(t, err) - - config := Get() - assert.NotNil(t, config) - - assert.Equal(t, "local-dir", config.Data.Directory) - assert.Equal(t, "debug", config.Log.Level) - }) - - t.Run("missing config file should not return error", func(t *testing.T) { - emptyDir := t.TempDir() - t.Setenv("HOME", emptyDir) - - cfg = nil - viper.Reset() - - err := Load(false) - assert.NoError(t, err) - }) - - t.Run("model priority and fallbacks", func(t *testing.T) { - testCases := []struct { - name string - anthropicKey string - openaiKey string - geminiKey string - expectedModel models.ModelID - explicitModel models.ModelID - useExplicitModel bool - }{ - { - name: "anthropic has priority", - anthropicKey: "test-key", - openaiKey: "test-key", - geminiKey: "test-key", - expectedModel: models.Claude37Sonnet, - }, - { - name: "fallback to openai when no anthropic", - anthropicKey: "", - openaiKey: "test-key", - geminiKey: "test-key", - expectedModel: models.GPT41, - }, - { - name: "fallback to gemini when no others", - anthropicKey: "", - openaiKey: "", - geminiKey: "test-key", - expectedModel: models.GRMINI20Flash, - }, - { - name: "explicit model overrides defaults", - anthropicKey: "test-key", - openaiKey: "test-key", - geminiKey: "test-key", - explicitModel: models.GPT41, - useExplicitModel: true, - expectedModel: models.GPT41, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - - configContent := "{}" - if tc.useExplicitModel { - configContent = fmt.Sprintf(`{"model":{"coder":"%s"}}`, tc.explicitModel) - } - - err := os.WriteFile(configPath, []byte(configContent), 0o644) - require.NoError(t, err) - - if tc.anthropicKey != "" { - t.Setenv("ANTHROPIC_API_KEY", tc.anthropicKey) - } else { - t.Setenv("ANTHROPIC_API_KEY", "") - } - - if tc.openaiKey != "" { - t.Setenv("OPENAI_API_KEY", tc.openaiKey) - } else { - t.Setenv("OPENAI_API_KEY", "") - } - - if tc.geminiKey != "" { - t.Setenv("GEMINI_API_KEY", tc.geminiKey) - } else { - t.Setenv("GEMINI_API_KEY", "") - } - - cfg = nil - viper.Reset() - - err = Load(false) - require.NoError(t, err) - - config := Get() - assert.NotNil(t, config) - assert.Equal(t, tc.expectedModel, config.Model.Coder) - }) - } - }) -} - -func TestGet(t *testing.T) { - t.Run("get returns same config instance", func(t *testing.T) { - setupTest(t) - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - err := os.WriteFile(configPath, []byte("{}"), 0o644) - require.NoError(t, err) - - cfg = nil - viper.Reset() - - config1 := Get() - require.NotNil(t, config1) - - config2 := Get() - require.NotNil(t, config2) - - assert.Same(t, config1, config2) - }) - - t.Run("get loads config if not loaded", func(t *testing.T) { - setupTest(t) - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - configContent := `{"data":{"directory":"test-dir"}}` - err := os.WriteFile(configPath, []byte(configContent), 0o644) - require.NoError(t, err) - - cfg = nil - viper.Reset() - - config := Get() - require.NotNil(t, config) - assert.Equal(t, "test-dir", config.Data.Directory) - }) -} - -func TestWorkingDirectory(t *testing.T) { - t.Run("returns current working directory", func(t *testing.T) { - setupTest(t) - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - err := os.WriteFile(configPath, []byte("{}"), 0o644) - require.NoError(t, err) - - cfg = nil - viper.Reset() - - err = Load(false) - require.NoError(t, err) - - wd := WorkingDirectory() - expectedWd, err := os.Getwd() - require.NoError(t, err) - assert.Equal(t, expectedWd, wd) - }) -} - -func TestWrite(t *testing.T) { - t.Run("writes config to file", func(t *testing.T) { - setupTest(t) - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - err := os.WriteFile(configPath, []byte("{}"), 0o644) - require.NoError(t, err) - - cfg = nil - viper.Reset() - - err = Load(false) - require.NoError(t, err) - - viper.Set("data.directory", "modified-dir") - - err = Write() - require.NoError(t, err) - - content, err := os.ReadFile(configPath) - require.NoError(t, err) - assert.Contains(t, string(content), "modified-dir") - }) -} - -func TestMCPType(t *testing.T) { - t.Run("MCPType constants", func(t *testing.T) { - assert.Equal(t, MCPType("stdio"), MCPStdio) - assert.Equal(t, MCPType("sse"), MCPSse) - }) - - t.Run("MCPType JSON unmarshaling", func(t *testing.T) { - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - - configContent := `{ - "mcpServers": { - "stdio-server": { - "type": "stdio" - }, - "sse-server": { - "type": "sse" - }, - "invalid-server": { - "type": "invalid" - } - } - }` - err := os.WriteFile(configPath, []byte(configContent), 0o644) - require.NoError(t, err) - - cfg = nil - viper.Reset() - - err = Load(false) - require.NoError(t, err) - - config := Get() - assert.NotNil(t, config) - - assert.Equal(t, MCPStdio, config.MCPServers["stdio-server"].Type) - assert.Equal(t, MCPSse, config.MCPServers["sse-server"].Type) - assert.Equal(t, MCPType("invalid"), config.MCPServers["invalid-server"].Type) - }) - - t.Run("default MCPType", func(t *testing.T) { - homeDir := t.TempDir() - t.Setenv("HOME", homeDir) - configPath := filepath.Join(homeDir, ".termai.json") - - configContent := `{ - "mcpServers": { - "test-server": { - "command": "test-command" - } - } - }` - err := os.WriteFile(configPath, []byte(configContent), 0o644) - require.NoError(t, err) - - cfg = nil - viper.Reset() - - err = Load(false) - require.NoError(t, err) - - config := Get() - assert.NotNil(t, config) - - assert.Equal(t, MCPType(""), config.MCPServers["test-server"].Type) - }) -} - -func setupTest(t *testing.T) { - origHome := os.Getenv("HOME") - origXdgConfigHome := os.Getenv("XDG_CONFIG_HOME") - origAnthropicKey := os.Getenv("ANTHROPIC_API_KEY") - origOpenAIKey := os.Getenv("OPENAI_API_KEY") - origGeminiKey := os.Getenv("GEMINI_API_KEY") - - t.Cleanup(func() { - t.Setenv("HOME", origHome) - t.Setenv("XDG_CONFIG_HOME", origXdgConfigHome) - t.Setenv("ANTHROPIC_API_KEY", origAnthropicKey) - t.Setenv("OPENAI_API_KEY", origOpenAIKey) - t.Setenv("GEMINI_API_KEY", origGeminiKey) - - cfg = nil - viper.Reset() - }) -} From 3ad983db0f2c08826d56cb5de274d706c95b3353 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sun, 13 Apr 2025 13:17:17 +0200 Subject: [PATCH 08/41] cleanup app, config and root --- .gitignore | 2 +- .termai.json => .opencode.json | 0 cmd/git/main.go | 4 - cmd/root.go | 253 +++++++++++++++++------ internal/app/app.go | 76 +++++++ internal/app/lsp.go | 108 ++++++++++ internal/app/services.go | 64 ------ internal/config/config.go | 20 +- internal/history/file.go | 73 ++++--- internal/llm/agent/agent-tool.go | 10 +- internal/llm/agent/agent.go | 53 ++--- internal/llm/agent/coder.go | 5 +- internal/llm/agent/task.go | 3 +- internal/message/message.go | 46 ++--- internal/session/session.go | 44 ++-- internal/tui/components/chat/messages.go | 5 +- internal/tui/components/repl/editor.go | 4 +- internal/tui/components/repl/messages.go | 7 +- internal/tui/components/repl/sessions.go | 4 +- internal/tui/page/chat.go | 8 +- internal/tui/tui.go | 6 +- 21 files changed, 514 insertions(+), 281 deletions(-) rename .termai.json => .opencode.json (100%) delete mode 100644 cmd/git/main.go create mode 100644 internal/app/app.go create mode 100644 internal/app/lsp.go delete mode 100644 internal/app/services.go diff --git a/.gitignore b/.gitignore index 388f8b2cac6bb87ad9917595e84d5d98f007ad54..0ef6e2aefaf866005e2c8cd8376c9ff447210ff5 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,6 @@ debug.log .env .env.local -.termai +.opencode internal/assets/diff/index.mjs diff --git a/.termai.json b/.opencode.json similarity index 100% rename from .termai.json rename to .opencode.json diff --git a/cmd/git/main.go b/cmd/git/main.go deleted file mode 100644 index da29a2cadf1e00b14b1a4bd0a52780888bf3e532..0000000000000000000000000000000000000000 --- a/cmd/git/main.go +++ /dev/null @@ -1,4 +0,0 @@ -package main - -func main() { -} diff --git a/cmd/root.go b/cmd/root.go index d846a14c263790d9c27df4ad8aa151cc555c2e53..092606de7ff6ad1dfde4f9238f5199ceca5333e4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,9 +2,10 @@ package cmd import ( "context" - "log/slog" + "fmt" "os" "sync" + "time" tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" @@ -13,6 +14,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/tui" zone "github.com/lrstanley/bubblezone" "github.com/spf13/cobra" @@ -23,111 +25,229 @@ var rootCmd = &cobra.Command{ Short: "A terminal ai assistant", Long: `A terminal ai assistant`, RunE: func(cmd *cobra.Command, args []string) error { + // If the help flag is set, show the help message if cmd.Flag("help").Changed { cmd.Help() return nil } + + // Load the config debug, _ := cmd.Flags().GetBool("debug") - err := config.Load(debug) + cwd, _ := cmd.Flags().GetString("cwd") + if cwd != "" { + err := os.Chdir(cwd) + if err != nil { + return fmt.Errorf("failed to change directory: %v", err) + } + } + if cwd == "" { + c, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get current working directory: %v", err) + } + cwd = c + } + _, err := config.Load(cwd, debug) if err != nil { return err } - cfg := config.Get() - defaultLevel := slog.LevelInfo - if cfg.Debug { - defaultLevel = slog.LevelDebug - } - logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ - Level: defaultLevel, - })) - slog.SetDefault(logger) err = assets.WriteAssets() if err != nil { - return err + logging.Error("Error writing assets: %v", err) } + // Connect DB, this will also run migrations conn, err := db.Connect() if err != nil { return err } - ctx := context.Background() + + // Create main context for the application + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() app := app.New(ctx, conn) - logging.Info("Starting termai...") + + // Set up the TUI zone.NewGlobal() - tui := tea.NewProgram( + program := tea.NewProgram( tui.New(app), tea.WithAltScreen(), tea.WithMouseCellMotion(), ) - logging.Info("Setting up subscriptions...") - ch, unsub := setupSubscriptions(app) - defer unsub() + // Initialize MCP tools in the background + initMCPTools(ctx, app) + + // Setup the subscriptions, this will send services events to the TUI + ch, cancelSubs := setupSubscriptions(app) + + // Create a context for the TUI message handler + tuiCtx, tuiCancel := context.WithCancel(ctx) + var tuiWg sync.WaitGroup + tuiWg.Add(1) + + // Set up message handling for the TUI go func() { - // Set this up once - agent.GetMcpTools(ctx, app.Permissions) - for msg := range ch { - tui.Send(msg) + defer tuiWg.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("Panic in TUI message handling: %v", r) + attemptTUIRecovery(program) + } + }() + + for { + select { + case <-tuiCtx.Done(): + logging.Info("TUI message handler shutting down") + return + case msg, ok := <-ch: + if !ok { + logging.Info("TUI message channel closed") + return + } + program.Send(msg) + } } }() - if _, err := tui.Run(); err != nil { - return err + + // Cleanup function for when the program exits + cleanup := func() { + // Shutdown the app + app.Shutdown() + + // Cancel subscriptions first + cancelSubs() + + // Then cancel TUI message handler + tuiCancel() + + // Wait for TUI message handler to finish + tuiWg.Wait() + + logging.Info("All goroutines cleaned up") + } + + // Run the TUI + result, err := program.Run() + cleanup() + + if err != nil { + logging.Error("TUI error: %v", err) + return fmt.Errorf("TUI error: %v", err) } + + logging.Info("TUI exited with result: %v", result) return nil }, } -func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { - ch := make(chan tea.Msg) - wg := sync.WaitGroup{} - ctx, cancel := context.WithCancel(app.Context) - { - sub := logging.Subscribe(ctx) - wg.Add(1) - go func() { - for ev := range sub { - ch <- ev +// attemptTUIRecovery tries to recover the TUI after a panic +func attemptTUIRecovery(program *tea.Program) { + logging.Info("Attempting to recover TUI after panic") + + // We could try to restart the TUI or gracefully exit + // For now, we'll just quit the program to avoid further issues + program.Quit() +} + +func initMCPTools(ctx context.Context, app *app.App) { + go func() { + defer func() { + if r := recover(); r != nil { + logging.Error("Panic in MCP goroutine: %v", r) } - wg.Done() }() - } - { - sub := app.Sessions.Subscribe(ctx) - wg.Add(1) - go func() { - for ev := range sub { - ch <- ev + + // Create a context with timeout for the initial MCP tools fetch + ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // Set this up once with proper error handling + agent.GetMcpTools(ctxWithTimeout, app.Permissions) + logging.Info("MCP message handling goroutine exiting") + }() +} + +func setupSubscriber[T any]( + ctx context.Context, + wg *sync.WaitGroup, + name string, + subscriber func(context.Context) <-chan pubsub.Event[T], + outputCh chan<- tea.Msg, +) { + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("Panic in %s subscription goroutine: %v", name, r) } - wg.Done() }() - } - { - sub := app.Messages.Subscribe(ctx) - wg.Add(1) - go func() { - for ev := range sub { - ch <- ev + + for { + select { + case event, ok := <-subscriber(ctx): + if !ok { + logging.Info("%s subscription channel closed", name) + return + } + + // Convert generic event to tea.Msg if needed + var msg tea.Msg = event + + // Non-blocking send with timeout to prevent deadlocks + select { + case outputCh <- msg: + case <-time.After(500 * time.Millisecond): + logging.Warn("%s message dropped due to slow consumer", name) + case <-ctx.Done(): + logging.Info("%s subscription cancelled", name) + return + } + case <-ctx.Done(): + logging.Info("%s subscription cancelled", name) + return } - wg.Done() - }() - } - { - sub := app.Permissions.Subscribe(ctx) - wg.Add(1) + } + }() +} + +func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { + ch := make(chan tea.Msg, 100) + // Add a buffer to prevent blocking + wg := sync.WaitGroup{} + ctx, cancel := context.WithCancel(context.Background()) + // Setup each subscription using the helper + setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch) + setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch) + setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch) + setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch) + + // Return channel and a cleanup function + cleanupFunc := func() { + logging.Info("Cancelling all subscriptions") + cancel() // Signal all goroutines to stop + + // Wait with a timeout for all goroutines to complete + waitCh := make(chan struct{}) go func() { - for ev := range sub { - ch <- ev - } - wg.Done() + wg.Wait() + close(waitCh) }() + + select { + case <-waitCh: + logging.Info("All subscription goroutines completed successfully") + case <-time.After(5 * time.Second): + logging.Warn("Timed out waiting for some subscription goroutines to complete") + } + + close(ch) // Safe to close after all writers are done or timed out } - return ch, func() { - cancel() - wg.Wait() - close(ch) - } + return ch, cleanupFunc } func Execute() { @@ -139,5 +259,6 @@ func Execute() { func init() { rootCmd.Flags().BoolP("help", "h", false, "Help") - rootCmd.Flags().BoolP("debug", "d", false, "Help") + rootCmd.Flags().BoolP("debug", "d", false, "Debug") + rootCmd.Flags().StringP("cwd", "c", "", "Current working directory") } diff --git a/internal/app/app.go b/internal/app/app.go new file mode 100644 index 0000000000000000000000000000000000000000..fa4a6ee9008b051f51a9f7b73195592ab93531bf --- /dev/null +++ b/internal/app/app.go @@ -0,0 +1,76 @@ +package app + +import ( + "context" + "database/sql" + "maps" + "sync" + "time" + + "github.com/kujtimiihoxha/termai/internal/db" + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" +) + +type App struct { + Sessions session.Service + Messages message.Service + Files history.Service + Permissions permission.Service + + LSPClients map[string]*lsp.Client + + clientsMutex sync.RWMutex + + watcherCancelFuncs []context.CancelFunc + cancelFuncsMutex sync.Mutex + watcherWG sync.WaitGroup +} + +func New(ctx context.Context, conn *sql.DB) *App { + q := db.New(conn) + sessions := session.NewService(q) + messages := message.NewService(q) + files := history.NewService(q) + + app := &App{ + Sessions: sessions, + Messages: messages, + Files: files, + Permissions: permission.NewPermissionService(), + LSPClients: make(map[string]*lsp.Client), + } + + app.initLSPClients(ctx) + + return app +} + +// Shutdown performs a clean shutdown of the application +func (app *App) Shutdown() { + // Cancel all watcher goroutines + app.cancelFuncsMutex.Lock() + for _, cancel := range app.watcherCancelFuncs { + cancel() + } + app.cancelFuncsMutex.Unlock() + app.watcherWG.Wait() + + // Perform additional cleanup for LSP clients + app.clientsMutex.RLock() + clients := make(map[string]*lsp.Client, len(app.LSPClients)) + maps.Copy(clients, app.LSPClients) + app.clientsMutex.RUnlock() + + for name, client := range clients { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := client.Shutdown(shutdownCtx); err != nil { + logging.Error("Failed to shutdown LSP client", "name", name, "error", err) + } + cancel() + } +} diff --git a/internal/app/lsp.go b/internal/app/lsp.go new file mode 100644 index 0000000000000000000000000000000000000000..4e0568f071f71031f898361748c1516b23d756fe --- /dev/null +++ b/internal/app/lsp.go @@ -0,0 +1,108 @@ +package app + +import ( + "context" + "time" + + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/lsp/watcher" +) + +func (app *App) initLSPClients(ctx context.Context) { + cfg := config.Get() + + // Initialize LSP clients + for name, clientConfig := range cfg.LSP { + app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...) + } +} + +// createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher +func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) { + // Create a specific context for initialization with a timeout + initCtx, initCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer initCancel() + + // Create the LSP client + lspClient, err := lsp.NewClient(initCtx, command, args...) + if err != nil { + logging.Error("Failed to create LSP client for", name, err) + return + } + + // Initialize with the initialization context + _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory()) + if err != nil { + logging.Error("Initialize failed", "name", name, "error", err) + // Clean up the client to prevent resource leaks + lspClient.Close() + return + } + + // Create a child context that can be canceled when the app is shutting down + watchCtx, cancelFunc := context.WithCancel(ctx) + workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) + + // Store the cancel function to be called during cleanup + app.cancelFuncsMutex.Lock() + app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc) + app.cancelFuncsMutex.Unlock() + + // Add the watcher to a WaitGroup to track active goroutines + app.watcherWG.Add(1) + + // Add to map with mutex protection before starting goroutine + app.clientsMutex.Lock() + app.LSPClients[name] = lspClient + app.clientsMutex.Unlock() + + go app.runWorkspaceWatcher(watchCtx, name, workspaceWatcher) +} + +// runWorkspaceWatcher executes the workspace watcher for an LSP client +func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceWatcher *watcher.WorkspaceWatcher) { + defer app.watcherWG.Done() + defer func() { + if r := recover(); r != nil { + logging.Error("LSP client crashed", "client", name, "panic", r) + + // Try to restart the client + app.restartLSPClient(ctx, name) + } + }() + + workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) + logging.Info("Workspace watcher stopped", "client", name) +} + +// restartLSPClient attempts to restart a crashed or failed LSP client +func (app *App) restartLSPClient(ctx context.Context, name string) { + // Get the original configuration + cfg := config.Get() + clientConfig, exists := cfg.LSP[name] + if !exists { + logging.Error("Cannot restart client, configuration not found", "client", name) + return + } + + // Clean up the old client if it exists + app.clientsMutex.Lock() + oldClient, exists := app.LSPClients[name] + if exists { + delete(app.LSPClients, name) // Remove from map before potentially slow shutdown + } + app.clientsMutex.Unlock() + + if exists && oldClient != nil { + // Try to shut it down gracefully, but don't block on errors + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = oldClient.Shutdown(shutdownCtx) + cancel() + } + + // Create a new client using the shared function + app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...) + logging.Info("Successfully restarted LSP client", "client", name) +} diff --git a/internal/app/services.go b/internal/app/services.go deleted file mode 100644 index 6ecdef03c2f7d7766431aa80e4e9eeaf9f154a0f..0000000000000000000000000000000000000000 --- a/internal/app/services.go +++ /dev/null @@ -1,64 +0,0 @@ -package app - -import ( - "context" - "database/sql" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/watcher" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type App struct { - Context context.Context - - Sessions session.Service - Messages message.Service - Files history.Service - Permissions permission.Service - - LSPClients map[string]*lsp.Client -} - -func New(ctx context.Context, conn *sql.DB) *App { - cfg := config.Get() - logging.Info("Debug mode enabled") - - q := db.New(conn) - sessions := session.NewService(ctx, q) - messages := message.NewService(ctx, q) - files := history.NewService(ctx, q) - - app := &App{ - Context: ctx, - Sessions: sessions, - Messages: messages, - Files: files, - Permissions: permission.NewPermissionService(), - LSPClients: make(map[string]*lsp.Client), - } - - for name, client := range cfg.LSP { - lspClient, err := lsp.NewClient(ctx, client.Command, client.Args...) - workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) - if err != nil { - logging.Error("Failed to create LSP client for", name, err) - continue - } - - _, err = lspClient.InitializeLSPClient(ctx, config.WorkingDirectory()) - if err != nil { - logging.Error("Initialize failed", "error", err) - continue - } - go workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) - app.LSPClients[name] = lspClient - } - return app -} diff --git a/internal/config/config.go b/internal/config/config.go index 6f757b3f48957185c9b256ea7b6dc1b07986a62d..1f3091ff3ea3df6b7c569fe013c7c15933b3c7e4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -83,9 +83,9 @@ var cfg *Config // Load initializes the configuration from environment variables and config files. // If debug is true, debug mode is enabled and log level is set to debug. // It returns an error if configuration loading fails. -func Load(workingDir string, debug bool) error { +func Load(workingDir string, debug bool) (*Config, error) { if cfg != nil { - return nil + return cfg, nil } cfg = &Config{ @@ -101,7 +101,7 @@ func Load(workingDir string, debug bool) error { // Read global config if err := readConfig(viper.ReadInConfig()); err != nil { - return err + return cfg, err } // Load and merge local config @@ -109,7 +109,7 @@ func Load(workingDir string, debug bool) error { // Apply configuration to the struct if err := viper.Unmarshal(cfg); err != nil { - return err + return cfg, fmt.Errorf("failed to unmarshal config: %w", err) } applyDefaultValues() @@ -123,7 +123,7 @@ func Load(workingDir string, debug bool) error { Level: defaultLevel, })) slog.SetDefault(logger) - return nil + return cfg, nil } // configureViper sets up viper's configuration paths and environment variables. @@ -237,7 +237,7 @@ func readConfig(err error) error { return nil } - return err + return fmt.Errorf("failed to read config: %w", err) } // mergeLocalConfig loads and merges configuration from the local directory. @@ -264,14 +264,6 @@ func applyDefaultValues() { } } -// setWorkingDirectory stores the current working directory in the configuration. -func setWorkingDirectory() { - workdir, err := os.Getwd() - if err == nil { - viper.Set("wd", workdir) - } -} - // Get returns the current configuration. // It's safe to call this function multiple times. func Get() *Config { diff --git a/internal/history/file.go b/internal/history/file.go index 25953b27325bc1266d993850337448097901fd02..82017d4cf84c855158cf934fc3799a0e7c18762f 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -27,45 +27,43 @@ type File struct { type Service interface { pubsub.Suscriber[File] - Create(sessionID, path, content string) (File, error) - CreateVersion(sessionID, path, content string) (File, error) - Get(id string) (File, error) - GetByPathAndSession(path, sessionID string) (File, error) - ListBySession(sessionID string) ([]File, error) - ListLatestSessionFiles(sessionID string) ([]File, error) - Update(file File) (File, error) - Delete(id string) error - DeleteSessionFiles(sessionID string) error + Create(ctx context.Context, sessionID, path, content string) (File, error) + CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) + Get(ctx context.Context, id string) (File, error) + GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) + ListBySession(ctx context.Context, sessionID string) ([]File, error) + ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) + Update(ctx context.Context, file File) (File, error) + Delete(ctx context.Context, id string) error + DeleteSessionFiles(ctx context.Context, sessionID string) error } type service struct { *pubsub.Broker[File] - q db.Querier - ctx context.Context + q db.Querier } -func NewService(ctx context.Context, q db.Querier) Service { +func NewService(q db.Querier) Service { return &service{ Broker: pubsub.NewBroker[File](), q: q, - ctx: ctx, } } -func (s *service) Create(sessionID, path, content string) (File, error) { - return s.createWithVersion(sessionID, path, content, InitialVersion) +func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) { + return s.createWithVersion(ctx, sessionID, path, content, InitialVersion) } -func (s *service) CreateVersion(sessionID, path, content string) (File, error) { +func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) { // Get the latest version for this path - files, err := s.q.ListFilesByPath(s.ctx, path) + files, err := s.q.ListFilesByPath(ctx, path) if err != nil { return File{}, err } if len(files) == 0 { // No previous versions, create initial - return s.Create(sessionID, path, content) + return s.Create(ctx, sessionID, path, content) } // Get the latest version @@ -89,11 +87,11 @@ func (s *service) CreateVersion(sessionID, path, content string) (File, error) { nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) } - return s.createWithVersion(sessionID, path, content, nextVersion) + return s.createWithVersion(ctx, sessionID, path, content, nextVersion) } -func (s *service) createWithVersion(sessionID, path, content, version string) (File, error) { - dbFile, err := s.q.CreateFile(s.ctx, db.CreateFileParams{ +func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string) (File, error) { + dbFile, err := s.q.CreateFile(ctx, db.CreateFileParams{ ID: uuid.New().String(), SessionID: sessionID, Path: path, @@ -108,16 +106,16 @@ func (s *service) createWithVersion(sessionID, path, content, version string) (F return file, nil } -func (s *service) Get(id string) (File, error) { - dbFile, err := s.q.GetFile(s.ctx, id) +func (s *service) Get(ctx context.Context, id string) (File, error) { + dbFile, err := s.q.GetFile(ctx, id) if err != nil { return File{}, err } return s.fromDBItem(dbFile), nil } -func (s *service) GetByPathAndSession(path, sessionID string) (File, error) { - dbFile, err := s.q.GetFileByPathAndSession(s.ctx, db.GetFileByPathAndSessionParams{ +func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) { + dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{ Path: path, SessionID: sessionID, }) @@ -127,8 +125,8 @@ func (s *service) GetByPathAndSession(path, sessionID string) (File, error) { return s.fromDBItem(dbFile), nil } -func (s *service) ListBySession(sessionID string) ([]File, error) { - dbFiles, err := s.q.ListFilesBySession(s.ctx, sessionID) +func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) { + dbFiles, err := s.q.ListFilesBySession(ctx, sessionID) if err != nil { return nil, err } @@ -139,8 +137,8 @@ func (s *service) ListBySession(sessionID string) ([]File, error) { return files, nil } -func (s *service) ListLatestSessionFiles(sessionID string) ([]File, error) { - dbFiles, err := s.q.ListLatestSessionFiles(s.ctx, sessionID) +func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) { + dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID) if err != nil { return nil, err } @@ -151,8 +149,8 @@ func (s *service) ListLatestSessionFiles(sessionID string) ([]File, error) { return files, nil } -func (s *service) Update(file File) (File, error) { - dbFile, err := s.q.UpdateFile(s.ctx, db.UpdateFileParams{ +func (s *service) Update(ctx context.Context, file File) (File, error) { + dbFile, err := s.q.UpdateFile(ctx, db.UpdateFileParams{ ID: file.ID, Content: file.Content, Version: file.Version, @@ -165,12 +163,12 @@ func (s *service) Update(file File) (File, error) { return updatedFile, nil } -func (s *service) Delete(id string) error { - file, err := s.Get(id) +func (s *service) Delete(ctx context.Context, id string) error { + file, err := s.Get(ctx, id) if err != nil { return err } - err = s.q.DeleteFile(s.ctx, id) + err = s.q.DeleteFile(ctx, id) if err != nil { return err } @@ -178,13 +176,13 @@ func (s *service) Delete(id string) error { return nil } -func (s *service) DeleteSessionFiles(sessionID string) error { - files, err := s.ListBySession(sessionID) +func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error { + files, err := s.ListBySession(ctx, sessionID) if err != nil { return err } for _, file := range files { - err = s.Delete(file.ID) + err = s.Delete(ctx, file.ID) if err != nil { return err } @@ -203,4 +201,3 @@ func (s *service) fromDBItem(item db.File) File { UpdatedAt: item.UpdatedAt, } } - diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index deb6aed608625816b6f9715f33a823a99c337d47..91c46da8b0bfe7b44e9c1ff65fd3eb30790ec9f7 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -51,7 +51,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil } - session, err := b.app.Sessions.CreateTaskSession(call.ID, b.parentSessionID, "New Agent Session") + session, err := b.app.Sessions.CreateTaskSession(ctx, call.ID, b.parentSessionID, "New Agent Session") if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil } @@ -61,7 +61,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil } - messages, err := b.app.Messages.List(session.ID) + messages, err := b.app.Messages.List(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil } @@ -74,11 +74,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("no assistant message found"), nil } - updatedSession, err := b.app.Sessions.Get(session.ID) + updatedSession, err := b.app.Sessions.Get(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } - parentSession, err := b.app.Sessions.Get(b.parentSessionID) + parentSession, err := b.app.Sessions.Get(ctx, b.parentSessionID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } @@ -87,7 +87,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes parentSession.PromptTokens += updatedSession.PromptTokens parentSession.CompletionTokens += updatedSession.CompletionTokens - _, err = b.app.Sessions.Save(parentSession) + _, err = b.app.Sessions.Save(ctx, parentSession) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 89de627f7ebdaa42b43cbdfb8b610474768ab8a9..b7c736e6c38f5cfa6f4029fac414fe65b494b9d3 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -48,7 +48,7 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st return } - session, err := c.Sessions.Get(sessionID) + session, err := c.Sessions.Get(ctx, sessionID) if err != nil { return } @@ -56,12 +56,12 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st session.Title = response.Content session.Title = strings.TrimSpace(session.Title) session.Title = strings.ReplaceAll(session.Title, "\n", " ") - c.Sessions.Save(session) + c.Sessions.Save(ctx, session) } } -func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := c.Sessions.Get(sessionID) +func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + session, err := c.Sessions.Get(ctx, sessionID) if err != nil { return err } @@ -75,11 +75,12 @@ func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider. session.CompletionTokens += usage.OutputTokens session.PromptTokens += usage.InputTokens - _, err = c.Sessions.Save(session) + _, err = c.Sessions.Save(ctx, session) return err } func (c *agent) processEvent( + ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent, @@ -87,10 +88,10 @@ func (c *agent) processEvent( switch event.Type { case provider.EventThinkingDelta: assistantMsg.AppendReasoningContent(event.Content) - return c.Messages.Update(*assistantMsg) + return c.Messages.Update(ctx, *assistantMsg) case provider.EventContentDelta: assistantMsg.AppendContent(event.Content) - return c.Messages.Update(*assistantMsg) + return c.Messages.Update(ctx, *assistantMsg) case provider.EventError: if errors.Is(event.Error, context.Canceled) { return nil @@ -105,11 +106,11 @@ func (c *agent) processEvent( case provider.EventComplete: assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.AddFinish(event.Response.FinishReason) - err := c.Messages.Update(*assistantMsg) + err := c.Messages.Update(ctx, *assistantMsg) if err != nil { return err } - return c.TrackUsage(sessionID, c.model, event.Response.Usage) + return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage) } return nil @@ -237,7 +238,7 @@ func (c *agent) handleToolExecution( for _, toolResult := range toolResults { parts = append(parts, toolResult) } - msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{ + msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) @@ -247,7 +248,7 @@ func (c *agent) handleToolExecution( func (c *agent) generate(ctx context.Context, sessionID string, content string) error { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - messages, err := c.Messages.List(sessionID) + messages, err := c.Messages.List(ctx, sessionID) if err != nil { return err } @@ -256,7 +257,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) go c.handleTitleGeneration(ctx, sessionID, content) } - userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.User, Parts: []message.ContentPart{ message.TextContent{ @@ -272,7 +273,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) for { select { case <-ctx.Done(): - assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, }) @@ -280,7 +281,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) return err } assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled default: // Continue processing @@ -289,7 +290,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools) if err != nil { if errors.Is(err, context.Canceled) { - assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, }) @@ -297,13 +298,13 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) return err } assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled } return err } - assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, Model: c.model.ID, @@ -314,22 +315,22 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) for event := range eventChan { - err = c.processEvent(sessionID, &assistantMsg, event) + err = c.processEvent(ctx, sessionID, &assistantMsg, event) if err != nil { if errors.Is(err, context.Canceled) { assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled } assistantMsg.AddFinish("error:" + err.Error()) - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return err } select { case <-ctx.Done(): assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled default: } @@ -339,7 +340,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) select { case <-ctx.Done(): assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled default: // Continue processing @@ -349,13 +350,13 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) if err != nil { if errors.Is(err, context.Canceled) { assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled } return err } - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) if len(assistantMsg.ToolCalls()) == 0 { break @@ -370,7 +371,7 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) select { case <-ctx.Done(): assistantMsg.AddFinish("canceled") - c.Messages.Update(assistantMsg) + c.Messages.Update(ctx, assistantMsg) return context.Canceled default: // Continue processing @@ -383,7 +384,7 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid maxTokens := config.Get().Model.CoderMaxTokens providerConfig, ok := config.Get().Providers[model.Provider] - if !ok || !providerConfig.Enabled { + if !ok || providerConfig.Disabled { return nil, nil, errors.New("provider is not enabled") } var agentProvider provider.Provider diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index 5deff05a8a85e5e8b1d38ece4715695809e19ec6..f8e1c40a084ddf2be3d0684363ddbdc3d91114c6 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -40,12 +40,13 @@ func NewCoderAgent(app *app.App) (Agent, error) { return nil, errors.New("model not supported") } - agentProvider, titleGenerator, err := getAgentProviders(app.Context, model) + ctx := context.Background() + agentProvider, titleGenerator, err := getAgentProviders(ctx, model) if err != nil { return nil, err } - otherTools := GetMcpTools(app.Context, app.Permissions) + otherTools := GetMcpTools(ctx, app.Permissions) if len(app.LSPClients) > 0 { otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients)) } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go index 034e9346003c747884afd12098682397cf80dc46..c196cb107b1926068a8d32b5703809a155303e53 100644 --- a/internal/llm/agent/task.go +++ b/internal/llm/agent/task.go @@ -24,7 +24,8 @@ func NewTaskAgent(app *app.App) (Agent, error) { return nil, errors.New("model not supported") } - agentProvider, titleGenerator, err := getAgentProviders(app.Context, model) + ctx := context.Background() + agentProvider, titleGenerator, err := getAgentProviders(ctx, model) if err != nil { return nil, err } diff --git a/internal/message/message.go b/internal/message/message.go index 06dae13a57a8ae1ca7ede1e5fd22be6bea5ee669..2871780a79f91cca018c6ad1a398c023123310c6 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -20,34 +20,32 @@ type CreateMessageParams struct { type Service interface { pubsub.Suscriber[Message] - Create(sessionID string, params CreateMessageParams) (Message, error) - Update(message Message) error - Get(id string) (Message, error) - List(sessionID string) ([]Message, error) - Delete(id string) error - DeleteSessionMessages(sessionID string) error + Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) + Update(ctx context.Context, message Message) error + Get(ctx context.Context, id string) (Message, error) + List(ctx context.Context, sessionID string) ([]Message, error) + Delete(ctx context.Context, id string) error + DeleteSessionMessages(ctx context.Context, sessionID string) error } type service struct { *pubsub.Broker[Message] - q db.Querier - ctx context.Context + q db.Querier } -func NewService(ctx context.Context, q db.Querier) Service { +func NewService(q db.Querier) Service { return &service{ Broker: pubsub.NewBroker[Message](), q: q, - ctx: ctx, } } -func (s *service) Delete(id string) error { - message, err := s.Get(id) +func (s *service) Delete(ctx context.Context, id string) error { + message, err := s.Get(ctx, id) if err != nil { return err } - err = s.q.DeleteMessage(s.ctx, message.ID) + err = s.q.DeleteMessage(ctx, message.ID) if err != nil { return err } @@ -55,7 +53,7 @@ func (s *service) Delete(id string) error { return nil } -func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) { +func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) { if params.Role != Assistant { params.Parts = append(params.Parts, Finish{ Reason: "stop", @@ -66,7 +64,7 @@ func (s *service) Create(sessionID string, params CreateMessageParams) (Message, return Message{}, err } - dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{ + dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{ ID: uuid.New().String(), SessionID: sessionID, Role: string(params.Role), @@ -84,14 +82,14 @@ func (s *service) Create(sessionID string, params CreateMessageParams) (Message, return message, nil } -func (s *service) DeleteSessionMessages(sessionID string) error { - messages, err := s.List(sessionID) +func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error { + messages, err := s.List(ctx, sessionID) if err != nil { return err } for _, message := range messages { if message.SessionID == sessionID { - err = s.Delete(message.ID) + err = s.Delete(ctx, message.ID) if err != nil { return err } @@ -100,7 +98,7 @@ func (s *service) DeleteSessionMessages(sessionID string) error { return nil } -func (s *service) Update(message Message) error { +func (s *service) Update(ctx context.Context, message Message) error { parts, err := marshallParts(message.Parts) if err != nil { return err @@ -110,7 +108,7 @@ func (s *service) Update(message Message) error { finishedAt.Int64 = f.Time finishedAt.Valid = true } - err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{ + err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{ ID: message.ID, Parts: string(parts), FinishedAt: finishedAt, @@ -122,16 +120,16 @@ func (s *service) Update(message Message) error { return nil } -func (s *service) Get(id string) (Message, error) { - dbMessage, err := s.q.GetMessage(s.ctx, id) +func (s *service) Get(ctx context.Context, id string) (Message, error) { + dbMessage, err := s.q.GetMessage(ctx, id) if err != nil { return Message{}, err } return s.fromDBItem(dbMessage) } -func (s *service) List(sessionID string) ([]Message, error) { - dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID) +func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) { + dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID) if err != nil { return nil, err } diff --git a/internal/session/session.go b/internal/session/session.go index 13f420b7c09ba4f074489c3c0be9b1978fd3b630..9a16224c3b3c9cc84c50d0d4c829a03c8f7de5d8 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -23,22 +23,21 @@ type Session struct { type Service interface { pubsub.Suscriber[Session] - Create(title string) (Session, error) - CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error) - Get(id string) (Session, error) - List() ([]Session, error) - Save(session Session) (Session, error) - Delete(id string) error + Create(ctx context.Context, title string) (Session, error) + CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) + Get(ctx context.Context, id string) (Session, error) + List(ctx context.Context) ([]Session, error) + Save(ctx context.Context, session Session) (Session, error) + Delete(ctx context.Context, id string) error } type service struct { *pubsub.Broker[Session] - q db.Querier - ctx context.Context + q db.Querier } -func (s *service) Create(title string) (Session, error) { - dbSession, err := s.q.CreateSession(s.ctx, db.CreateSessionParams{ +func (s *service) Create(ctx context.Context, title string) (Session, error) { + dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{ ID: uuid.New().String(), Title: title, }) @@ -50,8 +49,8 @@ func (s *service) Create(title string) (Session, error) { return session, nil } -func (s *service) CreateTaskSession(toolCallID, parentSessionID, title string) (Session, error) { - dbSession, err := s.q.CreateSession(s.ctx, db.CreateSessionParams{ +func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) { + dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{ ID: toolCallID, ParentSessionID: sql.NullString{String: parentSessionID, Valid: true}, Title: title, @@ -64,12 +63,12 @@ func (s *service) CreateTaskSession(toolCallID, parentSessionID, title string) ( return session, nil } -func (s *service) Delete(id string) error { - session, err := s.Get(id) +func (s *service) Delete(ctx context.Context, id string) error { + session, err := s.Get(ctx, id) if err != nil { return err } - err = s.q.DeleteSession(s.ctx, session.ID) + err = s.q.DeleteSession(ctx, session.ID) if err != nil { return err } @@ -77,16 +76,16 @@ func (s *service) Delete(id string) error { return nil } -func (s *service) Get(id string) (Session, error) { - dbSession, err := s.q.GetSessionByID(s.ctx, id) +func (s *service) Get(ctx context.Context, id string) (Session, error) { + dbSession, err := s.q.GetSessionByID(ctx, id) if err != nil { return Session{}, err } return s.fromDBItem(dbSession), nil } -func (s *service) Save(session Session) (Session, error) { - dbSession, err := s.q.UpdateSession(s.ctx, db.UpdateSessionParams{ +func (s *service) Save(ctx context.Context, session Session) (Session, error) { + dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{ ID: session.ID, Title: session.Title, PromptTokens: session.PromptTokens, @@ -101,8 +100,8 @@ func (s *service) Save(session Session) (Session, error) { return session, nil } -func (s *service) List() ([]Session, error) { - dbSessions, err := s.q.ListSessions(s.ctx) +func (s *service) List(ctx context.Context) ([]Session, error) { + dbSessions, err := s.q.ListSessions(ctx) if err != nil { return nil, err } @@ -127,11 +126,10 @@ func (s service) fromDBItem(item db.Session) Session { } } -func NewService(ctx context.Context, q db.Querier) Service { +func NewService(q db.Querier) Service { broker := pubsub.NewBroker[Session]() return &service{ broker, q, - ctx, } } diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index b5a36139239b08f514a36d3a07d308f107a67520..dc21fca2916f1019dbafa313c6d010d096ca9b34 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -1,6 +1,7 @@ package chat import ( + "context" "encoding/json" "fmt" "math" @@ -324,7 +325,7 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s innerToolCalls := make([]string, 0) if toolCall.Name == agent.AgentToolName { - messages, _ := m.app.Messages.List(toolCall.ID) + messages, _ := m.app.Messages.List(context.Background(), toolCall.ID) toolCalls := make([]message.ToolCall, 0) for _, v := range messages { toolCalls = append(toolCalls, v.ToolCalls()...) @@ -554,7 +555,7 @@ func (m *messagesCmp) GetSize() (int, int) { func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { m.session = session - messages, err := m.app.Messages.List(session.ID) + messages, err := m.app.Messages.List(context.Background(), session.ID) if err != nil { return util.ReportError(err) } diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go index e9493129d9a266d703ba123351ff81b266fe4230..b1e39e65540cc6a5cfb43857dfee0ad7d54a57d0 100644 --- a/internal/tui/components/repl/editor.go +++ b/internal/tui/components/repl/editor.go @@ -160,7 +160,7 @@ func (m *editorCmp) Send() tea.Cmd { return util.ReportWarn("Assistant is still working on the previous message") } - messages, err := m.app.Messages.List(m.sessionID) + messages, err := m.app.Messages.List(context.Background(), m.sessionID) if err != nil { return util.ReportError(err) } @@ -177,7 +177,7 @@ func (m *editorCmp) Send() tea.Cmd { if len(content) == 0 { return util.ReportWarn("Message is empty") } - ctx, cancel := context.WithCancel(m.app.Context) + ctx, cancel := context.WithCancel(context.Background()) m.cancelMessage = cancel go func() { defer cancel() diff --git a/internal/tui/components/repl/messages.go b/internal/tui/components/repl/messages.go index 57a55c579b432edec387e5e8c8f05ec858dc2c5c..260be220e82fc6d03928db950f2b244cb0ae7e8b 100644 --- a/internal/tui/components/repl/messages.go +++ b/internal/tui/components/repl/messages.go @@ -1,6 +1,7 @@ package repl import ( + "context" "encoding/json" "fmt" "sort" @@ -77,8 +78,8 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg.Payload } case SelectedSessionMsg: - m.session, _ = m.app.Sessions.Get(msg.SessionID) - m.messages, _ = m.app.Messages.List(m.session.ID) + m.session, _ = m.app.Sessions.Get(context.Background(), msg.SessionID) + m.messages, _ = m.app.Messages.List(context.Background(), m.session.ID) m.renderView() m.viewport.GotoBottom() } @@ -259,7 +260,7 @@ func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message. runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) allParts = append(allParts, leftPadding.Render(runningIndicator)) - taskSessionMessages, _ := m.app.Messages.List(toolCall.ID) + taskSessionMessages, _ := m.app.Messages.List(context.Background(), toolCall.ID) for _, msg := range taskSessionMessages { if msg.Role == message.Assistant { for _, toolCall := range msg.ToolCalls() { diff --git a/internal/tui/components/repl/sessions.go b/internal/tui/components/repl/sessions.go index 093337b188f3bac6a08a77ced937b6dc4500fb7b..c83c4036728138675bc08313454743245723b359 100644 --- a/internal/tui/components/repl/sessions.go +++ b/internal/tui/components/repl/sessions.go @@ -1,6 +1,7 @@ package repl import ( + "context" "fmt" "strings" @@ -57,12 +58,13 @@ var sessionKeyMapValue = sessionsKeyMap{ } func (i *sessionsCmp) Init() tea.Cmd { - existing, err := i.app.Sessions.List() + existing, err := i.app.Sessions.List(context.Background()) if err != nil { return util.ReportError(err) } if len(existing) == 0 || existing[0].MessageCount > 0 { newSession, err := i.app.Sessions.Create( + context.Background(), "New Session", ) if err != nil { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index a7a51bb844640ee0bd0e819d7a78531ac81fdc94..9b9924909ab9fa638bba4a96ab685d740fd65569 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -1,6 +1,8 @@ package page import ( + "context" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" @@ -36,7 +38,7 @@ func (p *chatPage) Init() tea.Cmd { p.layout.Init(), } - sessions, _ := p.app.Sessions.List() + sessions, _ := p.app.Sessions.List(context.Background()) if len(sessions) > 0 { p.session = sessions[0] cmd := p.setSidebar() @@ -92,7 +94,7 @@ func (p *chatPage) clearSidebar() { func (p *chatPage) sendMessage(text string) tea.Cmd { var cmds []tea.Cmd if p.session.ID == "" { - session, err := p.app.Sessions.Create("New Session") + session, err := p.app.Sessions.Create(context.Background(), "New Session") if err != nil { return util.ReportError(err) } @@ -110,7 +112,7 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { return util.ReportError(err) } go func() { - a.Generate(p.app.Context, p.session.ID, text) + a.Generate(context.Background(), p.session.ID, text) }() return tea.Batch(cmds...) diff --git a/internal/tui/tui.go b/internal/tui/tui.go index db9ac9ff6ef5eb0277210a9bf5e26188ff79ee00..1b1a1ed50f97943f2388326583bab9798b55b9c1 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,6 +1,8 @@ package tui import ( + "context" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -184,7 +186,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case key.Matches(msg, replKeyMap): if a.currentPage == page.ReplPage { - sessions, err := a.app.Sessions.List() + sessions, err := a.app.Sessions.List(context.Background()) if err != nil { return a, util.CmdHandler(util.ReportError(err)) } @@ -192,7 +194,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if lastSession.MessageCount == 0 { return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID}) } - s, err := a.app.Sessions.Create("New Session") + s, err := a.app.Sessions.Create(context.Background(), "New Session") if err != nil { return a, util.CmdHandler(util.ReportError(err)) } From cdc5f209dccdc980714f2ca1aeb52133d6e93cce Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sun, 13 Apr 2025 14:37:05 +0200 Subject: [PATCH 09/41] cleanup diff, cleanup agent --- README.md | 14 +- cmd/diff/main.go | 102 ----- cmd/root.go | 12 +- go.mod | 2 +- internal/app/app.go | 20 +- internal/assets/diff/themes/dark.json | 73 ---- internal/assets/embed.go | 6 - internal/assets/write.go | 60 --- internal/git/diff.go | 35 +- internal/llm/agent/agent-tool.go | 34 +- internal/llm/agent/agent.go | 522 ++++++++++++++++--------- internal/llm/agent/coder.go | 83 ++-- internal/llm/agent/task.go | 7 +- internal/llm/provider/provider.go | 4 +- internal/llm/tools/edit.go | 7 +- internal/llm/tools/tools.go | 2 +- internal/llm/tools/write.go | 2 +- internal/tui/components/repl/editor.go | 8 +- internal/tui/page/chat.go | 15 +- 19 files changed, 456 insertions(+), 552 deletions(-) delete mode 100644 cmd/diff/main.go delete mode 100644 internal/assets/diff/themes/dark.json delete mode 100644 internal/assets/embed.go delete mode 100644 internal/assets/write.go diff --git a/README.md b/README.md index ebef72cad3500d910217239a015c9f2a780de034..23a1906a1566d418857f2b2bd2605ffd914413b5 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ termai -d ### Keyboard Shortcuts #### Global Shortcuts + - `?`: Toggle help panel - `Ctrl+C` or `q`: Quit application - `L`: View logs @@ -60,10 +61,12 @@ termai -d - `Esc`: Close current view/dialog or return to normal mode #### Session Management + - `N`: Create new session - `Enter` or `Space`: Select session (in sessions list) #### Editor Shortcuts (Vim-like) + - `i`: Enter insert mode - `Esc`: Enter normal mode - `v`: Enter visual mode @@ -72,6 +75,7 @@ termai -d - `Ctrl+S`: Send message (in insert mode) #### Navigation + - Arrow keys: Navigate through lists and content - Page Up/Down: Scroll through content @@ -112,16 +116,6 @@ go build -o termai ./termai ``` -### Important: Building the Diff Script - -Before building or running the application, you must first build the diff script by running: - -```bash -go run cmd/diff/main.go -``` - -This command generates the necessary JavaScript file (`index.mjs`) used by the diff functionality in the application. - ## Acknowledgments TermAI builds upon the work of several open source projects and developers: diff --git a/cmd/diff/main.go b/cmd/diff/main.go deleted file mode 100644 index da93e4660069912f4081495ab3d2eac8d93c0729..0000000000000000000000000000000000000000 --- a/cmd/diff/main.go +++ /dev/null @@ -1,102 +0,0 @@ -package main - -import ( - "fmt" - "io" - "os" - "os/exec" - "path/filepath" -) - -func main() { - // Create a temporary directory - tempDir, err := os.MkdirTemp("", "git-split-diffs") - if err != nil { - fmt.Printf("Error creating temp directory: %v\n", err) - os.Exit(1) - } - defer func() { - fmt.Printf("Cleaning up temporary directory: %s\n", tempDir) - os.RemoveAll(tempDir) - }() - fmt.Printf("Created temporary directory: %s\n", tempDir) - - // Clone the repository with minimum depth - fmt.Println("Cloning git-split-diffs repository with minimum depth...") - cmd := exec.Command("git", "clone", "--depth=1", "https://github.com/kujtimiihoxha/git-split-diffs", tempDir) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - fmt.Printf("Error cloning repository: %v\n", err) - os.Exit(1) - } - - // Run npm install - fmt.Println("Running npm install...") - cmdNpmInstall := exec.Command("npm", "install") - cmdNpmInstall.Dir = tempDir - cmdNpmInstall.Stdout = os.Stdout - cmdNpmInstall.Stderr = os.Stderr - if err := cmdNpmInstall.Run(); err != nil { - fmt.Printf("Error running npm install: %v\n", err) - os.Exit(1) - } - - // Run npm run build - fmt.Println("Running npm run build...") - cmdNpmBuild := exec.Command("npm", "run", "build") - cmdNpmBuild.Dir = tempDir - cmdNpmBuild.Stdout = os.Stdout - cmdNpmBuild.Stderr = os.Stderr - if err := cmdNpmBuild.Run(); err != nil { - fmt.Printf("Error running npm run build: %v\n", err) - os.Exit(1) - } - - destDir := filepath.Join(".", "internal", "assets", "diff") - destFile := filepath.Join(destDir, "index.mjs") - - // Make sure the destination directory exists - if err := os.MkdirAll(destDir, 0o755); err != nil { - fmt.Printf("Error creating destination directory: %v\n", err) - os.Exit(1) - } - - // Copy the file - srcFile := filepath.Join(tempDir, "build", "index.mjs") - fmt.Printf("Copying %s to %s\n", srcFile, destFile) - if err := copyFile(srcFile, destFile); err != nil { - fmt.Printf("Error copying file: %v\n", err) - os.Exit(1) - } - - fmt.Println("Successfully completed the process!") -} - -// copyFile copies a file from src to dst -func copyFile(src, dst string) error { - sourceFile, err := os.Open(src) - if err != nil { - return err - } - defer sourceFile.Close() - - destFile, err := os.Create(dst) - if err != nil { - return err - } - defer destFile.Close() - - _, err = io.Copy(destFile, sourceFile) - if err != nil { - return err - } - - // Make sure the file is written to disk - err = destFile.Sync() - if err != nil { - return err - } - - return nil -} diff --git a/cmd/root.go b/cmd/root.go index 092606de7ff6ad1dfde4f9238f5199ceca5333e4..a2e63006f195b406598388c538e6abcfd8f525bc 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -9,7 +9,6 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/assets" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/llm/agent" @@ -52,11 +51,6 @@ var rootCmd = &cobra.Command{ return err } - err = assets.WriteAssets() - if err != nil { - logging.Error("Error writing assets: %v", err) - } - // Connect DB, this will also run migrations conn, err := db.Connect() if err != nil { @@ -67,7 +61,11 @@ var rootCmd = &cobra.Command{ ctx, cancel := context.WithCancel(context.Background()) defer cancel() - app := app.New(ctx, conn) + app, err := app.New(ctx, conn) + if err != nil { + logging.Error("Failed to create app: %v", err) + return err + } // Set up the TUI zone.NewGlobal() diff --git a/go.mod b/go.mod index 617dad3a1f541801dd8d2b8af1dd3254c2a69623..e3dc2bd96d60afca3c522c3260c3d92ebf4d9bf5 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,6 @@ require ( github.com/muesli/reflow v0.3.0 github.com/muesli/termenv v0.16.0 github.com/openai/openai-go v0.1.0-beta.2 - github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 @@ -107,6 +106,7 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect diff --git a/internal/app/app.go b/internal/app/app.go index fa4a6ee9008b051f51a9f7b73195592ab93531bf..9f575cac324127469b34c8a0ef5f4df9ed0cf52e 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,6 +9,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" @@ -22,6 +23,8 @@ type App struct { Files history.Service Permissions permission.Service + CoderAgent agent.Service + LSPClients map[string]*lsp.Client clientsMutex sync.RWMutex @@ -31,7 +34,7 @@ type App struct { watcherWG sync.WaitGroup } -func New(ctx context.Context, conn *sql.DB) *App { +func New(ctx context.Context, conn *sql.DB) (*App, error) { q := db.New(conn) sessions := session.NewService(q) messages := message.NewService(q) @@ -45,9 +48,22 @@ func New(ctx context.Context, conn *sql.DB) *App { LSPClients: make(map[string]*lsp.Client), } + var err error + app.CoderAgent, err = agent.NewCoderAgent( + + app.Permissions, + app.Sessions, + app.Messages, + app.LSPClients, + ) + if err != nil { + logging.Error("Failed to create coder agent", err) + return nil, err + } + app.initLSPClients(ctx) - return app + return app, nil } // Shutdown performs a clean shutdown of the application diff --git a/internal/assets/diff/themes/dark.json b/internal/assets/diff/themes/dark.json deleted file mode 100644 index 05c18e08c327178af34397cd2aafe736cc99a93a..0000000000000000000000000000000000000000 --- a/internal/assets/diff/themes/dark.json +++ /dev/null @@ -1,73 +0,0 @@ -{ - "SYNTAX_HIGHLIGHTING_THEME": "dark-plus", - "DEFAULT_COLOR": { - "color": "#ffffff", - "backgroundColor": "#212121" - }, - "COMMIT_HEADER_COLOR": { - "color": "#cccccc" - }, - "COMMIT_HEADER_LABEL_COLOR": { - "color": "#00000022" - }, - "COMMIT_SHA_COLOR": { - "color": "#00eeaa" - }, - "COMMIT_AUTHOR_COLOR": { - "color": "#00aaee" - }, - "COMMIT_DATE_COLOR": { - "color": "#cccccc" - }, - "COMMIT_MESSAGE_COLOR": { - "color": "#cccccc" - }, - "COMMIT_TITLE_COLOR": { - "modifiers": [ - "bold" - ] - }, - "FILE_NAME_COLOR": { - "color": "#ffdd99" - }, - "BORDER_COLOR": { - "color": "#ffdd9966", - "modifiers": [ - "dim" - ] - }, - "HUNK_HEADER_COLOR": { - "modifiers": [ - "dim" - ] - }, - "DELETED_WORD_COLOR": { - "color": "#ffcccc", - "backgroundColor": "#ff000033" - }, - "INSERTED_WORD_COLOR": { - "color": "#ccffcc", - "backgroundColor": "#00ff0033" - }, - "DELETED_LINE_NO_COLOR": { - "color": "#00000022", - "backgroundColor": "#00000022" - }, - "INSERTED_LINE_NO_COLOR": { - "color": "#00000022", - "backgroundColor": "#00000022" - }, - "UNMODIFIED_LINE_NO_COLOR": { - "color": "#666666" - }, - "DELETED_LINE_COLOR": { - "color": "#cc6666", - "backgroundColor": "#3a3030" - }, - "INSERTED_LINE_COLOR": { - "color": "#66cc66", - "backgroundColor": "#303a30" - }, - "UNMODIFIED_LINE_COLOR": {}, - "MISSING_LINE_COLOR": {} -} diff --git a/internal/assets/embed.go b/internal/assets/embed.go deleted file mode 100644 index 9e1316d08e8a5e0a4c9586db63a88d69ff767ae3..0000000000000000000000000000000000000000 --- a/internal/assets/embed.go +++ /dev/null @@ -1,6 +0,0 @@ -package assets - -import "embed" - -//go:embed diff -var FS embed.FS diff --git a/internal/assets/write.go b/internal/assets/write.go deleted file mode 100644 index 602b589ce353a994f4ea063f007639dc21b86fe5..0000000000000000000000000000000000000000 --- a/internal/assets/write.go +++ /dev/null @@ -1,60 +0,0 @@ -package assets - -import ( - "os" - "path/filepath" - - "github.com/kujtimiihoxha/termai/internal/config" -) - -func WriteAssets() error { - appCfg := config.Get() - appWd := config.WorkingDirectory() - scriptDir := filepath.Join( - appWd, - appCfg.Data.Directory, - "diff", - ) - scriptPath := filepath.Join(scriptDir, "index.mjs") - // Before, run the script in cmd/diff/main.go to build this file - if _, err := os.Stat(scriptPath); err != nil { - scriptData, err := FS.ReadFile("diff/index.mjs") - if err != nil { - return err - } - - err = os.MkdirAll(scriptDir, 0o755) - if err != nil { - return err - } - err = os.WriteFile(scriptPath, scriptData, 0o755) - if err != nil { - return err - } - } - - themeDir := filepath.Join( - appWd, - appCfg.Data.Directory, - "themes", - ) - - themePath := filepath.Join(themeDir, "dark.json") - - if _, err := os.Stat(themePath); err != nil { - themeData, err := FS.ReadFile("diff/themes/dark.json") - if err != nil { - return err - } - - err = os.MkdirAll(themeDir, 0o755) - if err != nil { - return err - } - err = os.WriteFile(themePath, themeData, 0o755) - if err != nil { - return err - } - } - return nil -} diff --git a/internal/git/diff.go b/internal/git/diff.go index d87956f0172ce92872db975a0d209e49cbbac7c5..2ab13964246d8c2c814823f3c8e2009859b991bf 100644 --- a/internal/git/diff.go +++ b/internal/git/diff.go @@ -11,7 +11,6 @@ import ( "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/object" - "github.com/kujtimiihoxha/termai/internal/config" ) type DiffStats struct { @@ -197,32 +196,32 @@ func isSplitDiffsAvailable() bool { } func formatWithSplitDiffs(diffText string, width int) (string, error) { - var cmd *exec.Cmd + args := []string{ + "--color", + } - appCfg := config.Get() - appWd := config.WorkingDirectory() - script := filepath.Join( - appWd, - appCfg.Data.Directory, - "diff", - "index.mjs", - ) + var diffCmd *exec.Cmd - cmd = exec.Command("node", script, "--color") + if _, err := exec.LookPath("git-split-diffs-opencode"); err == nil { + fullArgs := append([]string{"git-split-diffs-opencode"}, args...) + diffCmd = exec.Command(fullArgs[0], fullArgs[1:]...) + } else { + npxArgs := append([]string{"git-split-diffs-opencode"}, args...) + diffCmd = exec.Command("npx", npxArgs...) + } - cmd.Env = append(os.Environ(), fmt.Sprintf("COLUMNS=%d", width)) + diffCmd.Env = append(os.Environ(), fmt.Sprintf("DIFF_COLUMNS=%d", width)) - cmd.Stdin = strings.NewReader(diffText) + diffCmd.Stdin = strings.NewReader(diffText) var out bytes.Buffer - cmd.Stdout = &out + diffCmd.Stdout = &out var stderr bytes.Buffer - cmd.Stderr = &stderr + diffCmd.Stderr = &stderr - err := cmd.Run() - if err != nil { - return "", fmt.Errorf("git-split-diffs error: %v, stderr: %s", err, stderr.String()) + if err := diffCmd.Run(); err != nil { + return "", fmt.Errorf("git-split-diffs-opencode error: %w, stderr: %s", err, stderr.String()) } return out.String(), nil diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 91c46da8b0bfe7b44e9c1ff65fd3eb30790ec9f7..a9c6f93a7e04b74b485c79b5942891fdb392dbe0 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,14 +5,16 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/session" ) type agentTool struct { - parentSessionID string - app *app.App + sessions session.Service + messages message.Service + lspClients map[string]*lsp.Client } const ( @@ -46,12 +48,17 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("prompt is required"), nil } - agent, err := NewTaskAgent(b.app) + sessionID, messageID := tools.GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return tools.NewTextErrorResponse("session ID and message ID are required"), nil + } + + agent, err := NewTaskAgent(b.lspClients) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil } - session, err := b.app.Sessions.CreateTaskSession(ctx, call.ID, b.parentSessionID, "New Agent Session") + session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil } @@ -61,7 +68,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil } - messages, err := b.app.Messages.List(ctx, session.ID) + messages, err := b.messages.List(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil } @@ -74,11 +81,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse("no assistant message found"), nil } - updatedSession, err := b.app.Sessions.Get(ctx, session.ID) + updatedSession, err := b.sessions.Get(ctx, session.ID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } - parentSession, err := b.app.Sessions.Get(ctx, b.parentSessionID) + parentSession, err := b.sessions.Get(ctx, sessionID) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } @@ -87,16 +94,19 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes parentSession.PromptTokens += updatedSession.PromptTokens parentSession.CompletionTokens += updatedSession.CompletionTokens - _, err = b.app.Sessions.Save(ctx, parentSession) + _, err = b.sessions.Save(ctx, parentSession) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil } return tools.NewTextResponse(response.Content().String()), nil } -func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool { +func NewAgentTool( + Sessions session.Service, + Messages message.Service, +) tools.BaseTool { return &agentTool{ - parentSessionID: parentSessionID, - app: app, + sessions: Sessions, + messages: Messages, } } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index b7c736e6c38f5cfa6f4029fac414fe65b494b9d3..997004e123208835b038ef4360c8455b03867bd5 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/prompt" @@ -15,22 +14,118 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/session" ) -type Agent interface { +// Common errors +var ( + ErrProviderNotEnabled = errors.New("provider is not enabled") + ErrRequestCancelled = errors.New("request cancelled by user") + ErrSessionBusy = errors.New("session is currently processing another request") +) + +// Service defines the interface for generating responses +type Service interface { Generate(ctx context.Context, sessionID string, content string) error + Cancel(sessionID string) error } type agent struct { - *app.App + sessions session.Service + messages message.Service model models.Model tools []tools.BaseTool agent provider.Provider titleGenerator provider.Provider + activeRequests sync.Map // map[sessionID]context.CancelFunc +} + +// NewAgent creates a new agent instance with the given model and tools +func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) { + agentProvider, titleGenerator, err := getAgentProviders(ctx, model) + if err != nil { + return nil, fmt.Errorf("failed to initialize providers: %w", err) + } + + return &agent{ + model: model, + tools: tools, + sessions: sessions, + messages: messages, + agent: agentProvider, + titleGenerator: titleGenerator, + activeRequests: sync.Map{}, + }, nil +} + +// Cancel cancels an active request by session ID +func (a *agent) Cancel(sessionID string) error { + if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { + if cancel, ok := cancelFunc.(context.CancelFunc); ok { + logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) + cancel() + return nil + } + } + return errors.New("no active request found for this session") } -func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { - response, err := c.titleGenerator.SendMessages( +// Generate starts the generation process +func (a *agent) Generate(ctx context.Context, sessionID string, content string) error { + // Check if this session already has an active request + if _, busy := a.activeRequests.Load(sessionID); busy { + return ErrSessionBusy + } + + // Create a cancellable context + genCtx, cancel := context.WithCancel(ctx) + + // Store cancel function to allow user cancellation + a.activeRequests.Store(sessionID, cancel) + + // Launch the generation in a goroutine + go func() { + defer func() { + if r := recover(); r != nil { + logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) + } + }() + defer a.activeRequests.Delete(sessionID) + defer cancel() + + if err := a.generate(genCtx, sessionID, content); err != nil { + if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) { + // Log the error (avoid logging cancellations as they're expected) + logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err)) + + // You may want to create an error message in the chat + bgCtx := context.Background() + errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err) + _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{ + Role: message.System, + Parts: []message.ContentPart{ + message.TextContent{ + Text: errorMsg, + }, + }, + }) + if createErr != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr)) + } + } + } + }() + + return nil +} + +// IsSessionBusy checks if a session currently has an active request +func (a *agent) IsSessionBusy(sessionID string) bool { + _, busy := a.activeRequests.Load(sessionID) + return busy +} // handleTitleGeneration asynchronously generates a title for new sessions +func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { + response, err := a.titleGenerator.SendMessages( ctx, []message.Message{ { @@ -45,25 +140,30 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st nil, ) if err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err)) return } - session, err := c.Sessions.Get(ctx, sessionID) + session, err := a.sessions.Get(ctx, sessionID) if err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err)) return } + if response.Content != "" { - session.Title = response.Content - session.Title = strings.TrimSpace(session.Title) + session.Title = strings.TrimSpace(response.Content) session.Title = strings.ReplaceAll(session.Title, "\n", " ") - c.Sessions.Save(ctx, session) + if _, err := a.sessions.Save(ctx, session); err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err)) + } } } -func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := c.Sessions.Get(ctx, sessionID) +// TrackUsage updates token usage statistics for the session +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + session, err := a.sessions.Get(ctx, sessionID) if err != nil { - return err + return fmt.Errorf("failed to get session: %w", err) } cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + @@ -75,189 +175,241 @@ func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.M session.CompletionTokens += usage.OutputTokens session.PromptTokens += usage.InputTokens - _, err = c.Sessions.Save(ctx, session) - return err + _, err = a.sessions.Save(ctx, session) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) + } + return nil } -func (c *agent) processEvent( +// processEvent handles different types of events during generation +func (a *agent) processEvent( ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent, ) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Continue processing + } + switch event.Type { case provider.EventThinkingDelta: assistantMsg.AppendReasoningContent(event.Content) - return c.Messages.Update(ctx, *assistantMsg) + return a.messages.Update(ctx, *assistantMsg) case provider.EventContentDelta: assistantMsg.AppendContent(event.Content) - return c.Messages.Update(ctx, *assistantMsg) + return a.messages.Update(ctx, *assistantMsg) case provider.EventError: if errors.Is(event.Error, context.Canceled) { - return nil + logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) + return context.Canceled } logging.ErrorPersist(event.Error.Error()) return event.Error case provider.EventWarning: logging.WarnPersist(event.Info) - return nil case provider.EventInfo: logging.InfoPersist(event.Info) case provider.EventComplete: assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.AddFinish(event.Response.FinishReason) - err := c.Messages.Update(ctx, *assistantMsg) - if err != nil { - return err + if err := a.messages.Update(ctx, *assistantMsg); err != nil { + return fmt.Errorf("failed to update message: %w", err) } - return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage) + return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage) } return nil } -func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { - var wg sync.WaitGroup +// ExecuteTools runs all tool calls sequentially and returns the results +func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { toolResults := make([]message.ToolResult, len(toolCalls)) - mutex := &sync.Mutex{} - errChan := make(chan error, 1) // Create a child context that can be canceled ctx, cancel := context.WithCancel(ctx) defer cancel() - for i, tc := range toolCalls { - wg.Add(1) - go func(index int, toolCall message.ToolCall) { - defer wg.Done() + // Check if already canceled before starting any execution + if ctx.Err() != nil { + // Mark all tools as canceled + for i, toolCall := range toolCalls { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + return toolResults, ctx.Err() + } - // Check if context is already canceled - select { - case <-ctx.Done(): - mutex.Lock() - toolResults[index] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Tool execution canceled", + for i, toolCall := range toolCalls { + // Check for cancellation before executing each tool + select { + case <-ctx.Done(): + // Mark this and all remaining tools as canceled + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", IsError: true, } - mutex.Unlock() - - // Send cancellation error to error channel if it's empty - select { - case errChan <- ctx.Err(): - default: - } - return - default: } + return toolResults, ctx.Err() + default: + // Continue processing + } - response := "" - isError := false - found := false - - for _, tool := range tls { - if tool.Info().Name == toolCall.Name { - found = true - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - - if toolErr != nil { - if errors.Is(toolErr, context.Canceled) { - response = "Tool execution canceled" - - // Send cancellation error to error channel if it's empty - select { - case errChan <- ctx.Err(): - default: - } - } else { - response = fmt.Sprintf("error running tool: %s", toolErr) - } - isError = true + response := "" + isError := false + found := false + + // Find and execute the appropriate tool + for _, tool := range tls { + if tool.Info().Name == toolCall.Name { + found = true + toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, + }) + + if toolErr != nil { + if errors.Is(toolErr, context.Canceled) { + response = "Tool execution canceled by user" } else { - response = toolResult.Content - isError = toolResult.IsError + response = fmt.Sprintf("Error running tool: %s", toolErr) } - break + isError = true + } else { + response = toolResult.Content + isError = toolResult.IsError } + break } + } - if !found { - response = fmt.Sprintf("tool not found: %s", toolCall.Name) - isError = true - } - - mutex.Lock() - defer mutex.Unlock() - - toolResults[index] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: response, - IsError: isError, - } - }(i, tc) - } - - // Wait for all goroutines to finish or context to be canceled - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() + if !found { + response = fmt.Sprintf("Tool not found: %s", toolCall.Name) + isError = true + } - select { - case <-done: - // All tools completed successfully - case err := <-errChan: - // One of the tools encountered a cancellation - return toolResults, err - case <-ctx.Done(): - // Context was canceled externally - return toolResults, ctx.Err() + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: response, + IsError: isError, + } } return toolResults, nil } -func (c *agent) handleToolExecution( +// handleToolExecution processes tool calls and creates tool result messages +func (a *agent) handleToolExecution( ctx context.Context, assistantMsg message.Message, ) (*message.Message, error) { + select { + case <-ctx.Done(): + // If cancelled, create tool results that indicate cancellation + if len(assistantMsg.ToolCalls()) > 0 { + toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls())) + for _, tc := range assistantMsg.ToolCalls() { + toolResults = append(toolResults, message.ToolResult{ + ToolCallID: tc.ID, + Content: "Tool execution canceled by user", + IsError: true, + }) + } + + // Use background context to ensure the message is created even if original context is cancelled + bgCtx := context.Background() + parts := make([]message.ContentPart, 0) + for _, toolResult := range toolResults { + parts = append(parts, toolResult) + } + msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ + Role: message.Tool, + Parts: parts, + }) + if err != nil { + return nil, fmt.Errorf("failed to create cancelled tool message: %w", err) + } + return &msg, ctx.Err() + } + return nil, ctx.Err() + default: + // Continue processing + } + if len(assistantMsg.ToolCalls()) == 0 { return nil, nil } - toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools) + toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools) if err != nil { + // If error is from cancellation, still return the partial results we have + if errors.Is(err, context.Canceled) { + // Use background context to ensure the message is created even if original context is cancelled + bgCtx := context.Background() + parts := make([]message.ContentPart, 0) + for _, toolResult := range toolResults { + parts = append(parts, toolResult) + } + + msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ + Role: message.Tool, + Parts: parts, + }) + if createErr != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr)) + return nil, err + } + return &msg, err + } return nil, err } - parts := make([]message.ContentPart, 0) + + parts := make([]message.ContentPart, 0, len(toolResults)) for _, toolResult := range toolResults { parts = append(parts, toolResult) } - msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ + + msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) + if err != nil { + return nil, fmt.Errorf("failed to create tool message: %w", err) + } - return &msg, err + return &msg, nil } -func (c *agent) generate(ctx context.Context, sessionID string, content string) error { +// generate handles the main generation workflow +func (a *agent) generate(ctx context.Context, sessionID string, content string) error { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - messages, err := c.Messages.List(ctx, sessionID) + + // Handle context cancellation at any point + if err := ctx.Err(); err != nil { + return ErrRequestCancelled + } + + messages, err := a.messages.List(ctx, sessionID) if err != nil { - return err + return fmt.Errorf("failed to list messages: %w", err) } if len(messages) == 0 { - go c.handleTitleGeneration(ctx, sessionID, content) + titleCtx := context.Background() + go a.handleTitleGeneration(titleCtx, sessionID, content) } - userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ + userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.User, Parts: []message.ContentPart{ message.TextContent{ @@ -266,133 +418,125 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string) }, }) if err != nil { - return err + return fmt.Errorf("failed to create user message: %w", err) } messages = append(messages, userMsg) + for { + // Check for cancellation before each iteration select { case <-ctx.Done(): - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - }) - if err != nil { - return err - } - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled default: // Continue processing } - eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools) + eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools) if err != nil { if errors.Is(err, context.Canceled) { - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - }) - if err != nil { - return err - } - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled } - return err + return fmt.Errorf("failed to stream response: %w", err) } - assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{ + assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ Role: message.Assistant, Parts: []message.ContentPart{}, - Model: c.model.ID, + Model: a.model.ID, }) if err != nil { - return err + return fmt.Errorf("failed to create assistant message: %w", err) } ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) + + // Process events from the LLM provider for event := range eventChan { - err = c.processEvent(ctx, sessionID, &assistantMsg, event) - if err != nil { + if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil { if errors.Is(err, context.Canceled) { + // Mark as canceled but don't create separate message assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled } assistantMsg.AddFinish("error:" + err.Error()) - c.Messages.Update(ctx, assistantMsg) - return err + _ = a.messages.Update(ctx, assistantMsg) + return fmt.Errorf("event processing error: %w", err) } + // Check for cancellation during event processing select { case <-ctx.Done(): + // Mark as canceled assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled default: } } - // Check for context cancellation before tool execution + // Check for cancellation before tool execution select { case <-ctx.Done(): - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + assistantMsg.AddFinish("canceled_by_user") + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled default: - // Continue processing } - msg, err := c.handleToolExecution(ctx, assistantMsg) + // Execute any tool calls + toolMsg, err := a.handleToolExecution(ctx, assistantMsg) if err != nil { if errors.Is(err, context.Canceled) { - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + assistantMsg.AddFinish("canceled_by_user") + _ = a.messages.Update(context.Background(), assistantMsg) + return ErrRequestCancelled } - return err + return fmt.Errorf("tool execution error: %w", err) } - c.Messages.Update(ctx, assistantMsg) + if err := a.messages.Update(ctx, assistantMsg); err != nil { + return fmt.Errorf("failed to update assistant message: %w", err) + } + // If no tool calls, we're done if len(assistantMsg.ToolCalls()) == 0 { break } + // Add messages for next iteration messages = append(messages, assistantMsg) - if msg != nil { - messages = append(messages, *msg) + if toolMsg != nil { + messages = append(messages, *toolMsg) } - // Check for context cancellation after tool execution + // Check for cancellation after tool execution select { case <-ctx.Done(): - assistantMsg.AddFinish("canceled") - c.Messages.Update(ctx, assistantMsg) - return context.Canceled + return ErrRequestCancelled default: - // Continue processing } } + return nil } +// getAgentProviders initializes the LLM providers based on the chosen model func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) { maxTokens := config.Get().Model.CoderMaxTokens providerConfig, ok := config.Get().Providers[model.Provider] if !ok || providerConfig.Disabled { - return nil, nil, errors.New("provider is not enabled") + return nil, nil, ErrProviderNotEnabled } + var agentProvider provider.Provider var titleGenerator provider.Provider + var err error switch model.Provider { case models.ProviderOpenAI: - var err error agentProvider, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.CoderOpenAISystemPrompt(), @@ -402,8 +546,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIKey(providerConfig.APIKey), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err) } + titleGenerator, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.TitlePrompt(), @@ -413,10 +558,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIKey(providerConfig.APIKey), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err) } + case models.ProviderAnthropic: - var err error agentProvider, err = provider.NewAnthropicProvider( provider.WithAnthropicSystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -426,8 +571,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithAnthropicModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err) } + titleGenerator, err = provider.NewAnthropicProvider( provider.WithAnthropicSystemMessage( prompt.TitlePrompt(), @@ -437,11 +583,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithAnthropicModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err) } case models.ProviderGemini: - var err error agentProvider, err = provider.NewGeminiProvider( ctx, provider.WithGeminiSystemMessage( @@ -452,8 +597,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithGeminiModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err) } + titleGenerator, err = provider.NewGeminiProvider( ctx, provider.WithGeminiSystemMessage( @@ -464,10 +610,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithGeminiModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err) } + case models.ProviderGROQ: - var err error agentProvider, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -478,8 +624,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err) } + titleGenerator, err = provider.NewOpenAIProvider( provider.WithOpenAISystemMessage( prompt.TitlePrompt(), @@ -490,11 +637,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err) } case models.ProviderBedrock: - var err error agentProvider, err = provider.NewBedrockProvider( provider.WithBedrockSystemMessage( prompt.CoderAnthropicSystemPrompt(), @@ -503,19 +649,21 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid provider.WithBedrockModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err) } + titleGenerator, err = provider.NewBedrockProvider( provider.WithBedrockSystemMessage( prompt.TitlePrompt(), ), - provider.WithBedrockMaxTokens(maxTokens), + provider.WithBedrockMaxTokens(80), provider.WithBedrockModel(model), ) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err) } - + default: + return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider) } return agentProvider, titleGenerator, nil diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index f8e1c40a084ddf2be3d0684363ddbdc3d91114c6..8eea5704163fb3bb02b14aba98054e3fa9477f95 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -4,71 +4,60 @@ import ( "context" "errors" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" ) type coderAgent struct { - *agent + Service } -func (c *coderAgent) setAgentTool(sessionID string) { - inx := -1 - for i, tool := range c.tools { - if tool.Info().Name == AgentToolName { - inx = i - break - } - } - if inx == -1 { - c.tools = append(c.tools, NewAgentTool(sessionID, c.App)) - } else { - c.tools[inx] = NewAgentTool(sessionID, c.App) - } -} - -func (c *coderAgent) Generate(ctx context.Context, sessionID string, content string) error { - c.setAgentTool(sessionID) - return c.generate(ctx, sessionID, content) -} - -func NewCoderAgent(app *app.App) (Agent, error) { +func NewCoderAgent( + permissions permission.Service, + sessions session.Service, + messages message.Service, + lspClients map[string]*lsp.Client, +) (Service, error) { model, ok := models.SupportedModels[config.Get().Model.Coder] if !ok { return nil, errors.New("model not supported") } ctx := context.Background() - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + agent, err := NewAgent( + ctx, + sessions, + messages, + model, + append( + []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions), + NewAgentTool(sessions, messages), + }, otherTools..., + ), + ) if err != nil { return nil, err } - otherTools := GetMcpTools(ctx, app.Permissions) - if len(app.LSPClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients)) - } return &coderAgent{ - agent: &agent{ - App: app, - tools: append( - []tools.BaseTool{ - tools.NewBashTool(app.Permissions), - tools.NewEditTool(app.LSPClients, app.Permissions), - tools.NewFetchTool(app.Permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(app.LSPClients), - tools.NewWriteTool(app.LSPClients, app.Permissions), - }, otherTools..., - ), - model: model, - agent: agentProvider, - titleGenerator: titleGenerator, - }, + agent, }, nil } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go index c196cb107b1926068a8d32b5703809a155303e53..0a072044c4db4e4261848146d6d8d83a93787d0e 100644 --- a/internal/llm/agent/task.go +++ b/internal/llm/agent/task.go @@ -4,10 +4,10 @@ import ( "context" "errors" - "github.com/kujtimiihoxha/termai/internal/app" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" ) type taskAgent struct { @@ -18,7 +18,7 @@ func (c *taskAgent) Generate(ctx context.Context, sessionID string, content stri return c.generate(ctx, sessionID, content) } -func NewTaskAgent(app *app.App) (Agent, error) { +func NewTaskAgent(lspClients map[string]*lsp.Client) (Service, error) { model, ok := models.SupportedModels[config.Get().Model.Coder] if !ok { return nil, errors.New("model not supported") @@ -31,13 +31,12 @@ func NewTaskAgent(app *app.App) (Agent, error) { } return &taskAgent{ agent: &agent{ - App: app, tools: []tools.BaseTool{ tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), tools.NewSourcegraphTool(), - tools.NewViewTool(app.LSPClients), + tools.NewViewTool(lspClients), }, model: model, agent: agentProvider, diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 938a8c0adbb1a5b109f38071bf03806d027d3052..34d91f2b771dd918ff6432bb9aff5f12d667e0e0 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -57,7 +57,9 @@ func cleanupMessages(messages []message.Message) []message.Message { // First pass: filter out canceled messages var cleanedMessages []message.Message for _, msg := range messages { - if msg.FinishReason() != "canceled" { + if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 { + // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been + // cancelled cleanedMessages = append(cleanedMessages, msg) } } diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index c9a0be07956bc65725abf93a5bdd1774bda371ae..647b8d35ffeb098011ca41f7a305fe14e4e68989 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -190,7 +190,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return er, fmt.Errorf("failed to create parent directories: %w", err) } - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") } @@ -277,7 +277,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string newContent := oldContent[:index] + oldContent[index+len(oldString):] - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") @@ -365,7 +365,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return er, fmt.Errorf("session ID and message ID are required for creating a new file") @@ -409,4 +409,3 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return er, nil } - diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 473b787bbbe9c2e20c3c08aa7ed348b7dafc009f..07afe1363fc114ca17daf0211d656f4b21f2723d 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -66,7 +66,7 @@ type BaseTool interface { Run(ctx context.Context, params ToolCall) (ToolResponse, error) } -func getContextValues(ctx context.Context) (string, string) { +func GetContextValues(ctx context.Context) (string, string) { sessionID := ctx.Value(SessionIDContextKey) messageID := ctx.Value(MessageIDContextKey) if sessionID == nil { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 27c98bb9d244ff0054165d6d3e036a803c5f9908..1b087c1934d67585281e48a3fc4c75791915cdd8 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -144,7 +144,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error } } - sessionID, messageID := getContextValues(ctx) + sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { return NewTextErrorResponse("session ID or message ID is missing"), nil } diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go index b1e39e65540cc6a5cfb43857dfee0ad7d54a57d0..b659775e0e6fc92676fa66dbd6a7da2a904a189b 100644 --- a/internal/tui/components/repl/editor.go +++ b/internal/tui/components/repl/editor.go @@ -7,7 +7,6 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" @@ -168,11 +167,6 @@ func (m *editorCmp) Send() tea.Cmd { return util.ReportWarn("Assistant is still working on the previous message") } - a, err := agent.NewCoderAgent(m.app) - if err != nil { - return util.ReportError(err) - } - content := strings.Join(m.editor.GetBuffer().Lines(), "\n") if len(content) == 0 { return util.ReportWarn("Message is empty") @@ -181,7 +175,7 @@ func (m *editorCmp) Send() tea.Cmd { m.cancelMessage = cancel go func() { defer cancel() - a.Generate(ctx, m.sessionID, content) + m.app.CoderAgent.Generate(ctx, m.sessionID, content) m.cancelMessage = nil }() diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index 9b9924909ab9fa638bba4a96ab685d740fd65569..439c89e1f683cb6296266db465bdd5237548f692 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -6,7 +6,6 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/components/chat" "github.com/kujtimiihoxha/termai/internal/tui/layout" @@ -23,6 +22,7 @@ type chatPage struct { type ChatKeyMap struct { NewSession key.Binding + Cancel key.Binding } var keyMap = ChatKeyMap{ @@ -30,6 +30,10 @@ var keyMap = ChatKeyMap{ key.WithKeys("ctrl+n"), key.WithHelp("ctrl+n", "new session"), ), + Cancel: key.NewBinding( + key.WithKeys("ctrl+x"), + key.WithHelp("ctrl+x", "cancel"), + ), } func (p *chatPage) Init() tea.Cmd { @@ -106,15 +110,8 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { } cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session))) } - // TODO: move this to a service - a, err := agent.NewCoderAgent(p.app) - if err != nil { - return util.ReportError(err) - } - go func() { - a.Generate(context.Background(), p.session.ID, text) - }() + p.app.CoderAgent.Generate(context.Background(), p.session.ID, text) return tea.Batch(cmds...) } From 80cd75c4fb21eb28d82c1f0d672cbd8466c35ed5 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 10:49:32 +0200 Subject: [PATCH 10/41] handle errors correctly in the agent tool --- internal/llm/agent/agent-tool.go | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index a9c6f93a7e04b74b485c79b5942891fdb392dbe0..a92ea44a4e31c56baacc1b0772a151f153a2bd7b 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -50,44 +50,45 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes sessionID, messageID := tools.GetContextValues(ctx) if sessionID == "" || messageID == "" { - return tools.NewTextErrorResponse("session ID and message ID are required"), nil + return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } agent, err := NewTaskAgent(b.lspClients) if err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil + return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) } session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session") if err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil + return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) } err = agent.Generate(ctx, session.ID, params.Prompt) if err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil + return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) } messages, err := b.messages.List(ctx, session.ID) if err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil + return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err) } + if len(messages) == 0 { - return tools.NewTextErrorResponse("no messages found"), nil + return tools.NewTextErrorResponse("no response"), nil } response := messages[len(messages)-1] if response.Role != message.Assistant { - return tools.NewTextErrorResponse("no assistant message found"), nil + return tools.NewTextErrorResponse("no response"), nil } updatedSession, err := b.sessions.Get(ctx, session.ID) if err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil + return tools.ToolResponse{}, fmt.Errorf("error getting session: %s", err) } parentSession, err := b.sessions.Get(ctx, sessionID) if err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil + return tools.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err) } parentSession.Cost += updatedSession.Cost @@ -96,7 +97,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes _, err = b.sessions.Save(ctx, parentSession) if err != nil { - return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil + return tools.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err) } return tools.NewTextResponse(response.Content().String()), nil } From 9ae05fea12ad05ea356a057f67afdde46d548843 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 10:52:04 +0200 Subject: [PATCH 11/41] handle errros correctly in the bash tool --- internal/llm/tools/bash.go | 4 ++-- internal/permission/permission.go | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index d55cb241b9c0ade5a8b49518ee6b42f7115dfc7e..0cea20878731b24902796dba7efeace9f2d91312 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -273,14 +273,14 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) }, ) if !p { - return NewTextErrorResponse("permission denied"), nil + return ToolResponse{}, permission.ErrorPermissionDenied } } startTime := time.Now() shell := shell.GetPersistentShell(config.WorkingDirectory()) stdout, stderr, exitCode, interrupted, err := shell.Exec(ctx, params.Command, params.Timeout) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("error executing command: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error executing command: %w", err) } took := time.Since(startTime).Milliseconds() diff --git a/internal/permission/permission.go b/internal/permission/permission.go index ebf3fe0925b719afbe94f058ac6195611ed17341..8aa280906a397dd5766b4fcd52daf86b76506f8b 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -1,6 +1,7 @@ package permission import ( + "errors" "sync" "time" @@ -8,6 +9,8 @@ import ( "github.com/kujtimiihoxha/termai/internal/pubsub" ) +var ErrorPermissionDenied = errors.New("permission denied") + type CreatePermissionRequest struct { ToolName string `json:"tool_name"` Description string `json:"description"` @@ -15,6 +18,7 @@ type CreatePermissionRequest struct { Params any `json:"params"` Path string `json:"path"` } + type PermissionRequest struct { ID string `json:"id"` SessionID string `json:"session_id"` From 921f5ee5bd74837ff4566fc2d1e45051c87d9c38 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 11:08:17 +0200 Subject: [PATCH 12/41] handle errors correctly in the edit tool --- internal/llm/tools/diagnostics.go | 4 +- internal/llm/tools/edit.go | 152 +++++++++++++++--------------- internal/llm/tools/view.go | 2 +- internal/llm/tools/write.go | 2 +- 4 files changed, 78 insertions(+), 82 deletions(-) diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go index 1bb02098e131ba1c2f193954de7246cca0a2fd7c..b7b2bb8bab55b0ecf507c7bd6bd25b9454131cbc 100644 --- a/internal/llm/tools/diagnostics.go +++ b/internal/llm/tools/diagnostics.go @@ -82,7 +82,7 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, waitForLspDiagnostics(ctx, params.FilePath, lsps) } - output := appendDiagnostics(params.FilePath, lsps) + output := getDiagnostics(params.FilePath, lsps) return NewTextResponse(output), nil } @@ -154,7 +154,7 @@ func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol return false } -func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string { +func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string { fileDiagnostics := []string{} projectDiagnostics := []string{} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 647b8d35ffeb098011ca41f7a305fe14e4e68989..46b95aea23eb96163a9db546160c732d43bcc260 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -131,68 +131,54 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) params.FilePath = filepath.Join(wd, params.FilePath) } + var response ToolResponse + var err error + if params.OldString == "" { - result, err := e.createNewFile(ctx, params.FilePath, params.NewString) + response, err = e.createNewFile(ctx, params.FilePath, params.NewString) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("error creating file: %s", err)), nil + return response, nil } - return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{ - Additions: result.additions, - Removals: result.removals, - }), nil } if params.NewString == "" { - result, err := e.deleteContent(ctx, params.FilePath, params.OldString) + response, err = e.deleteContent(ctx, params.FilePath, params.OldString) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("error deleting content: %s", err)), nil + return response, nil } - return WithResponseMetadata(NewTextResponse(result.text), EditResponseMetadata{ - Additions: result.additions, - Removals: result.removals, - }), nil } - result, err := e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) + response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("error replacing content: %s", err)), nil + return response, nil } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) - text := fmt.Sprintf("\n%s\n\n", result.text) - text += appendDiagnostics(params.FilePath, e.lspClients) - return WithResponseMetadata(NewTextResponse(text), EditResponseMetadata{ - Additions: result.additions, - Removals: result.removals, - }), nil -} - -type editResponse struct { - text string - additions int - removals int + text := fmt.Sprintf("\n%s\n\n", response.Content) + text += getDiagnostics(params.FilePath, e.lspClients) + response.Content = text + return response, nil } -func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (editResponse, error) { - er := editResponse{} +func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (ToolResponse, error) { fileInfo, err := os.Stat(filePath) if err == nil { if fileInfo.IsDir() { - return er, fmt.Errorf("path is a directory, not a file: %s", filePath) + return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil } - return er, fmt.Errorf("file already exists: %s. Use the Replace tool to overwrite an existing file", filePath) + return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", filePath)), nil } else if !os.IsNotExist(err) { - return er, fmt.Errorf("failed to access file: %w", err) + return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) } dir := filepath.Dir(filePath) if err = os.MkdirAll(dir, 0o755); err != nil { - return er, fmt.Errorf("failed to create parent directories: %w", err) + return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err) } sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { - return er, fmt.Errorf("session ID and message ID are required for creating a new file") + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") } diff, stats, err := git.GenerateGitDiffWithStats( @@ -201,7 +187,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) content, ) if err != nil { - return er, fmt.Errorf("failed to get file diff: %w", err) + return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err) } p := e.permissions.Request( permission.CreatePermissionRequest{ @@ -216,63 +202,67 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) }, ) if !p { - return er, fmt.Errorf("permission denied") + return ToolResponse{}, permission.ErrorPermissionDenied } err = os.WriteFile(filePath, []byte(content), 0o644) if err != nil { - return er, fmt.Errorf("failed to write file: %w", err) + return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) - er.text = "File created: " + filePath - er.additions = stats.Additions - er.removals = stats.Removals - return er, nil + return WithResponseMetadata( + NewTextResponse("File created: "+filePath), + EditResponseMetadata{ + Additions: stats.Additions, + Removals: stats.Removals, + }, + ), nil } -func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (editResponse, error) { - er := editResponse{} +func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (ToolResponse, error) { fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { - return er, fmt.Errorf("file not found: %s", filePath) + return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil } - return er, fmt.Errorf("failed to access file: %w", err) + return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) } if fileInfo.IsDir() { - return er, fmt.Errorf("path is a directory, not a file: %s", filePath) + return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil } if getLastReadTime(filePath).IsZero() { - return er, fmt.Errorf("you must read the file before editing it. Use the View tool first") + return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil } modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { - return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", - filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)) + return NewTextErrorResponse( + fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), + )), nil } content, err := os.ReadFile(filePath) if err != nil { - return er, fmt.Errorf("failed to read file: %w", err) + return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) } oldContent := string(content) index := strings.Index(oldContent, oldString) if index == -1 { - return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") + return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil } lastIndex := strings.LastIndex(oldContent, oldString) if index != lastIndex { - return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") + return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil } newContent := oldContent[:index] + oldContent[index+len(oldString):] @@ -280,7 +270,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { - return er, fmt.Errorf("session ID and message ID are required for creating a new file") + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") } diff, stats, err := git.GenerateGitDiffWithStats( @@ -289,7 +279,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string newContent, ) if err != nil { - return er, fmt.Errorf("failed to get file diff: %w", err) + return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err) } p := e.permissions.Request( @@ -305,62 +295,66 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string }, ) if !p { - return er, fmt.Errorf("permission denied") + return ToolResponse{}, permission.ErrorPermissionDenied } err = os.WriteFile(filePath, []byte(newContent), 0o644) if err != nil { - return er, fmt.Errorf("failed to write file: %w", err) + return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) - er.text = "Content deleted from file: " + filePath - er.additions = stats.Additions - er.removals = stats.Removals - return er, nil + return WithResponseMetadata( + NewTextResponse("Content deleted from file: "+filePath), + EditResponseMetadata{ + Additions: stats.Additions, + Removals: stats.Removals, + }, + ), nil } -func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (editResponse, error) { - er := editResponse{} +func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (ToolResponse, error) { fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { - return er, fmt.Errorf("file not found: %s", filePath) + return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil } - return er, fmt.Errorf("failed to access file: %w", err) + return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) } if fileInfo.IsDir() { - return er, fmt.Errorf("path is a directory, not a file: %s", filePath) + return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil } if getLastReadTime(filePath).IsZero() { - return er, fmt.Errorf("you must read the file before editing it. Use the View tool first") + return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil } modTime := fileInfo.ModTime() lastRead := getLastReadTime(filePath) if modTime.After(lastRead) { - return er, fmt.Errorf("file %s has been modified since it was last read (mod time: %s, last read: %s)", - filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339)) + return NewTextErrorResponse( + fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), + )), nil } content, err := os.ReadFile(filePath) if err != nil { - return er, fmt.Errorf("failed to read file: %w", err) + return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) } oldContent := string(content) index := strings.Index(oldContent, oldString) if index == -1 { - return er, fmt.Errorf("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks") + return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil } lastIndex := strings.LastIndex(oldContent, oldString) if index != lastIndex { - return er, fmt.Errorf("old_string appears multiple times in the file. Please provide more context to ensure a unique match") + return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil } newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] @@ -368,7 +362,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { - return er, fmt.Errorf("session ID and message ID are required for creating a new file") + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") } diff, stats, err := git.GenerateGitDiffWithStats( removeWorkingDirectoryPrefix(filePath), @@ -376,7 +370,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent, ) if err != nil { - return er, fmt.Errorf("failed to get file diff: %w", err) + return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err) } p := e.permissions.Request( @@ -393,19 +387,21 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS }, ) if !p { - return er, fmt.Errorf("permission denied") + return ToolResponse{}, permission.ErrorPermissionDenied } err = os.WriteFile(filePath, []byte(newContent), 0o644) if err != nil { - return er, fmt.Errorf("failed to write file: %w", err) + return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } recordFileWrite(filePath) recordFileRead(filePath) - er.text = "Content replaced in file: " + filePath - er.additions = stats.Additions - er.removals = stats.Removals - return er, nil + return WithResponseMetadata( + NewTextResponse("Content replaced in file: "+filePath), + EditResponseMetadata{ + Additions: stats.Additions, + Removals: stats.Removals, + }), nil } diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index a687be015353afdbdc1e84fe1eaf384a72aa350e..a0600a2a1b5b776307fa0714c5ff866c8c48ba66 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -177,7 +177,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) params.Offset+len(strings.Split(content, "\n"))) } output += "\n\n" - output += appendDiagnostics(filePath, v.lspClients) + output += getDiagnostics(filePath, v.lspClients) recordFileRead(filePath) return NewTextResponse(output), nil } diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 1b087c1934d67585281e48a3fc4c75791915cdd8..9797239d9dbb234cb93908e63e3e8d1f9acae585 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -183,7 +183,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error result := fmt.Sprintf("File successfully written: %s", filePath) result = fmt.Sprintf("\n%s\n", result) - result += appendDiagnostics(filePath, w.lspClients) + result += getDiagnostics(filePath, w.lspClients) return WithResponseMetadata(NewTextResponse(result), WriteResponseMetadata{ Additions: stats.Additions, From 0b3e5f5bd42a02c2a15b394b3768e517dc43f39c Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 11:24:36 +0200 Subject: [PATCH 13/41] handle errors correctly in the other tools --- internal/llm/tools/edit.go | 8 ++++++-- internal/llm/tools/fetch.go | 7 ++++--- internal/llm/tools/glob.go | 15 +++++++++++++-- internal/llm/tools/grep.go | 15 +++++++++++++-- internal/llm/tools/ls.go | 15 +++++++++++++-- internal/llm/tools/sourcegraph.go | 15 ++++++++++----- internal/llm/tools/view.go | 5 +++-- internal/llm/tools/write.go | 18 ++++++++++-------- 8 files changed, 72 insertions(+), 26 deletions(-) diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 46b95aea23eb96163a9db546160c732d43bcc260..b74e427296e4cca28349e6de920cef6399f7f823 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -27,8 +27,9 @@ type EditPermissionsParams struct { } type EditResponseMetadata struct { - Additions int `json:"additions"` - Removals int `json:"removals"` + Diff string `json:"diff"` + Additions int `json:"additions"` + Removals int `json:"removals"` } type editTool struct { @@ -216,6 +217,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return WithResponseMetadata( NewTextResponse("File created: "+filePath), EditResponseMetadata{ + Diff: diff, Additions: stats.Additions, Removals: stats.Removals, }, @@ -308,6 +310,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string return WithResponseMetadata( NewTextResponse("Content deleted from file: "+filePath), EditResponseMetadata{ + Diff: diff, Additions: stats.Additions, Removals: stats.Removals, }, @@ -401,6 +404,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return WithResponseMetadata( NewTextResponse("Content replaced in file: "+filePath), EditResponseMetadata{ + Diff: diff, Additions: stats.Additions, Removals: stats.Removals, }), nil diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 19e6442819c03af278b995cf4605df757f9034d7..91bcb36a0696e293658ea11cf333b18a9aa2d767 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -86,6 +86,7 @@ func (t *fetchTool) Info() ToolInfo { "format": map[string]any{ "type": "string", "description": "The format to return the content in (text, markdown, or html)", + "enum": []string{"text", "markdown", "html"}, }, "timeout": map[string]any{ "type": "number", @@ -126,7 +127,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error ) if !p { - return NewTextErrorResponse("Permission denied to fetch from URL: " + params.URL), nil + return ToolResponse{}, permission.ErrorPermissionDenied } client := t.client @@ -142,14 +143,14 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error req, err := http.NewRequestWithContext(ctx, "GET", params.URL, nil) if err != nil { - return NewTextErrorResponse("Failed to create request: " + err.Error()), nil + return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("User-Agent", "termai/1.0") resp, err := client.Do(req) if err != nil { - return NewTextErrorResponse("Failed to execute request: " + err.Error()), nil + return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err) } defer resp.Body.Close() diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 4de7971e63e088b1de953beb9bad28c5e68bb2d4..bdfc23b4ababde88978b3888878f4265fd52eae6 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -63,6 +63,11 @@ type GlobParams struct { Path string `json:"path"` } +type GlobMetadata struct { + NumberOfFiles int `json:"number_of_files"` + Truncated bool `json:"truncated"` +} + type globTool struct{} func NewGlobTool() BaseTool { @@ -104,7 +109,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) files, truncated, err := globFiles(params.Pattern, searchPath, 100) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("error performing glob search: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error finding files: %w", err) } var output string @@ -117,7 +122,13 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } } - return NewTextResponse(output), nil + return WithResponseMetadata( + NewTextResponse(output), + GlobMetadata{ + NumberOfFiles: len(files), + Truncated: truncated, + }, + ), nil } func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) { diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index f349e83709dc4ec5f6b3bddfb3851b2d01b6a441..7e52821d07bc6dbc71acdf6d7a5eeb42d2bc2dc1 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -27,6 +27,11 @@ type grepMatch struct { modTime time.Time } +type GrepMetadata struct { + NumberOfMatches int `json:"number_of_matches"` + Truncated bool `json:"truncated"` +} + type grepTool struct{} const ( @@ -110,7 +115,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) matches, truncated, err := searchFiles(params.Pattern, searchPath, params.Include, 100) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("error searching files: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error searching files: %w", err) } var output string @@ -127,7 +132,13 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } } - return NewTextResponse(output), nil + return WithResponseMetadata( + NewTextResponse(output), + GrepMetadata{ + NumberOfMatches: len(matches), + Truncated: truncated, + }, + ), nil } func pluralize(count int) string { diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index 59e8dcd21141a33491aaac96a50f78607a4293b0..a679f261b1ae28b3551245a8e5d25f76185ef14d 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -23,6 +23,11 @@ type TreeNode struct { Children []*TreeNode `json:"children,omitempty"` } +type LSMetadata struct { + NumberOfFiles int `json:"number_of_files"` + Truncated bool `json:"truncated"` +} + type lsTool struct{} const ( @@ -104,7 +109,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { files, truncated, err := listDirectory(searchPath, params.Ignore, MaxLSFiles) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("error listing directory: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error listing directory: %w", err) } tree := createFileTree(files) @@ -114,7 +119,13 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %d files and directories are included below:\n\n%s", MaxLSFiles, MaxLSFiles, output) } - return NewTextResponse(output), nil + return WithResponseMetadata( + NewTextResponse(output), + LSMetadata{ + NumberOfFiles: len(files), + Truncated: truncated, + }, + ), nil } func listDirectory(initialPath string, ignorePatterns []string, limit int) ([]string, bool, error) { diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index e1ea962d4efaf032302ddf2732e2e3cfb3b19d5a..17bc610ea8fe2597ba8a210efa0f0e5a82c620bd 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -18,6 +18,11 @@ type SourcegraphParams struct { Timeout int `json:"timeout,omitempty"` } +type SourcegraphMetadata struct { + NumberOfMatches int `json:"number_of_matches"` + Truncated bool `json:"truncated"` +} + type sourcegraphTool struct { client *http.Client } @@ -198,7 +203,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, graphqlQueryBytes, err := json.Marshal(request) if err != nil { - return NewTextErrorResponse("Failed to create GraphQL request: " + err.Error()), nil + return ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err) } graphqlQuery := string(graphqlQueryBytes) @@ -209,7 +214,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, bytes.NewBuffer([]byte(graphqlQuery)), ) if err != nil { - return NewTextErrorResponse("Failed to create request: " + err.Error()), nil + return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") @@ -217,7 +222,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, resp, err := client.Do(req) if err != nil { - return NewTextErrorResponse("Failed to execute request: " + err.Error()), nil + return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err) } defer resp.Body.Close() @@ -231,12 +236,12 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, } body, err := io.ReadAll(resp.Body) if err != nil { - return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil + return ToolResponse{}, fmt.Errorf("failed to read response body: %w", err) } var result map[string]any if err = json.Unmarshal(body, &result); err != nil { - return NewTextErrorResponse("Failed to parse response: " + err.Error()), nil + return ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err) } formattedResults, err := formatSourcegraphResults(result, params.ContextWindow) diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index a0600a2a1b5b776307fa0714c5ff866c8c48ba66..7450a84bfb1bcedf4c4c265f202d35ca5ffb6259 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -135,7 +135,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse(fmt.Sprintf("File not found: %s", filePath)), nil } - return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error accessing file: %w", err) } // Check if it's a directory @@ -156,6 +156,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) // Check if it's an image file isImage, imageType := isImageFile(filePath) + // TODO: handle images if isImage { return NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\nUse a different tool to process images", imageType)), nil } @@ -163,7 +164,7 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) // Read the file content content, lineCount, err := readTextFile(filePath, params.Offset, params.Limit) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("Failed to read file: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error reading file: %w", err) } notifyLspOpenFile(ctx, filePath, v.lspClients) diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 9797239d9dbb234cb93908e63e3e8d1f9acae585..8318f2851fa0a9e830a842fe7823eab567858c1c 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -30,8 +30,9 @@ type writeTool struct { } type WriteResponseMetadata struct { - Additions int `json:"additions"` - Removals int `json:"removals"` + Diff string `json:"diff"` + Additions int `json:"additions"` + Removals int `json:"removals"` } const ( @@ -128,12 +129,12 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil } } else if !os.IsNotExist(err) { - return NewTextErrorResponse(fmt.Sprintf("Failed to access file: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error checking file: %w", err) } dir := filepath.Dir(filePath) if err = os.MkdirAll(dir, 0o755); err != nil { - return NewTextErrorResponse(fmt.Sprintf("Failed to create parent directories: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error creating directory: %w", err) } oldContent := "" @@ -146,7 +147,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { - return NewTextErrorResponse("session ID or message ID is missing"), nil + return ToolResponse{}, fmt.Errorf("session_id and message_id are required") } diff, stats, err := git.GenerateGitDiffWithStats( removeWorkingDirectoryPrefix(filePath), @@ -154,7 +155,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error params.Content, ) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("Failed to get file diff: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error generating diff: %w", err) } p := w.permissions.Request( permission.CreatePermissionRequest{ @@ -169,12 +170,12 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error }, ) if !p { - return NewTextErrorResponse(fmt.Sprintf("Permission denied to create file: %s", filePath)), nil + return ToolResponse{}, permission.ErrorPermissionDenied } err = os.WriteFile(filePath, []byte(params.Content), 0o644) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("Failed to write file: %s", err)), nil + return ToolResponse{}, fmt.Errorf("error writing file: %w", err) } recordFileWrite(filePath) @@ -186,6 +187,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error result += getDiagnostics(filePath, w.lspClients) return WithResponseMetadata(NewTextResponse(result), WriteResponseMetadata{ + Diff: diff, Additions: stats.Additions, Removals: stats.Removals, }, From 0130bde1edabb81d82dbce9d2d562966d2dee133 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 14:09:17 +0200 Subject: [PATCH 14/41] remove node dependency and implement diff format --- go.mod | 24 +- go.sum | 57 +- internal/diff/diff.go | 995 +++++++++++++++++++ internal/git/diff.go | 264 ----- internal/llm/tools/edit.go | 39 +- internal/llm/tools/write.go | 15 +- internal/tui/components/dialog/permission.go | 6 +- 7 files changed, 1028 insertions(+), 372 deletions(-) create mode 100644 internal/diff/diff.go delete mode 100644 internal/git/diff.go diff --git a/go.mod b/go.mod index e3dc2bd96d60afca3c522c3260c3d92ebf4d9bf5..b201be8005cc1893589e1fc022f944a70094f48c 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.24.2 require ( github.com/JohannesKaufmann/html-to-markdown v1.6.0 github.com/PuerkitoBio/goquery v1.9.2 + github.com/alecthomas/chroma/v2 v2.15.0 github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 github.com/bmatcuk/doublestar/v4 v4.8.1 github.com/catppuccin/go v0.3.0 @@ -17,7 +18,6 @@ require ( github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/x/ansi v0.8.0 github.com/fsnotify/fsnotify v1.8.0 - github.com/go-git/go-git/v5 v5.15.0 github.com/go-logfmt/logfmt v0.6.0 github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/generative-ai-go v0.19.0 @@ -31,6 +31,7 @@ require ( github.com/muesli/reflow v0.3.0 github.com/muesli/termenv v0.16.0 github.com/openai/openai-go v0.1.0-beta.2 + github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 @@ -45,10 +46,6 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/longrunning v0.5.7 // indirect - dario.cat/mergo v1.0.0 // indirect - github.com/Microsoft/go-winio v0.6.2 // indirect - github.com/ProtonMail/go-crypto v1.1.6 // indirect - github.com/alecthomas/chroma/v2 v2.15.0 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect @@ -71,20 +68,15 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect - github.com/cloudflare/circl v1.6.1 // indirect - github.com/cyphar/filepath-securejoin v0.4.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/emirpasic/gods v1.18.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect - github.com/go-git/go-billy/v5 v5.6.2 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect @@ -92,8 +84,6 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect - github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -101,13 +91,11 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/pjbgf/sha1cd v0.3.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect - github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect - github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect github.com/spf13/cast v1.7.1 // indirect @@ -117,7 +105,6 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect - github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.7.8 // indirect @@ -131,6 +118,7 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.design/x/clipboard v0.7.0 // indirect golang.org/x/crypto v0.37.0 // indirect + golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/image v0.14.0 // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect @@ -144,6 +132,6 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect google.golang.org/grpc v1.67.3 // indirect google.golang.org/protobuf v1.36.1 // indirect - gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9c2c2df8fbcb0d909a73a67ead79339bf3814892..08e7e7c42e61f69fd454df0683507bb91ddd3cd9 100644 --- a/go.sum +++ b/go.sum @@ -10,17 +10,10 @@ cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4 cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= -dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= -dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5x+rHJnb1ssNmqpLH/k= github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= -github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= -github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= -github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/ProtonMail/go-crypto v1.1.6 h1:ZcV+Ropw6Qn0AX9brlQLAUXfqLBc7Bl+f/DmNxpLfdw= -github.com/ProtonMail/go-crypto v1.1.6/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4yPeE= github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= @@ -31,12 +24,8 @@ github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= -github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= -github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 h1:h7qxtumNjKPWFv1QM/HJy60MteeW23iKeEtBoY7bYZk= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= -github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= -github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= @@ -99,11 +88,7 @@ github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= -github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= -github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= -github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -111,10 +96,6 @@ github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yA github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= -github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= -github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= -github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -123,16 +104,6 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= -github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= -github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= -github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= -github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM= -github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= -github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= -github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= -github.com/go-git/go-git/v5 v5.15.0 h1:f5Qn0W0F7ry1iN0ZwIU5m/n7/BKB4hiZfc+zlZx7ly0= -github.com/go-git/go-git/v5 v5.15.0/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -144,8 +115,6 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= -github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= -github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -169,11 +138,8 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= -github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= -github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= -github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -213,17 +179,11 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= -github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/openai/openai-go v0.1.0-beta.2 h1:Ra5nCFkbEl9w+UJwAciC4kqnIBUCcJazhmMA0/YN894= github.com/openai/openai-go v0.1.0-beta.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= -github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= -github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -243,9 +203,6 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= -github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8= -github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= @@ -259,7 +216,6 @@ github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.20.0 h1:zrxIyR3RQIOsarIrgL8+sAvALXul9jeEPa06Y0Ph6vY= github.com/spf13/viper v1.20.0/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -276,8 +232,6 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= -github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= @@ -306,7 +260,6 @@ golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo= golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= @@ -324,7 +277,6 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= @@ -342,14 +294,10 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -372,7 +320,6 @@ golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= @@ -401,8 +348,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= -gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/diff/diff.go b/internal/diff/diff.go new file mode 100644 index 0000000000000000000000000000000000000000..4e6aa9f5bcba9386a0f5c0e9c801254470950090 --- /dev/null +++ b/internal/diff/diff.go @@ -0,0 +1,995 @@ +package diff + +import ( + "bytes" + "fmt" + "io" + "regexp" + "strconv" + "strings" + "time" + + "github.com/alecthomas/chroma/v2" + "github.com/alecthomas/chroma/v2/formatters" + "github.com/alecthomas/chroma/v2/lexers" + "github.com/alecthomas/chroma/v2/styles" + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/x/ansi" + "github.com/sergi/go-diff/diffmatchpatch" +) + +// LineType represents the kind of line in a diff. +type LineType int + +const ( + // LineContext represents a line that exists in both the old and new file. + LineContext LineType = iota + // LineAdded represents a line added in the new file. + LineAdded + // LineRemoved represents a line removed from the old file. + LineRemoved +) + +// DiffLine represents a single line in a diff, either from the old file, +// the new file, or a context line. +type DiffLine struct { + OldLineNo int // Line number in the old file (0 for added lines) + NewLineNo int // Line number in the new file (0 for removed lines) + Kind LineType // Type of line (added, removed, context) + Content string // Content of the line +} + +// Hunk represents a section of changes in a diff. +type Hunk struct { + Header string + Lines []DiffLine +} + +// DiffResult contains the parsed result of a diff. +type DiffResult struct { + OldFile string + NewFile string + Hunks []Hunk +} + +// HunkDelta represents the change statistics for a hunk. +type HunkDelta struct { + StartLine1 int + LineCount1 int + StartLine2 int + LineCount2 int +} + +// linePair represents a pair of lines to be displayed side by side. +type linePair struct { + left *DiffLine + right *DiffLine +} + +// ------------------------------------------------------------------------- +// Style Configuration with Option Pattern +// ------------------------------------------------------------------------- + +// StyleConfig defines styling for diff rendering. +type StyleConfig struct { + RemovedLineBg lipgloss.Color + AddedLineBg lipgloss.Color + ContextLineBg lipgloss.Color + HunkLineBg lipgloss.Color + HunkLineFg lipgloss.Color + RemovedFg lipgloss.Color + AddedFg lipgloss.Color + LineNumberFg lipgloss.Color + HighlightStyle string + RemovedHighlightBg lipgloss.Color + AddedHighlightBg lipgloss.Color + RemovedLineNumberBg lipgloss.Color + AddedLineNamerBg lipgloss.Color + RemovedHighlightFg lipgloss.Color + AddedHighlightFg lipgloss.Color +} + +// StyleOption defines a function that modifies a StyleConfig. +type StyleOption func(*StyleConfig) + +// NewStyleConfig creates a StyleConfig with default values and applies any provided options. +func NewStyleConfig(opts ...StyleOption) StyleConfig { + // Set default values + config := StyleConfig{ + RemovedLineBg: lipgloss.Color("#3A3030"), + AddedLineBg: lipgloss.Color("#303A30"), + ContextLineBg: lipgloss.Color("#212121"), + HunkLineBg: lipgloss.Color("#2A2822"), + HunkLineFg: lipgloss.Color("#D4AF37"), + RemovedFg: lipgloss.Color("#7C4444"), + AddedFg: lipgloss.Color("#478247"), + LineNumberFg: lipgloss.Color("#888888"), + HighlightStyle: "dracula", + RemovedHighlightBg: lipgloss.Color("#612726"), + AddedHighlightBg: lipgloss.Color("#256125"), + RemovedLineNumberBg: lipgloss.Color("#332929"), + AddedLineNamerBg: lipgloss.Color("#293229"), + RemovedHighlightFg: lipgloss.Color("#FADADD"), + AddedHighlightFg: lipgloss.Color("#DAFADA"), + } + + // Apply all provided options + for _, opt := range opts { + opt(&config) + } + + return config +} + +// WithRemovedLineBg sets the background color for removed lines. +func WithRemovedLineBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.RemovedLineBg = color + } +} + +// WithAddedLineBg sets the background color for added lines. +func WithAddedLineBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.AddedLineBg = color + } +} + +// WithContextLineBg sets the background color for context lines. +func WithContextLineBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.ContextLineBg = color + } +} + +// WithRemovedFg sets the foreground color for removed line markers. +func WithRemovedFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.RemovedFg = color + } +} + +// WithAddedFg sets the foreground color for added line markers. +func WithAddedFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.AddedFg = color + } +} + +// WithLineNumberFg sets the foreground color for line numbers. +func WithLineNumberFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.LineNumberFg = color + } +} + +// WithHighlightStyle sets the syntax highlighting style. +func WithHighlightStyle(style string) StyleOption { + return func(s *StyleConfig) { + s.HighlightStyle = style + } +} + +// WithRemovedHighlightColors sets the colors for highlighted parts in removed text. +func WithRemovedHighlightColors(bg, fg lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.RemovedHighlightBg = bg + s.RemovedHighlightFg = fg + } +} + +// WithAddedHighlightColors sets the colors for highlighted parts in added text. +func WithAddedHighlightColors(bg, fg lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.AddedHighlightBg = bg + s.AddedHighlightFg = fg + } +} + +// WithRemovedLineNumberBg sets the background color for removed line numbers. +func WithRemovedLineNumberBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.RemovedLineNumberBg = color + } +} + +// WithAddedLineNumberBg sets the background color for added line numbers. +func WithAddedLineNumberBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.AddedLineNamerBg = color + } +} + +func WithHunkLineBg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.HunkLineBg = color + } +} + +func WithHunkLineFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { + s.HunkLineFg = color + } +} + +// ------------------------------------------------------------------------- +// Parse Options with Option Pattern +// ------------------------------------------------------------------------- + +// ParseConfig configures the behavior of diff parsing. +type ParseConfig struct { + ContextSize int // Number of context lines to include +} + +// ParseOption defines a function that modifies a ParseConfig. +type ParseOption func(*ParseConfig) + +// NewParseConfig creates a ParseConfig with default values and applies any provided options. +func NewParseConfig(opts ...ParseOption) ParseConfig { + // Set default values + config := ParseConfig{ + ContextSize: 3, + } + + // Apply all provided options + for _, opt := range opts { + opt(&config) + } + + return config +} + +// WithContextSize sets the number of context lines to include. +func WithContextSize(size int) ParseOption { + return func(p *ParseConfig) { + if size >= 0 { + p.ContextSize = size + } + } +} + +// ------------------------------------------------------------------------- +// Side-by-Side Options with Option Pattern +// ------------------------------------------------------------------------- + +// SideBySideConfig configures the rendering of side-by-side diffs. +type SideBySideConfig struct { + TotalWidth int + Style StyleConfig +} + +// SideBySideOption defines a function that modifies a SideBySideConfig. +type SideBySideOption func(*SideBySideConfig) + +// NewSideBySideConfig creates a SideBySideConfig with default values and applies any provided options. +func NewSideBySideConfig(opts ...SideBySideOption) SideBySideConfig { + // Set default values + config := SideBySideConfig{ + TotalWidth: 160, // Default width for side-by-side view + Style: NewStyleConfig(), + } + + // Apply all provided options + for _, opt := range opts { + opt(&config) + } + + return config +} + +// WithTotalWidth sets the total width for side-by-side view. +func WithTotalWidth(width int) SideBySideOption { + return func(s *SideBySideConfig) { + if width > 0 { + s.TotalWidth = width + } + } +} + +// WithStyle sets the styling configuration. +func WithStyle(style StyleConfig) SideBySideOption { + return func(s *SideBySideConfig) { + s.Style = style + } +} + +// WithStyleOptions applies the specified style options. +func WithStyleOptions(opts ...StyleOption) SideBySideOption { + return func(s *SideBySideConfig) { + s.Style = NewStyleConfig(opts...) + } +} + +// ------------------------------------------------------------------------- +// Diff Parsing and Generation +// ------------------------------------------------------------------------- + +// ParseUnifiedDiff parses a unified diff format string into structured data. +func ParseUnifiedDiff(diff string) (DiffResult, error) { + var result DiffResult + var currentHunk *Hunk + + hunkHeaderRe := regexp.MustCompile(`^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@`) + lines := strings.Split(diff, "\n") + + var oldLine, newLine int + inFileHeader := true + + for _, line := range lines { + // Parse the file headers + if inFileHeader { + if strings.HasPrefix(line, "--- a/") { + result.OldFile = strings.TrimPrefix(line, "--- a/") + continue + } + if strings.HasPrefix(line, "+++ b/") { + result.NewFile = strings.TrimPrefix(line, "+++ b/") + inFileHeader = false + continue + } + } + + // Parse hunk headers + if matches := hunkHeaderRe.FindStringSubmatch(line); matches != nil { + if currentHunk != nil { + result.Hunks = append(result.Hunks, *currentHunk) + } + currentHunk = &Hunk{ + Header: line, + Lines: []DiffLine{}, + } + + oldStart, _ := strconv.Atoi(matches[1]) + newStart, _ := strconv.Atoi(matches[3]) + oldLine = oldStart + newLine = newStart + + continue + } + + if currentHunk == nil { + continue + } + + if len(line) > 0 { + // Process the line based on its prefix + switch line[0] { + case '+': + currentHunk.Lines = append(currentHunk.Lines, DiffLine{ + OldLineNo: 0, + NewLineNo: newLine, + Kind: LineAdded, + Content: line[1:], // skip '+' + }) + newLine++ + case '-': + currentHunk.Lines = append(currentHunk.Lines, DiffLine{ + OldLineNo: oldLine, + NewLineNo: 0, + Kind: LineRemoved, + Content: line[1:], // skip '-' + }) + oldLine++ + default: + currentHunk.Lines = append(currentHunk.Lines, DiffLine{ + OldLineNo: oldLine, + NewLineNo: newLine, + Kind: LineContext, + Content: line, + }) + oldLine++ + newLine++ + } + } else { + // Handle empty lines + currentHunk.Lines = append(currentHunk.Lines, DiffLine{ + OldLineNo: oldLine, + NewLineNo: newLine, + Kind: LineContext, + Content: "", + }) + oldLine++ + newLine++ + } + } + + // Add the last hunk if there is one + if currentHunk != nil { + result.Hunks = append(result.Hunks, *currentHunk) + } + + return result, nil +} + +// HighlightIntralineChanges updates the content of lines in a hunk to show +// character-level differences within lines. +func HighlightIntralineChanges(h *Hunk, style StyleConfig) { + var updated []DiffLine + dmp := diffmatchpatch.New() + + for i := 0; i < len(h.Lines); i++ { + // Look for removed line followed by added line, which might have similar content + if i+1 < len(h.Lines) && + h.Lines[i].Kind == LineRemoved && + h.Lines[i+1].Kind == LineAdded { + + oldLine := h.Lines[i] + newLine := h.Lines[i+1] + + // Find character-level differences + patches := dmp.DiffMain(oldLine.Content, newLine.Content, false) + patches = dmp.DiffCleanupEfficiency(patches) + patches = dmp.DiffCleanupSemantic(patches) + + // Apply highlighting to the differences + oldLine.Content = colorizeSegments(patches, true, style) + newLine.Content = colorizeSegments(patches, false, style) + + updated = append(updated, oldLine, newLine) + i++ // Skip the next line as we've already processed it + } else { + updated = append(updated, h.Lines[i]) + } + } + + h.Lines = updated +} + +// colorizeSegments applies styles to the character-level diff segments. +func colorizeSegments(diffs []diffmatchpatch.Diff, isOld bool, style StyleConfig) string { + var buf strings.Builder + + removeBg := lipgloss.NewStyle(). + Background(style.RemovedHighlightBg). + Foreground(style.RemovedHighlightFg) + + addBg := lipgloss.NewStyle(). + Background(style.AddedHighlightBg). + Foreground(style.AddedHighlightFg) + + removedLineStyle := lipgloss.NewStyle().Background(style.RemovedLineBg) + addedLineStyle := lipgloss.NewStyle().Background(style.AddedLineBg) + + afterBg := false + + for _, d := range diffs { + switch d.Type { + case diffmatchpatch.DiffEqual: + // Handle text that's the same in both versions + if afterBg { + if isOld { + buf.WriteString(removedLineStyle.Render(d.Text)) + } else { + buf.WriteString(addedLineStyle.Render(d.Text)) + } + } else { + buf.WriteString(d.Text) + } + case diffmatchpatch.DiffDelete: + // Handle deleted text (only show in old version) + if isOld { + buf.WriteString(removeBg.Render(d.Text)) + afterBg = true + } + case diffmatchpatch.DiffInsert: + // Handle inserted text (only show in new version) + if !isOld { + buf.WriteString(addBg.Render(d.Text)) + afterBg = true + } + } + } + + return buf.String() +} + +// pairLines converts a flat list of diff lines to pairs for side-by-side display. +func pairLines(lines []DiffLine) []linePair { + var pairs []linePair + i := 0 + + for i < len(lines) { + switch lines[i].Kind { + case LineRemoved: + // Check if the next line is an addition, if so pair them + if i+1 < len(lines) && lines[i+1].Kind == LineAdded { + pairs = append(pairs, linePair{left: &lines[i], right: &lines[i+1]}) + i += 2 + } else { + pairs = append(pairs, linePair{left: &lines[i], right: nil}) + i++ + } + case LineAdded: + pairs = append(pairs, linePair{left: nil, right: &lines[i]}) + i++ + case LineContext: + pairs = append(pairs, linePair{left: &lines[i], right: &lines[i]}) + i++ + } + } + + return pairs +} + +// ------------------------------------------------------------------------- +// Syntax Highlighting +// ------------------------------------------------------------------------- + +// SyntaxHighlight applies syntax highlighting to a string based on the file extension. +func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipgloss.TerminalColor) error { + // Determine the language lexer to use + l := lexers.Match(fileName) + if l == nil { + l = lexers.Analyse(source) + } + if l == nil { + l = lexers.Fallback + } + l = chroma.Coalesce(l) + + // Get the formatter + f := formatters.Get(formatter) + if f == nil { + f = formatters.Fallback + } + + // Get the style + s := styles.Get("dracula") + if s == nil { + s = styles.Fallback + } + + // Modify the style to use the provided background + s, err := s.Builder().Transform( + func(t chroma.StyleEntry) chroma.StyleEntry { + r, g, b, _ := bg.RGBA() + ru8 := uint8(r >> 8) + gu8 := uint8(g >> 8) + bu8 := uint8(b >> 8) + t.Background = chroma.NewColour(ru8, gu8, bu8) + return t + }, + ).Build() + if err != nil { + s = styles.Fallback + } + + // Tokenize and format + it, err := l.Tokenise(nil, source) + if err != nil { + return err + } + + return f.Format(w, s, it) +} + +// highlightLine applies syntax highlighting to a single line. +func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) string { + var buf bytes.Buffer + err := SyntaxHighlight(&buf, line, fileName, "terminal16m", bg) + if err != nil { + return line + } + return buf.String() +} + +// createStyles generates the lipgloss styles needed for rendering diffs. +func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle lipgloss.Style) { + removedLineStyle = lipgloss.NewStyle().Background(config.RemovedLineBg) + addedLineStyle = lipgloss.NewStyle().Background(config.AddedLineBg) + contextLineStyle = lipgloss.NewStyle().Background(config.ContextLineBg) + lineNumberStyle = lipgloss.NewStyle().Foreground(config.LineNumberFg) + + return +} + +// renderLeftColumn formats the left side of a side-by-side diff. +func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string { + if dl == nil { + contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg) + return contextLineStyle.Width(colWidth).Render("") + } + + removedLineStyle, _, contextLineStyle, lineNumberStyle := createStyles(styles) + + var marker string + var bgStyle lipgloss.Style + + switch dl.Kind { + case LineRemoved: + marker = removedLineStyle.Foreground(styles.RemovedFg).Render("-") + bgStyle = removedLineStyle + lineNumberStyle = lineNumberStyle.Foreground(styles.RemovedFg).Background(styles.RemovedLineNumberBg) + case LineAdded: + marker = "?" + bgStyle = contextLineStyle + case LineContext: + marker = contextLineStyle.Render(" ") + bgStyle = contextLineStyle + } + + lineNum := "" + if dl.OldLineNo > 0 { + lineNum = fmt.Sprintf("%6d", dl.OldLineNo) + } + + prefix := lineNumberStyle.Render(lineNum + " " + marker) + content := highlightLine(fileName, dl.Content, bgStyle.GetBackground()) + + if dl.Kind == LineRemoved { + content = bgStyle.Render(" ") + content + } + + lineText := prefix + content + return bgStyle.MaxHeight(1).Width(colWidth).Render(ansi.Truncate(lineText, colWidth, "...")) +} + +// renderRightColumn formats the right side of a side-by-side diff. +func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string { + if dl == nil { + contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg) + return contextLineStyle.Width(colWidth).Render("") + } + + _, addedLineStyle, contextLineStyle, lineNumberStyle := createStyles(styles) + + var marker string + var bgStyle lipgloss.Style + + switch dl.Kind { + case LineAdded: + marker = addedLineStyle.Foreground(styles.AddedFg).Render("+") + bgStyle = addedLineStyle + lineNumberStyle = lineNumberStyle.Foreground(styles.AddedFg).Background(styles.AddedLineNamerBg) + case LineRemoved: + marker = "?" + bgStyle = contextLineStyle + case LineContext: + marker = contextLineStyle.Render(" ") + bgStyle = contextLineStyle + } + + lineNum := "" + if dl.NewLineNo > 0 { + lineNum = fmt.Sprintf("%6d", dl.NewLineNo) + } + + prefix := lineNumberStyle.Render(lineNum + " " + marker) + content := highlightLine(fileName, dl.Content, bgStyle.GetBackground()) + + if dl.Kind == LineAdded { + content = bgStyle.Render(" ") + content + } + + lineText := prefix + content + return bgStyle.MaxHeight(1).Width(colWidth).Render(ansi.Truncate(lineText, colWidth, "...")) +} + +// ------------------------------------------------------------------------- +// Public API Methods +// ------------------------------------------------------------------------- + +// RenderSideBySideHunk formats a hunk for side-by-side display. +func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) string { + // Apply options to create the configuration + config := NewSideBySideConfig(opts...) + + // Make a copy of the hunk so we don't modify the original + hunkCopy := Hunk{Lines: make([]DiffLine, len(h.Lines))} + copy(hunkCopy.Lines, h.Lines) + + // Highlight changes within lines + HighlightIntralineChanges(&hunkCopy, config.Style) + + // Pair lines for side-by-side display + pairs := pairLines(hunkCopy.Lines) + + // Calculate column width + colWidth := config.TotalWidth / 2 + + var sb strings.Builder + for _, p := range pairs { + leftStr := renderLeftColumn(fileName, p.left, colWidth, config.Style) + rightStr := renderRightColumn(fileName, p.right, colWidth, config.Style) + sb.WriteString(leftStr + rightStr + "\n") + } + + return sb.String() +} + +// FormatDiff creates a side-by-side formatted view of a diff. +func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { + diffResult, err := ParseUnifiedDiff(diffText) + if err != nil { + return "", err + } + + var sb strings.Builder + + config := NewSideBySideConfig(opts...) + for i, h := range diffResult.Hunks { + if i > 0 { + sb.WriteString(lipgloss.NewStyle().Background(config.Style.HunkLineBg).Foreground(config.Style.HunkLineFg).Width(config.TotalWidth).Render(h.Header) + "\n") + } + sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) + } + + return sb.String(), nil +} + +// GenerateDiff creates a unified diff from two file contents. +func GenerateDiff(beforeContent, afterContent, beforeFilename, afterFilename string, opts ...ParseOption) (string, int, int) { + config := NewParseConfig(opts...) + + var output strings.Builder + + // Ensure we handle newlines correctly + beforeHasNewline := len(beforeContent) > 0 && beforeContent[len(beforeContent)-1] == '\n' + afterHasNewline := len(afterContent) > 0 && afterContent[len(afterContent)-1] == '\n' + + // Split into lines + beforeLines := strings.Split(beforeContent, "\n") + afterLines := strings.Split(afterContent, "\n") + + // Remove empty trailing element from the split if the content ended with a newline + if beforeHasNewline && len(beforeLines) > 0 { + beforeLines = beforeLines[:len(beforeLines)-1] + } + if afterHasNewline && len(afterLines) > 0 { + afterLines = afterLines[:len(afterLines)-1] + } + + dmp := diffmatchpatch.New() + dmp.DiffTimeout = 5 * time.Second + + // Convert lines to characters for efficient diffing + lineArray1, lineArray2, lineArrays := dmp.DiffLinesToChars(beforeContent, afterContent) + diffs := dmp.DiffMain(lineArray1, lineArray2, false) + diffs = dmp.DiffCharsToLines(diffs, lineArrays) + + // Default filenames if not provided + if beforeFilename == "" { + beforeFilename = "a" + } + if afterFilename == "" { + afterFilename = "b" + } + + // Write diff header + output.WriteString(fmt.Sprintf("diff --git a/%s b/%s\n", beforeFilename, afterFilename)) + output.WriteString(fmt.Sprintf("--- a/%s\n", beforeFilename)) + output.WriteString(fmt.Sprintf("+++ b/%s\n", afterFilename)) + + line1 := 0 // Line numbers start from 0 internally + line2 := 0 + additions := 0 + deletions := 0 + + var hunks []string + var currentHunk strings.Builder + var hunkStartLine1, hunkStartLine2 int + var hunkLines1, hunkLines2 int + inHunk := false + + contextSize := config.ContextSize + + // startHunk begins recording a new hunk + startHunk := func(startLine1, startLine2 int) { + inHunk = true + hunkStartLine1 = startLine1 + hunkStartLine2 = startLine2 + hunkLines1 = 0 + hunkLines2 = 0 + currentHunk.Reset() + } + + // writeHunk adds the current hunk to the hunks slice + writeHunk := func() { + if inHunk { + hunkHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@\n", + hunkStartLine1+1, hunkLines1, + hunkStartLine2+1, hunkLines2) + hunks = append(hunks, hunkHeader+currentHunk.String()) + inHunk = false + } + } + + // Process diffs to create hunks + pendingContext := make([]string, 0, contextSize*2) + var contextLines1, contextLines2 int + + // Helper function to add context lines to the hunk + addContextToHunk := func(lines []string, count int) { + for i := 0; i < count; i++ { + if i < len(lines) { + currentHunk.WriteString(" " + lines[i] + "\n") + hunkLines1++ + hunkLines2++ + } + } + } + + // Process diffs + for _, diff := range diffs { + lines := strings.Split(diff.Text, "\n") + + // Remove empty trailing line that comes from splitting a string that ends with \n + if len(lines) > 0 && lines[len(lines)-1] == "" && diff.Text[len(diff.Text)-1] == '\n' { + lines = lines[:len(lines)-1] + } + + switch diff.Type { + case diffmatchpatch.DiffEqual: + // If we have enough equal lines to serve as context, add them to pending + pendingContext = append(pendingContext, lines...) + + // If pending context grows too large, trim it + if len(pendingContext) > contextSize*2 { + pendingContext = pendingContext[len(pendingContext)-contextSize*2:] + } + + // If we're in a hunk, add the necessary context + if inHunk { + // Only add the first contextSize lines as trailing context + numContextLines := min(contextSize, len(lines)) + addContextToHunk(lines[:numContextLines], numContextLines) + + // If we've added enough trailing context, close the hunk + if numContextLines >= contextSize { + writeHunk() + } + } + + line1 += len(lines) + line2 += len(lines) + contextLines1 += len(lines) + contextLines2 += len(lines) + + case diffmatchpatch.DiffDelete, diffmatchpatch.DiffInsert: + // Start a new hunk if needed + if !inHunk { + // Determine how many context lines we can add before + contextBefore := min(contextSize, len(pendingContext)) + ctxStartIdx := len(pendingContext) - contextBefore + + // Calculate the correct start lines + startLine1 := line1 - contextLines1 + ctxStartIdx + startLine2 := line2 - contextLines2 + ctxStartIdx + + startHunk(startLine1, startLine2) + + // Add the context lines before + addContextToHunk(pendingContext[ctxStartIdx:], contextBefore) + } + + // Reset context tracking when we see a diff + pendingContext = pendingContext[:0] + contextLines1 = 0 + contextLines2 = 0 + + // Add the changes + if diff.Type == diffmatchpatch.DiffDelete { + for _, line := range lines { + currentHunk.WriteString("-" + line + "\n") + hunkLines1++ + deletions++ + } + line1 += len(lines) + } else { // DiffInsert + for _, line := range lines { + currentHunk.WriteString("+" + line + "\n") + hunkLines2++ + additions++ + } + line2 += len(lines) + } + } + } + + // Write the final hunk if there's one pending + if inHunk { + writeHunk() + } + + // Merge hunks that are close to each other (within 2*contextSize lines) + var mergedHunks []string + if len(hunks) > 0 { + mergedHunks = append(mergedHunks, hunks[0]) + + for i := 1; i < len(hunks); i++ { + prevHunk := mergedHunks[len(mergedHunks)-1] + currHunk := hunks[i] + + // Extract line numbers to check proximity + var prevStart, prevLen, currStart, currLen int + fmt.Sscanf(prevHunk, "@@ -%d,%d", &prevStart, &prevLen) + fmt.Sscanf(currHunk, "@@ -%d,%d", &currStart, &currLen) + + prevEnd := prevStart + prevLen - 1 + + // If hunks are close, merge them + if currStart-prevEnd <= contextSize*2 { + // Create a merged hunk - this is a simplification, real git has more complex merging logic + merged := mergeHunks(prevHunk, currHunk) + mergedHunks[len(mergedHunks)-1] = merged + } else { + mergedHunks = append(mergedHunks, currHunk) + } + } + } + + // Write all hunks to output + for _, hunk := range mergedHunks { + output.WriteString(hunk) + } + + // Handle "No newline at end of file" notifications + if !beforeHasNewline && len(beforeLines) > 0 { + // Find the last deletion in the diff and add the notification after it + lastPos := strings.LastIndex(output.String(), "\n-") + if lastPos != -1 { + // Insert the notification after the line + str := output.String() + output.Reset() + output.WriteString(str[:lastPos+1]) + output.WriteString("\\ No newline at end of file\n") + output.WriteString(str[lastPos+1:]) + } + } + + if !afterHasNewline && len(afterLines) > 0 { + // Find the last insertion in the diff and add the notification after it + lastPos := strings.LastIndex(output.String(), "\n+") + if lastPos != -1 { + // Insert the notification after the line + str := output.String() + output.Reset() + output.WriteString(str[:lastPos+1]) + output.WriteString("\\ No newline at end of file\n") + output.WriteString(str[lastPos+1:]) + } + } + + // Return the diff without the summary line + return output.String(), additions, deletions +} + +// Helper function to merge two hunks +func mergeHunks(hunk1, hunk2 string) string { + // This is a simplified implementation + // A full implementation would need to properly recalculate the hunk header + // and remove redundant context lines + + // Extract header info from both hunks + var start1, len1, start2, len2 int + var startB1, lenB1, startB2, lenB2 int + + fmt.Sscanf(hunk1, "@@ -%d,%d +%d,%d @@", &start1, &len1, &startB1, &lenB1) + fmt.Sscanf(hunk2, "@@ -%d,%d +%d,%d @@", &start2, &len2, &startB2, &lenB2) + + // Split the hunks to get content + parts1 := strings.SplitN(hunk1, "\n", 2) + parts2 := strings.SplitN(hunk2, "\n", 2) + + content1 := "" + content2 := "" + + if len(parts1) > 1 { + content1 = parts1[1] + } + if len(parts2) > 1 { + content2 = parts2[1] + } + + // Calculate the new header + newEnd := max(start1+len1-1, start2+len2-1) + newEndB := max(startB1+lenB1-1, startB2+lenB2-1) + + newLen := newEnd - start1 + 1 + newLenB := newEndB - startB1 + 1 + + newHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@", start1, newLen, startB1, newLenB) + + // Combine the content, potentially with some overlap handling + return newHeader + "\n" + content1 + content2 +} diff --git a/internal/git/diff.go b/internal/git/diff.go deleted file mode 100644 index 2ab13964246d8c2c814823f3c8e2009859b991bf..0000000000000000000000000000000000000000 --- a/internal/git/diff.go +++ /dev/null @@ -1,264 +0,0 @@ -package git - -import ( - "bytes" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - - "github.com/go-git/go-git/v5" - "github.com/go-git/go-git/v5/plumbing/object" -) - -type DiffStats struct { - Additions int - Removals int -} - -func GenerateGitDiff(filePath string, contentBefore string, contentAfter string) (string, error) { - tempDir, err := os.MkdirTemp("", "git-diff-temp") - if err != nil { - return "", fmt.Errorf("failed to create temp dir: %w", err) - } - defer os.RemoveAll(tempDir) - - repo, err := git.PlainInit(tempDir, false) - if err != nil { - return "", fmt.Errorf("failed to initialize git repo: %w", err) - } - - wt, err := repo.Worktree() - if err != nil { - return "", fmt.Errorf("failed to get worktree: %w", err) - } - - fullPath := filepath.Join(tempDir, filePath) - if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { - return "", fmt.Errorf("failed to create directories: %w", err) - } - if err = os.WriteFile(fullPath, []byte(contentBefore), 0o644); err != nil { - return "", fmt.Errorf("failed to write 'before' content: %w", err) - } - - _, err = wt.Add(filePath) - if err != nil { - return "", fmt.Errorf("failed to add file to git: %w", err) - } - - beforeCommit, err := wt.Commit("Before", &git.CommitOptions{ - Author: &object.Signature{ - Name: "OpenCode", - Email: "coder@opencode.ai", - When: time.Now(), - }, - }) - if err != nil { - return "", fmt.Errorf("failed to commit 'before' version: %w", err) - } - - if err = os.WriteFile(fullPath, []byte(contentAfter), 0o644); err != nil { - return "", fmt.Errorf("failed to write 'after' content: %w", err) - } - - _, err = wt.Add(filePath) - if err != nil { - return "", fmt.Errorf("failed to add updated file to git: %w", err) - } - - afterCommit, err := wt.Commit("After", &git.CommitOptions{ - Author: &object.Signature{ - Name: "OpenCode", - Email: "coder@opencode.ai", - When: time.Now(), - }, - }) - if err != nil { - return "", fmt.Errorf("failed to commit 'after' version: %w", err) - } - - beforeCommitObj, err := repo.CommitObject(beforeCommit) - if err != nil { - return "", fmt.Errorf("failed to get 'before' commit: %w", err) - } - - afterCommitObj, err := repo.CommitObject(afterCommit) - if err != nil { - return "", fmt.Errorf("failed to get 'after' commit: %w", err) - } - - patch, err := beforeCommitObj.Patch(afterCommitObj) - if err != nil { - return "", fmt.Errorf("failed to generate patch: %w", err) - } - - return patch.String(), nil -} - -func GenerateGitDiffWithStats(filePath string, contentBefore string, contentAfter string) (string, DiffStats, error) { - tempDir, err := os.MkdirTemp("", "git-diff-temp") - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to create temp dir: %w", err) - } - defer os.RemoveAll(tempDir) - - repo, err := git.PlainInit(tempDir, false) - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to initialize git repo: %w", err) - } - - wt, err := repo.Worktree() - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to get worktree: %w", err) - } - - fullPath := filepath.Join(tempDir, filePath) - if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { - return "", DiffStats{}, fmt.Errorf("failed to create directories: %w", err) - } - if err = os.WriteFile(fullPath, []byte(contentBefore), 0o644); err != nil { - return "", DiffStats{}, fmt.Errorf("failed to write 'before' content: %w", err) - } - - _, err = wt.Add(filePath) - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to add file to git: %w", err) - } - - beforeCommit, err := wt.Commit("Before", &git.CommitOptions{ - Author: &object.Signature{ - Name: "OpenCode", - Email: "coder@opencode.ai", - When: time.Now(), - }, - }) - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to commit 'before' version: %w", err) - } - - if err = os.WriteFile(fullPath, []byte(contentAfter), 0o644); err != nil { - return "", DiffStats{}, fmt.Errorf("failed to write 'after' content: %w", err) - } - - _, err = wt.Add(filePath) - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to add updated file to git: %w", err) - } - - afterCommit, err := wt.Commit("After", &git.CommitOptions{ - Author: &object.Signature{ - Name: "OpenCode", - Email: "coder@opencode.ai", - When: time.Now(), - }, - }) - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to commit 'after' version: %w", err) - } - - beforeCommitObj, err := repo.CommitObject(beforeCommit) - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to get 'before' commit: %w", err) - } - - afterCommitObj, err := repo.CommitObject(afterCommit) - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to get 'after' commit: %w", err) - } - - patch, err := beforeCommitObj.Patch(afterCommitObj) - if err != nil { - return "", DiffStats{}, fmt.Errorf("failed to generate patch: %w", err) - } - - stats := DiffStats{} - for _, fileStat := range patch.Stats() { - stats.Additions += fileStat.Addition - stats.Removals += fileStat.Deletion - } - - return patch.String(), stats, nil -} - -func FormatDiff(diffText string, width int) (string, error) { - if isSplitDiffsAvailable() { - return formatWithSplitDiffs(diffText, width) - } - - return formatSimple(diffText), nil -} - -func isSplitDiffsAvailable() bool { - _, err := exec.LookPath("node") - return err == nil -} - -func formatWithSplitDiffs(diffText string, width int) (string, error) { - args := []string{ - "--color", - } - - var diffCmd *exec.Cmd - - if _, err := exec.LookPath("git-split-diffs-opencode"); err == nil { - fullArgs := append([]string{"git-split-diffs-opencode"}, args...) - diffCmd = exec.Command(fullArgs[0], fullArgs[1:]...) - } else { - npxArgs := append([]string{"git-split-diffs-opencode"}, args...) - diffCmd = exec.Command("npx", npxArgs...) - } - - diffCmd.Env = append(os.Environ(), fmt.Sprintf("DIFF_COLUMNS=%d", width)) - - diffCmd.Stdin = strings.NewReader(diffText) - - var out bytes.Buffer - diffCmd.Stdout = &out - - var stderr bytes.Buffer - diffCmd.Stderr = &stderr - - if err := diffCmd.Run(); err != nil { - return "", fmt.Errorf("git-split-diffs-opencode error: %w, stderr: %s", err, stderr.String()) - } - - return out.String(), nil -} - -func formatSimple(diffText string) string { - lines := strings.Split(diffText, "\n") - var result strings.Builder - - for _, line := range lines { - if len(line) == 0 { - result.WriteString("\n") - continue - } - - switch line[0] { - case '+': - result.WriteString("\033[32m" + line + "\033[0m\n") - case '-': - result.WriteString("\033[31m" + line + "\033[0m\n") - case '@': - result.WriteString("\033[36m" + line + "\033[0m\n") - case 'd': - if strings.HasPrefix(line, "diff --git") { - result.WriteString("\033[1m" + line + "\033[0m\n") - } else { - result.WriteString(line + "\n") - } - default: - result.WriteString(line + "\n") - } - } - - if !strings.HasSuffix(diffText, "\n") { - output := result.String() - return output[:len(output)-1] - } - - return result.String() -} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index b74e427296e4cca28349e6de920cef6399f7f823..1305879b917816b37269c5b6619ac0c0edb178f0 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -10,7 +10,7 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/git" + "github.com/kujtimiihoxha/termai/internal/diff" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -182,14 +182,12 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") } - diff, stats, err := git.GenerateGitDiffWithStats( - removeWorkingDirectoryPrefix(filePath), + diff, additions, removals := diff.GenerateDiff( "", content, + filePath, + filePath, ) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err) - } p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), @@ -218,8 +216,8 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) NewTextResponse("File created: "+filePath), EditResponseMetadata{ Diff: diff, - Additions: stats.Additions, - Removals: stats.Removals, + Additions: additions, + Removals: removals, }, ), nil } @@ -275,14 +273,12 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") } - diff, stats, err := git.GenerateGitDiffWithStats( - removeWorkingDirectoryPrefix(filePath), + diff, additions, removals := diff.GenerateDiff( oldContent, newContent, + filePath, + filePath, ) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err) - } p := e.permissions.Request( permission.CreatePermissionRequest{ @@ -311,8 +307,8 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string NewTextResponse("Content deleted from file: "+filePath), EditResponseMetadata{ Diff: diff, - Additions: stats.Additions, - Removals: stats.Removals, + Additions: additions, + Removals: removals, }, ), nil } @@ -367,15 +363,12 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS if sessionID == "" || messageID == "" { return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") } - diff, stats, err := git.GenerateGitDiffWithStats( - removeWorkingDirectoryPrefix(filePath), + diff, additions, removals := diff.GenerateDiff( oldContent, newContent, + filePath, + filePath, ) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to get file diff: %w", err) - } - p := e.permissions.Request( permission.CreatePermissionRequest{ Path: filepath.Dir(filePath), @@ -405,7 +398,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS NewTextResponse("Content replaced in file: "+filePath), EditResponseMetadata{ Diff: diff, - Additions: stats.Additions, - Removals: stats.Removals, + Additions: additions, + Removals: removals, }), nil } diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 8318f2851fa0a9e830a842fe7823eab567858c1c..ef2ca01f4a09f5efb065607365ac57ec5f25588b 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -9,7 +9,7 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/git" + "github.com/kujtimiihoxha/termai/internal/diff" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -149,14 +149,13 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error if sessionID == "" || messageID == "" { return ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - diff, stats, err := git.GenerateGitDiffWithStats( - removeWorkingDirectoryPrefix(filePath), + + diff, additions, removals := diff.GenerateDiff( oldContent, params.Content, + filePath, + filePath, ) - if err != nil { - return ToolResponse{}, fmt.Errorf("error generating diff: %w", err) - } p := w.permissions.Request( permission.CreatePermissionRequest{ Path: filePath, @@ -188,8 +187,8 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return WithResponseMetadata(NewTextResponse(result), WriteResponseMetadata{ Diff: diff, - Additions: stats.Additions, - Removals: stats.Removals, + Additions: additions, + Removals: removals, }, ), nil } diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 344310eb6563688cf83e989eecbf7e24ec0bbd78..d147f89cd0b40095336f989d3d42f9a107822406 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -9,7 +9,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/git" + "github.com/kujtimiihoxha/termai/internal/diff" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/permission" "github.com/kujtimiihoxha/termai/internal/tui/components/core" @@ -242,7 +242,7 @@ func (p *permissionDialogCmp) render() string { // Calculate content height dynamically based on window size maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 p.contentViewPort.Height = maxContentHeight - diff, err := git.FormatDiff(pr.Diff, p.contentViewPort.Width) + diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) if err != nil { diff = fmt.Sprintf("Error formatting diff: %v", err) } @@ -291,7 +291,7 @@ func (p *permissionDialogCmp) render() string { // Calculate content height dynamically based on window size maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 p.contentViewPort.Height = maxContentHeight - diff, err := git.FormatDiff(pr.Diff, p.contentViewPort.Width) + diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) if err != nil { diff = fmt.Sprintf("Error formatting diff: %v", err) } From 013694832f4c5a7819bfd9a801346e4c3fb22e77 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 15:48:01 +0200 Subject: [PATCH 15/41] fix diff --- go.mod | 20 +- go.sum | 57 +++++- internal/diff/diff.go | 362 +++++++++--------------------------- internal/llm/tools/edit.go | 3 - internal/llm/tools/write.go | 1 - 5 files changed, 155 insertions(+), 288 deletions(-) diff --git a/go.mod b/go.mod index b201be8005cc1893589e1fc022f944a70094f48c..925a71097a6d55b7cbbc78c16151d0bbb42a4a32 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/x/ansi v0.8.0 github.com/fsnotify/fsnotify v1.8.0 + github.com/go-git/go-git/v5 v5.15.0 github.com/go-logfmt/logfmt v0.6.0 github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/generative-ai-go v0.19.0 @@ -46,6 +47,9 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/longrunning v0.5.7 // indirect + dario.cat/mergo v1.0.0 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/ProtonMail/go-crypto v1.1.6 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect @@ -68,15 +72,20 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/cloudflare/circl v1.6.1 // indirect + github.com/cyphar/filepath-securejoin v0.4.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/emirpasic/gods v1.18.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect + github.com/go-git/go-billy/v5 v5.6.2 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/google/go-cmp v0.7.0 // indirect + github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect @@ -84,6 +93,8 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect + github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -91,11 +102,12 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/pjbgf/sha1cd v0.3.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect github.com/spf13/cast v1.7.1 // indirect @@ -105,6 +117,7 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.7.8 // indirect @@ -118,7 +131,6 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.design/x/clipboard v0.7.0 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/image v0.14.0 // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect @@ -132,6 +144,6 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect google.golang.org/grpc v1.67.3 // indirect google.golang.org/protobuf v1.36.1 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 08e7e7c42e61f69fd454df0683507bb91ddd3cd9..9c2c2df8fbcb0d909a73a67ead79339bf3814892 100644 --- a/go.sum +++ b/go.sum @@ -10,10 +10,17 @@ cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4 cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5x+rHJnb1ssNmqpLH/k= github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= +github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/ProtonMail/go-crypto v1.1.6 h1:ZcV+Ropw6Qn0AX9brlQLAUXfqLBc7Bl+f/DmNxpLfdw= +github.com/ProtonMail/go-crypto v1.1.6/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4yPeE= github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= @@ -24,8 +31,12 @@ github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 h1:h7qxtumNjKPWFv1QM/HJy60MteeW23iKeEtBoY7bYZk= github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= @@ -88,7 +99,11 @@ github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 h1:qko github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= +github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= +github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= +github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -96,6 +111,10 @@ github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yA github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= +github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -104,6 +123,16 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= +github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM= +github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= +github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= +github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= +github.com/go-git/go-git/v5 v5.15.0 h1:f5Qn0W0F7ry1iN0ZwIU5m/n7/BKB4hiZfc+zlZx7ly0= +github.com/go-git/go-git/v5 v5.15.0/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -115,6 +144,8 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -138,8 +169,11 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= +github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= +github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= +github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -179,11 +213,17 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= +github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/openai/openai-go v0.1.0-beta.2 h1:Ra5nCFkbEl9w+UJwAciC4kqnIBUCcJazhmMA0/YN894= github.com/openai/openai-go v0.1.0-beta.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= +github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= +github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -203,6 +243,9 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8= +github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= @@ -216,6 +259,7 @@ github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.20.0 h1:zrxIyR3RQIOsarIrgL8+sAvALXul9jeEPa06Y0Ph6vY= github.com/spf13/viper v1.20.0/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -232,6 +276,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= +github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= @@ -260,6 +306,7 @@ golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo= golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= @@ -277,6 +324,7 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= @@ -294,10 +342,14 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -320,6 +372,7 @@ golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= @@ -348,6 +401,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 4e6aa9f5bcba9386a0f5c0e9c801254470950090..c4088d32933ec1580af59de500cce34bffbfcdd9 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -4,6 +4,8 @@ import ( "bytes" "fmt" "io" + "os" + "path/filepath" "regexp" "strconv" "strings" @@ -15,6 +17,8 @@ import ( "github.com/alecthomas/chroma/v2/styles" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/object" "github.com/sergi/go-diff/diffmatchpatch" ) @@ -224,21 +228,6 @@ type ParseConfig struct { // ParseOption defines a function that modifies a ParseConfig. type ParseOption func(*ParseConfig) -// NewParseConfig creates a ParseConfig with default values and applies any provided options. -func NewParseConfig(opts ...ParseOption) ParseConfig { - // Set default values - config := ParseConfig{ - ContextSize: 3, - } - - // Apply all provided options - for _, opt := range opts { - opt(&config) - } - - return config -} - // WithContextSize sets the number of context lines to include. func WithContextSize(size int) ParseOption { return func(p *ParseConfig) { @@ -347,6 +336,10 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { continue } + // ignore the \\ No newline at end of file + if strings.HasPrefix(line, "\\ No newline at end of file") { + continue + } if currentHunk == nil { continue } @@ -450,32 +443,22 @@ func colorizeSegments(diffs []diffmatchpatch.Diff, isOld bool, style StyleConfig removedLineStyle := lipgloss.NewStyle().Background(style.RemovedLineBg) addedLineStyle := lipgloss.NewStyle().Background(style.AddedLineBg) - afterBg := false - for _, d := range diffs { switch d.Type { case diffmatchpatch.DiffEqual: // Handle text that's the same in both versions - if afterBg { - if isOld { - buf.WriteString(removedLineStyle.Render(d.Text)) - } else { - buf.WriteString(addedLineStyle.Render(d.Text)) - } - } else { - buf.WriteString(d.Text) - } + buf.WriteString(d.Text) case diffmatchpatch.DiffDelete: // Handle deleted text (only show in old version) if isOld { buf.WriteString(removeBg.Render(d.Text)) - afterBg = true + buf.WriteString(removedLineStyle.Render("")) } case diffmatchpatch.DiffInsert: // Handle inserted text (only show in new version) if !isOld { buf.WriteString(addBg.Render(d.Text)) - afterBg = true + buf.WriteString(addedLineStyle.Render("")) } } } @@ -621,7 +604,13 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC } lineText := prefix + content - return bgStyle.MaxHeight(1).Width(colWidth).Render(ansi.Truncate(lineText, colWidth, "...")) + return bgStyle.MaxHeight(1).Width(colWidth).Render( + ansi.Truncate( + lineText, + colWidth, + lipgloss.NewStyle().Background(styles.HunkLineBg).Foreground(styles.HunkLineFg).Render("..."), + ), + ) } // renderRightColumn formats the right side of a side-by-side diff. @@ -662,7 +651,13 @@ func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles Style } lineText := prefix + content - return bgStyle.MaxHeight(1).Width(colWidth).Render(ansi.Truncate(lineText, colWidth, "...")) + return bgStyle.MaxHeight(1).Width(colWidth).Render( + ansi.Truncate( + lineText, + colWidth, + lipgloss.NewStyle().Background(styles.HunkLineBg).Foreground(styles.HunkLineFg).Render("..."), + ), + ) } // ------------------------------------------------------------------------- @@ -718,278 +713,87 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { } // GenerateDiff creates a unified diff from two file contents. -func GenerateDiff(beforeContent, afterContent, beforeFilename, afterFilename string, opts ...ParseOption) (string, int, int) { - config := NewParseConfig(opts...) - - var output strings.Builder - - // Ensure we handle newlines correctly - beforeHasNewline := len(beforeContent) > 0 && beforeContent[len(beforeContent)-1] == '\n' - afterHasNewline := len(afterContent) > 0 && afterContent[len(afterContent)-1] == '\n' - - // Split into lines - beforeLines := strings.Split(beforeContent, "\n") - afterLines := strings.Split(afterContent, "\n") - - // Remove empty trailing element from the split if the content ended with a newline - if beforeHasNewline && len(beforeLines) > 0 { - beforeLines = beforeLines[:len(beforeLines)-1] - } - if afterHasNewline && len(afterLines) > 0 { - afterLines = afterLines[:len(afterLines)-1] +func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) { + tempDir, err := os.MkdirTemp("", "git-diff-temp") + if err != nil { + return "", 0, 0 } + defer os.RemoveAll(tempDir) - dmp := diffmatchpatch.New() - dmp.DiffTimeout = 5 * time.Second + repo, err := git.PlainInit(tempDir, false) + if err != nil { + return "", 0, 0 + } - // Convert lines to characters for efficient diffing - lineArray1, lineArray2, lineArrays := dmp.DiffLinesToChars(beforeContent, afterContent) - diffs := dmp.DiffMain(lineArray1, lineArray2, false) - diffs = dmp.DiffCharsToLines(diffs, lineArrays) + wt, err := repo.Worktree() + if err != nil { + return "", 0, 0 + } - // Default filenames if not provided - if beforeFilename == "" { - beforeFilename = "a" + fullPath := filepath.Join(tempDir, fileName) + if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + return "", 0, 0 } - if afterFilename == "" { - afterFilename = "b" + if err = os.WriteFile(fullPath, []byte(beforeContent), 0o644); err != nil { + return "", 0, 0 } - // Write diff header - output.WriteString(fmt.Sprintf("diff --git a/%s b/%s\n", beforeFilename, afterFilename)) - output.WriteString(fmt.Sprintf("--- a/%s\n", beforeFilename)) - output.WriteString(fmt.Sprintf("+++ b/%s\n", afterFilename)) - - line1 := 0 // Line numbers start from 0 internally - line2 := 0 - additions := 0 - deletions := 0 - - var hunks []string - var currentHunk strings.Builder - var hunkStartLine1, hunkStartLine2 int - var hunkLines1, hunkLines2 int - inHunk := false - - contextSize := config.ContextSize - - // startHunk begins recording a new hunk - startHunk := func(startLine1, startLine2 int) { - inHunk = true - hunkStartLine1 = startLine1 - hunkStartLine2 = startLine2 - hunkLines1 = 0 - hunkLines2 = 0 - currentHunk.Reset() - } - - // writeHunk adds the current hunk to the hunks slice - writeHunk := func() { - if inHunk { - hunkHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@\n", - hunkStartLine1+1, hunkLines1, - hunkStartLine2+1, hunkLines2) - hunks = append(hunks, hunkHeader+currentHunk.String()) - inHunk = false - } + _, err = wt.Add(fileName) + if err != nil { + return "", 0, 0 } - // Process diffs to create hunks - pendingContext := make([]string, 0, contextSize*2) - var contextLines1, contextLines2 int - - // Helper function to add context lines to the hunk - addContextToHunk := func(lines []string, count int) { - for i := 0; i < count; i++ { - if i < len(lines) { - currentHunk.WriteString(" " + lines[i] + "\n") - hunkLines1++ - hunkLines2++ - } - } + beforeCommit, err := wt.Commit("Before", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "coder@opencode.ai", + When: time.Now(), + }, + }) + if err != nil { + return "", 0, 0 } - // Process diffs - for _, diff := range diffs { - lines := strings.Split(diff.Text, "\n") - - // Remove empty trailing line that comes from splitting a string that ends with \n - if len(lines) > 0 && lines[len(lines)-1] == "" && diff.Text[len(diff.Text)-1] == '\n' { - lines = lines[:len(lines)-1] - } - - switch diff.Type { - case diffmatchpatch.DiffEqual: - // If we have enough equal lines to serve as context, add them to pending - pendingContext = append(pendingContext, lines...) - - // If pending context grows too large, trim it - if len(pendingContext) > contextSize*2 { - pendingContext = pendingContext[len(pendingContext)-contextSize*2:] - } - - // If we're in a hunk, add the necessary context - if inHunk { - // Only add the first contextSize lines as trailing context - numContextLines := min(contextSize, len(lines)) - addContextToHunk(lines[:numContextLines], numContextLines) - - // If we've added enough trailing context, close the hunk - if numContextLines >= contextSize { - writeHunk() - } - } - - line1 += len(lines) - line2 += len(lines) - contextLines1 += len(lines) - contextLines2 += len(lines) - - case diffmatchpatch.DiffDelete, diffmatchpatch.DiffInsert: - // Start a new hunk if needed - if !inHunk { - // Determine how many context lines we can add before - contextBefore := min(contextSize, len(pendingContext)) - ctxStartIdx := len(pendingContext) - contextBefore - - // Calculate the correct start lines - startLine1 := line1 - contextLines1 + ctxStartIdx - startLine2 := line2 - contextLines2 + ctxStartIdx - - startHunk(startLine1, startLine2) - - // Add the context lines before - addContextToHunk(pendingContext[ctxStartIdx:], contextBefore) - } - - // Reset context tracking when we see a diff - pendingContext = pendingContext[:0] - contextLines1 = 0 - contextLines2 = 0 - - // Add the changes - if diff.Type == diffmatchpatch.DiffDelete { - for _, line := range lines { - currentHunk.WriteString("-" + line + "\n") - hunkLines1++ - deletions++ - } - line1 += len(lines) - } else { // DiffInsert - for _, line := range lines { - currentHunk.WriteString("+" + line + "\n") - hunkLines2++ - additions++ - } - line2 += len(lines) - } - } + if err = os.WriteFile(fullPath, []byte(afterContent), 0o644); err != nil { } - // Write the final hunk if there's one pending - if inHunk { - writeHunk() + _, err = wt.Add(fileName) + if err != nil { + return "", 0, 0 } - // Merge hunks that are close to each other (within 2*contextSize lines) - var mergedHunks []string - if len(hunks) > 0 { - mergedHunks = append(mergedHunks, hunks[0]) - - for i := 1; i < len(hunks); i++ { - prevHunk := mergedHunks[len(mergedHunks)-1] - currHunk := hunks[i] - - // Extract line numbers to check proximity - var prevStart, prevLen, currStart, currLen int - fmt.Sscanf(prevHunk, "@@ -%d,%d", &prevStart, &prevLen) - fmt.Sscanf(currHunk, "@@ -%d,%d", &currStart, &currLen) - - prevEnd := prevStart + prevLen - 1 - - // If hunks are close, merge them - if currStart-prevEnd <= contextSize*2 { - // Create a merged hunk - this is a simplification, real git has more complex merging logic - merged := mergeHunks(prevHunk, currHunk) - mergedHunks[len(mergedHunks)-1] = merged - } else { - mergedHunks = append(mergedHunks, currHunk) - } - } + afterCommit, err := wt.Commit("After", &git.CommitOptions{ + Author: &object.Signature{ + Name: "OpenCode", + Email: "coder@opencode.ai", + When: time.Now(), + }, + }) + if err != nil { + return "", 0, 0 } - // Write all hunks to output - for _, hunk := range mergedHunks { - output.WriteString(hunk) + beforeCommitObj, err := repo.CommitObject(beforeCommit) + if err != nil { + return "", 0, 0 } - // Handle "No newline at end of file" notifications - if !beforeHasNewline && len(beforeLines) > 0 { - // Find the last deletion in the diff and add the notification after it - lastPos := strings.LastIndex(output.String(), "\n-") - if lastPos != -1 { - // Insert the notification after the line - str := output.String() - output.Reset() - output.WriteString(str[:lastPos+1]) - output.WriteString("\\ No newline at end of file\n") - output.WriteString(str[lastPos+1:]) - } + afterCommitObj, err := repo.CommitObject(afterCommit) + if err != nil { + return "", 0, 0 } - if !afterHasNewline && len(afterLines) > 0 { - // Find the last insertion in the diff and add the notification after it - lastPos := strings.LastIndex(output.String(), "\n+") - if lastPos != -1 { - // Insert the notification after the line - str := output.String() - output.Reset() - output.WriteString(str[:lastPos+1]) - output.WriteString("\\ No newline at end of file\n") - output.WriteString(str[lastPos+1:]) - } + patch, err := beforeCommitObj.Patch(afterCommitObj) + if err != nil { + return "", 0, 0 } - // Return the diff without the summary line - return output.String(), additions, deletions -} - -// Helper function to merge two hunks -func mergeHunks(hunk1, hunk2 string) string { - // This is a simplified implementation - // A full implementation would need to properly recalculate the hunk header - // and remove redundant context lines - - // Extract header info from both hunks - var start1, len1, start2, len2 int - var startB1, lenB1, startB2, lenB2 int - - fmt.Sscanf(hunk1, "@@ -%d,%d +%d,%d @@", &start1, &len1, &startB1, &lenB1) - fmt.Sscanf(hunk2, "@@ -%d,%d +%d,%d @@", &start2, &len2, &startB2, &lenB2) - - // Split the hunks to get content - parts1 := strings.SplitN(hunk1, "\n", 2) - parts2 := strings.SplitN(hunk2, "\n", 2) - - content1 := "" - content2 := "" - - if len(parts1) > 1 { - content1 = parts1[1] - } - if len(parts2) > 1 { - content2 = parts2[1] + additions := 0 + removals := 0 + for _, fileStat := range patch.Stats() { + additions += fileStat.Addition + removals += fileStat.Deletion } - // Calculate the new header - newEnd := max(start1+len1-1, start2+len2-1) - newEndB := max(startB1+lenB1-1, startB2+lenB2-1) - - newLen := newEnd - start1 + 1 - newLenB := newEndB - startB1 + 1 - - newHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@", start1, newLen, startB1, newLenB) - - // Combine the content, potentially with some overlap handling - return newHeader + "\n" + content1 + content2 + return patch.String(), additions, removals } diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 1305879b917816b37269c5b6619ac0c0edb178f0..08d6d446c9320b7da3e60040ba61d84022f3affb 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -186,7 +186,6 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) "", content, filePath, - filePath, ) p := e.permissions.Request( permission.CreatePermissionRequest{ @@ -277,7 +276,6 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string oldContent, newContent, filePath, - filePath, ) p := e.permissions.Request( @@ -367,7 +365,6 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS oldContent, newContent, filePath, - filePath, ) p := e.permissions.Request( permission.CreatePermissionRequest{ diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index ef2ca01f4a09f5efb065607365ac57ec5f25588b..889561d2af4d5e0f428deb7562535fd4647f2331 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -154,7 +154,6 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error oldContent, params.Content, filePath, - filePath, ) p := w.permissions.Request( permission.CreatePermissionRequest{ From f6be348bf704ab3d012eec549357f5acd9c74796 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 14 Apr 2025 20:12:28 +0200 Subject: [PATCH 16/41] fix segment diff and add new theme --- internal/diff/diff.go | 523 ++++++++++++++++++++++++++++-------------- 1 file changed, 348 insertions(+), 175 deletions(-) diff --git a/internal/diff/diff.go b/internal/diff/diff.go index c4088d32933ec1580af59de500cce34bffbfcdd9..02d4d7140c99d94b4aa0894316cdfb243b78b54b 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -22,89 +22,95 @@ import ( "github.com/sergi/go-diff/diffmatchpatch" ) +// ------------------------------------------------------------------------- +// Core Types +// ------------------------------------------------------------------------- + // LineType represents the kind of line in a diff. type LineType int const ( - // LineContext represents a line that exists in both the old and new file. - LineContext LineType = iota - // LineAdded represents a line added in the new file. - LineAdded - // LineRemoved represents a line removed from the old file. - LineRemoved + LineContext LineType = iota // Line exists in both files + LineAdded // Line added in the new file + LineRemoved // Line removed from the old file ) -// DiffLine represents a single line in a diff, either from the old file, -// the new file, or a context line. +// Segment represents a portion of a line for intra-line highlighting +type Segment struct { + Start int + End int + Type LineType + Text string +} + +// DiffLine represents a single line in a diff type DiffLine struct { - OldLineNo int // Line number in the old file (0 for added lines) - NewLineNo int // Line number in the new file (0 for removed lines) - Kind LineType // Type of line (added, removed, context) - Content string // Content of the line + OldLineNo int // Line number in old file (0 for added lines) + NewLineNo int // Line number in new file (0 for removed lines) + Kind LineType // Type of line (added, removed, context) + Content string // Content of the line + Segments []Segment // Segments for intraline highlighting } -// Hunk represents a section of changes in a diff. +// Hunk represents a section of changes in a diff type Hunk struct { Header string Lines []DiffLine } -// DiffResult contains the parsed result of a diff. +// DiffResult contains the parsed result of a diff type DiffResult struct { OldFile string NewFile string Hunks []Hunk } -// HunkDelta represents the change statistics for a hunk. -type HunkDelta struct { - StartLine1 int - LineCount1 int - StartLine2 int - LineCount2 int -} - -// linePair represents a pair of lines to be displayed side by side. +// linePair represents a pair of lines for side-by-side display type linePair struct { left *DiffLine right *DiffLine } // ------------------------------------------------------------------------- -// Style Configuration with Option Pattern +// Style Configuration // ------------------------------------------------------------------------- -// StyleConfig defines styling for diff rendering. +// StyleConfig defines styling for diff rendering type StyleConfig struct { + // Background colors RemovedLineBg lipgloss.Color AddedLineBg lipgloss.Color ContextLineBg lipgloss.Color HunkLineBg lipgloss.Color - HunkLineFg lipgloss.Color - RemovedFg lipgloss.Color - AddedFg lipgloss.Color - LineNumberFg lipgloss.Color - HighlightStyle string - RemovedHighlightBg lipgloss.Color - AddedHighlightBg lipgloss.Color RemovedLineNumberBg lipgloss.Color AddedLineNamerBg lipgloss.Color - RemovedHighlightFg lipgloss.Color - AddedHighlightFg lipgloss.Color + + // Foreground colors + HunkLineFg lipgloss.Color + RemovedFg lipgloss.Color + AddedFg lipgloss.Color + LineNumberFg lipgloss.Color + RemovedHighlightFg lipgloss.Color + AddedHighlightFg lipgloss.Color + + // Highlight settings + HighlightStyle string + RemovedHighlightBg lipgloss.Color + AddedHighlightBg lipgloss.Color } -// StyleOption defines a function that modifies a StyleConfig. +// StyleOption is a function that modifies a StyleConfig type StyleOption func(*StyleConfig) -// NewStyleConfig creates a StyleConfig with default values and applies any provided options. +// NewStyleConfig creates a StyleConfig with default values func NewStyleConfig(opts ...StyleOption) StyleConfig { - // Set default values + // Default color scheme config := StyleConfig{ RemovedLineBg: lipgloss.Color("#3A3030"), AddedLineBg: lipgloss.Color("#303A30"), ContextLineBg: lipgloss.Color("#212121"), - HunkLineBg: lipgloss.Color("#2A2822"), - HunkLineFg: lipgloss.Color("#D4AF37"), + HunkLineBg: lipgloss.Color("#23252D"), + HunkLineFg: lipgloss.Color("#8CA3B4"), RemovedFg: lipgloss.Color("#7C4444"), AddedFg: lipgloss.Color("#478247"), LineNumberFg: lipgloss.Color("#888888"), @@ -125,56 +131,35 @@ func NewStyleConfig(opts ...StyleOption) StyleConfig { return config } -// WithRemovedLineBg sets the background color for removed lines. +// Style option functions func WithRemovedLineBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.RemovedLineBg = color - } + return func(s *StyleConfig) { s.RemovedLineBg = color } } -// WithAddedLineBg sets the background color for added lines. func WithAddedLineBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.AddedLineBg = color - } + return func(s *StyleConfig) { s.AddedLineBg = color } } -// WithContextLineBg sets the background color for context lines. func WithContextLineBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.ContextLineBg = color - } + return func(s *StyleConfig) { s.ContextLineBg = color } } -// WithRemovedFg sets the foreground color for removed line markers. func WithRemovedFg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.RemovedFg = color - } + return func(s *StyleConfig) { s.RemovedFg = color } } -// WithAddedFg sets the foreground color for added line markers. func WithAddedFg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.AddedFg = color - } + return func(s *StyleConfig) { s.AddedFg = color } } -// WithLineNumberFg sets the foreground color for line numbers. func WithLineNumberFg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.LineNumberFg = color - } + return func(s *StyleConfig) { s.LineNumberFg = color } } -// WithHighlightStyle sets the syntax highlighting style. func WithHighlightStyle(style string) StyleOption { - return func(s *StyleConfig) { - s.HighlightStyle = style - } + return func(s *StyleConfig) { s.HighlightStyle = style } } -// WithRemovedHighlightColors sets the colors for highlighted parts in removed text. func WithRemovedHighlightColors(bg, fg lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.RemovedHighlightBg = bg @@ -182,7 +167,6 @@ func WithRemovedHighlightColors(bg, fg lipgloss.Color) StyleOption { } } -// WithAddedHighlightColors sets the colors for highlighted parts in added text. func WithAddedHighlightColors(bg, fg lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.AddedHighlightBg = bg @@ -190,45 +174,35 @@ func WithAddedHighlightColors(bg, fg lipgloss.Color) StyleOption { } } -// WithRemovedLineNumberBg sets the background color for removed line numbers. func WithRemovedLineNumberBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.RemovedLineNumberBg = color - } + return func(s *StyleConfig) { s.RemovedLineNumberBg = color } } -// WithAddedLineNumberBg sets the background color for added line numbers. func WithAddedLineNumberBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.AddedLineNamerBg = color - } + return func(s *StyleConfig) { s.AddedLineNamerBg = color } } func WithHunkLineBg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.HunkLineBg = color - } + return func(s *StyleConfig) { s.HunkLineBg = color } } func WithHunkLineFg(color lipgloss.Color) StyleOption { - return func(s *StyleConfig) { - s.HunkLineFg = color - } + return func(s *StyleConfig) { s.HunkLineFg = color } } // ------------------------------------------------------------------------- -// Parse Options with Option Pattern +// Parse Configuration // ------------------------------------------------------------------------- -// ParseConfig configures the behavior of diff parsing. +// ParseConfig configures the behavior of diff parsing type ParseConfig struct { ContextSize int // Number of context lines to include } -// ParseOption defines a function that modifies a ParseConfig. +// ParseOption modifies a ParseConfig type ParseOption func(*ParseConfig) -// WithContextSize sets the number of context lines to include. +// WithContextSize sets the number of context lines to include func WithContextSize(size int) ParseOption { return func(p *ParseConfig) { if size >= 0 { @@ -238,27 +212,25 @@ func WithContextSize(size int) ParseOption { } // ------------------------------------------------------------------------- -// Side-by-Side Options with Option Pattern +// Side-by-Side Configuration // ------------------------------------------------------------------------- -// SideBySideConfig configures the rendering of side-by-side diffs. +// SideBySideConfig configures the rendering of side-by-side diffs type SideBySideConfig struct { TotalWidth int Style StyleConfig } -// SideBySideOption defines a function that modifies a SideBySideConfig. +// SideBySideOption modifies a SideBySideConfig type SideBySideOption func(*SideBySideConfig) -// NewSideBySideConfig creates a SideBySideConfig with default values and applies any provided options. +// NewSideBySideConfig creates a SideBySideConfig with default values func NewSideBySideConfig(opts ...SideBySideOption) SideBySideConfig { - // Set default values config := SideBySideConfig{ TotalWidth: 160, // Default width for side-by-side view Style: NewStyleConfig(), } - // Apply all provided options for _, opt := range opts { opt(&config) } @@ -266,7 +238,7 @@ func NewSideBySideConfig(opts ...SideBySideOption) SideBySideConfig { return config } -// WithTotalWidth sets the total width for side-by-side view. +// WithTotalWidth sets the total width for side-by-side view func WithTotalWidth(width int) SideBySideOption { return func(s *SideBySideConfig) { if width > 0 { @@ -275,14 +247,14 @@ func WithTotalWidth(width int) SideBySideOption { } } -// WithStyle sets the styling configuration. +// WithStyle sets the styling configuration func WithStyle(style StyleConfig) SideBySideOption { return func(s *SideBySideConfig) { s.Style = style } } -// WithStyleOptions applies the specified style options. +// WithStyleOptions applies the specified style options func WithStyleOptions(opts ...StyleOption) SideBySideOption { return func(s *SideBySideConfig) { s.Style = NewStyleConfig(opts...) @@ -290,10 +262,10 @@ func WithStyleOptions(opts ...StyleOption) SideBySideOption { } // ------------------------------------------------------------------------- -// Diff Parsing and Generation +// Diff Parsing // ------------------------------------------------------------------------- -// ParseUnifiedDiff parses a unified diff format string into structured data. +// ParseUnifiedDiff parses a unified diff format string into structured data func ParseUnifiedDiff(diff string) (DiffResult, error) { var result DiffResult var currentHunk *Hunk @@ -305,7 +277,7 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { inFileHeader := true for _, line := range lines { - // Parse the file headers + // Parse file headers if inFileHeader { if strings.HasPrefix(line, "--- a/") { result.OldFile = strings.TrimPrefix(line, "--- a/") @@ -332,27 +304,27 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { newStart, _ := strconv.Atoi(matches[3]) oldLine = oldStart newLine = newStart - continue } - // ignore the \\ No newline at end of file + // Ignore "No newline at end of file" markers if strings.HasPrefix(line, "\\ No newline at end of file") { continue } + if currentHunk == nil { continue } + // Process the line based on its prefix if len(line) > 0 { - // Process the line based on its prefix switch line[0] { case '+': currentHunk.Lines = append(currentHunk.Lines, DiffLine{ OldLineNo: 0, NewLineNo: newLine, Kind: LineAdded, - Content: line[1:], // skip '+' + Content: line[1:], }) newLine++ case '-': @@ -360,7 +332,7 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { OldLineNo: oldLine, NewLineNo: 0, Kind: LineRemoved, - Content: line[1:], // skip '-' + Content: line[1:], }) oldLine++ default: @@ -394,14 +366,13 @@ func ParseUnifiedDiff(diff string) (DiffResult, error) { return result, nil } -// HighlightIntralineChanges updates the content of lines in a hunk to show -// character-level differences within lines. +// HighlightIntralineChanges updates lines in a hunk to show character-level differences func HighlightIntralineChanges(h *Hunk, style StyleConfig) { var updated []DiffLine dmp := diffmatchpatch.New() for i := 0; i < len(h.Lines); i++ { - // Look for removed line followed by added line, which might have similar content + // Look for removed line followed by added line if i+1 < len(h.Lines) && h.Lines[i].Kind == LineRemoved && h.Lines[i+1].Kind == LineAdded { @@ -411,12 +382,40 @@ func HighlightIntralineChanges(h *Hunk, style StyleConfig) { // Find character-level differences patches := dmp.DiffMain(oldLine.Content, newLine.Content, false) - patches = dmp.DiffCleanupEfficiency(patches) patches = dmp.DiffCleanupSemantic(patches) + patches = dmp.DiffCleanupMerge(patches) + patches = dmp.DiffCleanupEfficiency(patches) - // Apply highlighting to the differences - oldLine.Content = colorizeSegments(patches, true, style) - newLine.Content = colorizeSegments(patches, false, style) + segments := make([]Segment, 0) + + removeStart := 0 + addStart := 0 + for _, patch := range patches { + switch patch.Type { + case diffmatchpatch.DiffDelete: + segments = append(segments, Segment{ + Start: removeStart, + End: removeStart + len(patch.Text), + Type: LineRemoved, + Text: patch.Text, + }) + removeStart += len(patch.Text) + case diffmatchpatch.DiffInsert: + segments = append(segments, Segment{ + Start: addStart, + End: addStart + len(patch.Text), + Type: LineAdded, + Text: patch.Text, + }) + addStart += len(patch.Text) + default: + // Context text, no highlighting needed + removeStart += len(patch.Text) + addStart += len(patch.Text) + } + } + oldLine.Segments = segments + newLine.Segments = segments updated = append(updated, oldLine, newLine) i++ // Skip the next line as we've already processed it @@ -428,45 +427,7 @@ func HighlightIntralineChanges(h *Hunk, style StyleConfig) { h.Lines = updated } -// colorizeSegments applies styles to the character-level diff segments. -func colorizeSegments(diffs []diffmatchpatch.Diff, isOld bool, style StyleConfig) string { - var buf strings.Builder - - removeBg := lipgloss.NewStyle(). - Background(style.RemovedHighlightBg). - Foreground(style.RemovedHighlightFg) - - addBg := lipgloss.NewStyle(). - Background(style.AddedHighlightBg). - Foreground(style.AddedHighlightFg) - - removedLineStyle := lipgloss.NewStyle().Background(style.RemovedLineBg) - addedLineStyle := lipgloss.NewStyle().Background(style.AddedLineBg) - - for _, d := range diffs { - switch d.Type { - case diffmatchpatch.DiffEqual: - // Handle text that's the same in both versions - buf.WriteString(d.Text) - case diffmatchpatch.DiffDelete: - // Handle deleted text (only show in old version) - if isOld { - buf.WriteString(removeBg.Render(d.Text)) - buf.WriteString(removedLineStyle.Render("")) - } - case diffmatchpatch.DiffInsert: - // Handle inserted text (only show in new version) - if !isOld { - buf.WriteString(addBg.Render(d.Text)) - buf.WriteString(addedLineStyle.Render("")) - } - } - } - - return buf.String() -} - -// pairLines converts a flat list of diff lines to pairs for side-by-side display. +// pairLines converts a flat list of diff lines to pairs for side-by-side display func pairLines(lines []DiffLine) []linePair { var pairs []linePair i := 0 @@ -498,7 +459,7 @@ func pairLines(lines []DiffLine) []linePair { // Syntax Highlighting // ------------------------------------------------------------------------- -// SyntaxHighlight applies syntax highlighting to a string based on the file extension. +// SyntaxHighlight applies syntax highlighting to text based on file extension func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipgloss.TerminalColor) error { // Determine the language lexer to use l := lexers.Match(fileName) @@ -515,21 +476,98 @@ func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipglos if f == nil { f = formatters.Fallback } - - // Get the style - s := styles.Get("dracula") - if s == nil { - s = styles.Fallback - } - + theme := ` + +` + + r := strings.NewReader(theme) + style := chroma.MustNewXMLStyle(r) // Modify the style to use the provided background - s, err := s.Builder().Transform( + s, err := style.Builder().Transform( func(t chroma.StyleEntry) chroma.StyleEntry { r, g, b, _ := bg.RGBA() - ru8 := uint8(r >> 8) - gu8 := uint8(g >> 8) - bu8 := uint8(b >> 8) - t.Background = chroma.NewColour(ru8, gu8, bu8) + t.Background = chroma.NewColour(uint8(r>>8), uint8(g>>8), uint8(b>>8)) return t }, ).Build() @@ -546,7 +584,7 @@ func SyntaxHighlight(w io.Writer, source, fileName, formatter string, bg lipglos return f.Format(w, s, it) } -// highlightLine applies syntax highlighting to a single line. +// highlightLine applies syntax highlighting to a single line func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) string { var buf bytes.Buffer err := SyntaxHighlight(&buf, line, fileName, "terminal16m", bg) @@ -556,7 +594,7 @@ func highlightLine(fileName string, line string, bg lipgloss.TerminalColor) stri return buf.String() } -// createStyles generates the lipgloss styles needed for rendering diffs. +// createStyles generates the lipgloss styles needed for rendering diffs func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, contextLineStyle, lineNumberStyle lipgloss.Style) { removedLineStyle = lipgloss.NewStyle().Background(config.RemovedLineBg) addedLineStyle = lipgloss.NewStyle().Background(config.AddedLineBg) @@ -566,7 +604,106 @@ func createStyles(config StyleConfig) (removedLineStyle, addedLineStyle, context return } -// renderLeftColumn formats the left side of a side-by-side diff. +// ------------------------------------------------------------------------- +// Rendering Functions +// ------------------------------------------------------------------------- + +// applyHighlighting applies intra-line highlighting to a piece of text +func applyHighlighting(content string, segments []Segment, segmentType LineType, highlightBg lipgloss.Color, +) string { + // Find all ANSI sequences in the content + ansiRegex := regexp.MustCompile(`\x1b(?:[@-Z\\-_]|\[[0-9?]*(?:;[0-9?]*)*[@-~])`) + ansiMatches := ansiRegex.FindAllStringIndex(content, -1) + + // Build a mapping of visible character positions to their actual indices + visibleIdx := 0 + ansiSequences := make(map[int]string) + lastAnsiSeq := "\x1b[0m" // Default reset sequence + + for i := 0; i < len(content); { + isAnsi := false + for _, match := range ansiMatches { + if match[0] == i { + ansiSequences[visibleIdx] = content[match[0]:match[1]] + lastAnsiSeq = content[match[0]:match[1]] + i = match[1] + isAnsi = true + break + } + } + if isAnsi { + continue + } + + // For non-ANSI positions, store the last ANSI sequence + if _, exists := ansiSequences[visibleIdx]; !exists { + ansiSequences[visibleIdx] = lastAnsiSeq + } + visibleIdx++ + i++ + } + + // Apply highlighting + var sb strings.Builder + inSelection := false + currentPos := 0 + + for i := 0; i < len(content); { + // Check if we're at an ANSI sequence + isAnsi := false + for _, match := range ansiMatches { + if match[0] == i { + sb.WriteString(content[match[0]:match[1]]) // Preserve ANSI sequence + i = match[1] + isAnsi = true + break + } + } + if isAnsi { + continue + } + + // Check for segment boundaries + for _, seg := range segments { + if seg.Type == segmentType { + if currentPos == seg.Start { + inSelection = true + } + if currentPos == seg.End { + inSelection = false + } + } + } + + // Get current character + char := string(content[i]) + + if inSelection { + // Get the current styling + currentStyle := ansiSequences[currentPos] + + // Apply background highlight + sb.WriteString("\x1b[48;2;") + r, g, b, _ := highlightBg.RGBA() + sb.WriteString(fmt.Sprintf("%d;%d;%dm", r>>8, g>>8, b>>8)) + sb.WriteString(char) + sb.WriteString("\x1b[49m") // Reset only background + + // Reapply the original ANSI sequence + sb.WriteString(currentStyle) + } else { + // Not in selection, just copy the character + sb.WriteString(char) + } + + currentPos++ + i++ + } + + return sb.String() +} + +// renderLeftColumn formats the left side of a side-by-side diff func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string { if dl == nil { contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg) @@ -575,9 +712,9 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC removedLineStyle, _, contextLineStyle, lineNumberStyle := createStyles(styles) + // Determine line style based on line type var marker string var bgStyle lipgloss.Style - switch dl.Kind { case LineRemoved: marker = removedLineStyle.Foreground(styles.RemovedFg).Render("-") @@ -591,18 +728,29 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC bgStyle = contextLineStyle } + // Format line number lineNum := "" if dl.OldLineNo > 0 { lineNum = fmt.Sprintf("%6d", dl.OldLineNo) } + // Create the line prefix prefix := lineNumberStyle.Render(lineNum + " " + marker) + + // Apply syntax highlighting content := highlightLine(fileName, dl.Content, bgStyle.GetBackground()) + // Apply intra-line highlighting for removed lines + if dl.Kind == LineRemoved && len(dl.Segments) > 0 { + content = applyHighlighting(content, dl.Segments, LineRemoved, styles.RemovedHighlightBg) + } + + // Add a padding space for removed lines if dl.Kind == LineRemoved { content = bgStyle.Render(" ") + content } + // Create the final line and truncate if needed lineText := prefix + content return bgStyle.MaxHeight(1).Width(colWidth).Render( ansi.Truncate( @@ -613,7 +761,7 @@ func renderLeftColumn(fileName string, dl *DiffLine, colWidth int, styles StyleC ) } -// renderRightColumn formats the right side of a side-by-side diff. +// renderRightColumn formats the right side of a side-by-side diff func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles StyleConfig) string { if dl == nil { contextLineStyle := lipgloss.NewStyle().Background(styles.ContextLineBg) @@ -622,9 +770,9 @@ func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles Style _, addedLineStyle, contextLineStyle, lineNumberStyle := createStyles(styles) + // Determine line style based on line type var marker string var bgStyle lipgloss.Style - switch dl.Kind { case LineAdded: marker = addedLineStyle.Foreground(styles.AddedFg).Render("+") @@ -638,18 +786,29 @@ func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles Style bgStyle = contextLineStyle } + // Format line number lineNum := "" if dl.NewLineNo > 0 { lineNum = fmt.Sprintf("%6d", dl.NewLineNo) } + // Create the line prefix prefix := lineNumberStyle.Render(lineNum + " " + marker) + + // Apply syntax highlighting content := highlightLine(fileName, dl.Content, bgStyle.GetBackground()) + // Apply intra-line highlighting for added lines + if dl.Kind == LineAdded && len(dl.Segments) > 0 { + content = applyHighlighting(content, dl.Segments, LineAdded, styles.AddedHighlightBg) + } + + // Add a padding space for added lines if dl.Kind == LineAdded { content = bgStyle.Render(" ") + content } + // Create the final line and truncate if needed lineText := prefix + content return bgStyle.MaxHeight(1).Width(colWidth).Render( ansi.Truncate( @@ -661,10 +820,10 @@ func renderRightColumn(fileName string, dl *DiffLine, colWidth int, styles Style } // ------------------------------------------------------------------------- -// Public API Methods +// Public API // ------------------------------------------------------------------------- -// RenderSideBySideHunk formats a hunk for side-by-side display. +// RenderSideBySideHunk formats a hunk for side-by-side display func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) string { // Apply options to create the configuration config := NewSideBySideConfig(opts...) @@ -692,7 +851,7 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str return sb.String() } -// FormatDiff creates a side-by-side formatted view of a diff. +// FormatDiff creates a side-by-side formatted view of a diff func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { diffResult, err := ParseUnifiedDiff(diffText) if err != nil { @@ -700,11 +859,18 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { } var sb strings.Builder - config := NewSideBySideConfig(opts...) + for i, h := range diffResult.Hunks { if i > 0 { - sb.WriteString(lipgloss.NewStyle().Background(config.Style.HunkLineBg).Foreground(config.Style.HunkLineFg).Width(config.TotalWidth).Render(h.Header) + "\n") + // Render hunk header + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.HunkLineBg). + Foreground(config.Style.HunkLineFg). + Width(config.TotalWidth). + Render(h.Header) + "\n", + ) } sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) } @@ -712,14 +878,16 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { return sb.String(), nil } -// GenerateDiff creates a unified diff from two file contents. +// GenerateDiff creates a unified diff from two file contents func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) { + // Create temporary directory for git operations tempDir, err := os.MkdirTemp("", "git-diff-temp") if err != nil { return "", 0, 0 } defer os.RemoveAll(tempDir) + // Initialize git repo repo, err := git.PlainInit(tempDir, false) if err != nil { return "", 0, 0 @@ -730,6 +898,7 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in return "", 0, 0 } + // Write the "before" content and commit it fullPath := filepath.Join(tempDir, fileName) if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { return "", 0, 0 @@ -754,7 +923,9 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in return "", 0, 0 } + // Write the "after" content and commit it if err = os.WriteFile(fullPath, []byte(afterContent), 0o644); err != nil { + return "", 0, 0 } _, err = wt.Add(fileName) @@ -773,6 +944,7 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in return "", 0, 0 } + // Get the diff between the two commits beforeCommitObj, err := repo.CommitObject(beforeCommit) if err != nil { return "", 0, 0 @@ -788,6 +960,7 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in return "", 0, 0 } + // Count additions and removals additions := 0 removals := 0 for _, fileStat := range patch.Stats() { From 1cdd24fbc7b45693b65b5d55e4f45c2ebc60a556 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 15 Apr 2025 11:58:01 +0200 Subject: [PATCH 17/41] minor fixes --- internal/app/app.go | 5 ++-- internal/config/config.go | 5 +++- internal/llm/agent/agent-tool.go | 8 ++++--- internal/llm/agent/agent.go | 17 ++++++++++++++ internal/llm/agent/coder.go | 2 +- internal/llm/agent/task.go | 39 ++++++++++++++++---------------- 6 files changed, 49 insertions(+), 27 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 9f575cac324127469b34c8a0ef5f4df9ed0cf52e..ca23b3c404fd651bae9eb2d5524c684e7e591ba3 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -48,9 +48,10 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { LSPClients: make(map[string]*lsp.Client), } + app.initLSPClients(ctx) + var err error app.CoderAgent, err = agent.NewCoderAgent( - app.Permissions, app.Sessions, app.Messages, @@ -61,8 +62,6 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { return nil, err } - app.initLSPClients(ctx) - return app, nil } diff --git a/internal/config/config.go b/internal/config/config.go index 1f3091ff3ea3df6b7c569fe013c7c15933b3c7e4..f0afbdd3c9839549403d937bb8982896804f733d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -272,5 +272,8 @@ func Get() *Config { // WorkingDirectory returns the current working directory from the configuration. func WorkingDirectory() string { - return viper.GetString("wd") + if cfg == nil { + panic("config not loaded") + } + return cfg.WorkingDir } diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index a92ea44a4e31c56baacc1b0772a151f153a2bd7b..83160bb645458150d96373de255ac78f1111dadd 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -53,7 +53,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - agent, err := NewTaskAgent(b.lspClients) + agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) } @@ -105,9 +105,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes func NewAgentTool( Sessions session.Service, Messages message.Service, + LspClients map[string]*lsp.Client, ) tools.BaseTool { return &agentTool{ - sessions: Sessions, - messages: Messages, + sessions: Sessions, + messages: Messages, + lspClients: LspClients, } } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 997004e123208835b038ef4360c8455b03867bd5..1958111a136ff6956c9370d568831831c5906807 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "os" + "runtime/debug" "strings" "sync" @@ -88,6 +90,21 @@ func (a *agent) Generate(ctx context.Context, sessionID string, content string) defer func() { if r := recover(); r != nil { logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) + + // dump stack trace into a file + file, err := os.Create("panic.log") + if err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) + return + } + + defer file.Close() + + stackTrace := debug.Stack() + if _, err := file.Write(stackTrace); err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err)) + } + } }() defer a.activeRequests.Delete(sessionID) diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index 8eea5704163fb3bb02b14aba98054e3fa9477f95..a3db6b55c1f675d4ca77ff03cf9cc781b2e795e9 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -49,7 +49,7 @@ func NewCoderAgent( tools.NewSourcegraphTool(), tools.NewViewTool(lspClients), tools.NewWriteTool(lspClients, permissions), - NewAgentTool(sessions, messages), + NewAgentTool(sessions, messages, lspClients), }, otherTools..., ), ) diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go index 0a072044c4db4e4261848146d6d8d83a93787d0e..fca1f223f505c1756503967b21dc592dc64903d2 100644 --- a/internal/llm/agent/task.go +++ b/internal/llm/agent/task.go @@ -8,39 +8,40 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/session" ) type taskAgent struct { - *agent + Service } -func (c *taskAgent) Generate(ctx context.Context, sessionID string, content string) error { - return c.generate(ctx, sessionID, content) -} - -func NewTaskAgent(lspClients map[string]*lsp.Client) (Service, error) { +func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) { model, ok := models.SupportedModels[config.Get().Model.Coder] if !ok { return nil, errors.New("model not supported") } ctx := context.Background() - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) + + agent, err := NewAgent( + ctx, + sessions, + messages, + model, + []tools.BaseTool{ + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + }, + ) if err != nil { return nil, err } + return &taskAgent{ - agent: &agent{ - tools: []tools.BaseTool{ - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - }, - model: model, - agent: agentProvider, - titleGenerator: titleGenerator, - }, + agent, }, nil } From 76b4065f17b87a63092acfd98c997bab53700b35 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 15 Apr 2025 11:59:23 +0200 Subject: [PATCH 18/41] update gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0ef6e2aefaf866005e2c8cd8376c9ff447210ff5..b4d5d61ea39cc833b733b8dcb24f142367a9e85c 100644 --- a/.gitignore +++ b/.gitignore @@ -31,7 +31,7 @@ go.work .Trashes ehthumbs.db Thumbs.db -debug.log +*.log # Binary output directory /bin/ @@ -44,3 +44,4 @@ debug.log .opencode internal/assets/diff/index.mjs +cmd/test/* From bbfa60c787f2ec459f1689b9a650ddbec9693ed9 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 16 Apr 2025 20:06:23 +0200 Subject: [PATCH 19/41] reimplement agent,provider and add file history --- .opencode.json | 4 - README.md | 34 +- cmd/root.go | 24 +- go.mod | 8 +- go.sum | 14 - internal/app/app.go | 17 +- internal/app/lsp.go | 19 +- internal/config/config.go | 108 ++- internal/db/files.sql.go | 4 +- internal/db/sql/files.sql | 4 +- internal/diff/diff.go | 99 ++- internal/llm/agent/agent-tool.go | 18 +- internal/llm/agent/agent.go | 861 +++++++------------ internal/llm/agent/coder.go | 63 -- internal/llm/agent/mcp-tools.go | 4 +- internal/llm/agent/task.go | 47 - internal/llm/agent/tools.go | 50 ++ internal/llm/models/anthropic.go | 71 ++ internal/llm/models/models.go | 190 ++-- internal/llm/prompt/coder.go | 28 +- internal/llm/prompt/prompt.go | 19 + internal/llm/prompt/task.go | 5 +- internal/llm/prompt/title.go | 4 +- internal/llm/provider/anthropic.go | 531 ++++++------ internal/llm/provider/bedrock.go | 101 ++- internal/llm/provider/gemini.go | 533 ++++++++---- internal/llm/provider/openai.go | 401 +++++---- internal/llm/provider/provider.go | 169 +++- internal/llm/tools/bash.go | 7 +- internal/llm/tools/bash_test.go | 31 - internal/llm/tools/edit.go | 75 +- internal/llm/tools/edit_test.go | 30 +- internal/llm/tools/file.go | 10 - internal/llm/tools/glob.go | 4 +- internal/llm/tools/grep.go | 4 +- internal/llm/tools/ls.go | 4 +- internal/llm/tools/mocks_test.go | 246 ++++++ internal/llm/tools/shell/shell.go | 12 +- internal/llm/tools/sourcegraph.go | 2 +- internal/llm/tools/tools.go | 9 +- internal/llm/tools/write.go | 27 +- internal/llm/tools/write_test.go | 22 +- internal/logging/logger.go | 41 +- internal/lsp/client.go | 13 +- internal/lsp/handlers.go | 2 +- internal/lsp/transport.go | 28 +- internal/lsp/watcher/watcher.go | 18 +- internal/message/content.go | 30 +- internal/pubsub/broker.go | 2 +- internal/session/session.go | 15 + internal/tui/components/chat/chat.go | 2 - internal/tui/components/chat/editor.go | 22 +- internal/tui/components/chat/messages.go | 205 ++++- internal/tui/components/chat/sidebar.go | 176 +++- internal/tui/components/core/dialog.go | 117 --- internal/tui/components/core/help.go | 119 --- internal/tui/components/core/status.go | 90 +- internal/tui/components/dialog/help.go | 182 ++++ internal/tui/components/dialog/permission.go | 682 +++++++-------- internal/tui/components/dialog/quit.go | 156 ++-- internal/tui/components/logs/details.go | 2 - internal/tui/components/logs/table.go | 22 - internal/tui/components/repl/editor.go | 201 ----- internal/tui/components/repl/messages.go | 513 ----------- internal/tui/components/repl/sessions.go | 249 ------ internal/tui/layout/overlay.go | 11 +- internal/tui/layout/split.go | 1 + internal/tui/page/chat.go | 32 +- internal/tui/page/init.go | 308 ------- internal/tui/page/logs.go | 17 + internal/tui/page/repl.go | 21 - internal/tui/tui.go | 277 +++--- main.go | 7 + 73 files changed, 3595 insertions(+), 3879 deletions(-) delete mode 100644 internal/llm/agent/coder.go delete mode 100644 internal/llm/agent/task.go create mode 100644 internal/llm/agent/tools.go create mode 100644 internal/llm/models/anthropic.go create mode 100644 internal/llm/prompt/prompt.go create mode 100644 internal/llm/tools/mocks_test.go delete mode 100644 internal/tui/components/core/dialog.go delete mode 100644 internal/tui/components/core/help.go create mode 100644 internal/tui/components/dialog/help.go delete mode 100644 internal/tui/components/repl/editor.go delete mode 100644 internal/tui/components/repl/messages.go delete mode 100644 internal/tui/components/repl/sessions.go delete mode 100644 internal/tui/page/init.go delete mode 100644 internal/tui/page/repl.go diff --git a/.opencode.json b/.opencode.json index f63a63dba2ad01766cb59515d84547d6529b40d2..b7fc19b524371cf7e4a625173f2fe305914694d3 100644 --- a/.opencode.json +++ b/.opencode.json @@ -1,8 +1,4 @@ { - "model": { - "coder": "claude-3.7-sonnet", - "coderMaxTokens": 20000 - }, "lsp": { "gopls": { "command": "gopls" diff --git a/README.md b/README.md index 23a1906a1566d418857f2b2bd2605ffd914413b5..564284c7f138fdc54ebca74d64745730c496635e 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ -# TermAI +# OpenCode > **⚠️ Early Development Notice:** This project is in early development and is not yet ready for production use. Features may change, break, or be incomplete. Use at your own risk. A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal. -[![TermAI Demo](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy.svg)](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy) +[![OpenCode Demo](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy.svg)](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy) ## Overview -TermAI is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more. +OpenCode is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more. ## Features @@ -23,16 +23,16 @@ TermAI is a Go-based CLI application that brings AI assistance to your terminal. ```bash # Coming soon -go install github.com/kujtimiihoxha/termai@latest +go install github.com/kujtimiihoxha/opencode@latest ``` ## Configuration -TermAI looks for configuration in the following locations: +OpenCode looks for configuration in the following locations: -- `$HOME/.termai.json` -- `$XDG_CONFIG_HOME/termai/.termai.json` -- `./.termai.json` (local directory) +- `$HOME/.opencode.json` +- `$XDG_CONFIG_HOME/opencode/.opencode.json` +- `./.opencode.json` (local directory) You can also use environment variables: @@ -43,11 +43,11 @@ You can also use environment variables: ## Usage ```bash -# Start TermAI -termai +# Start OpenCode +opencode # Start with debug logging -termai -d +opencode -d ``` ### Keyboard Shortcuts @@ -81,7 +81,7 @@ termai -d ## Architecture -TermAI is built with a modular architecture: +OpenCode is built with a modular architecture: - **cmd**: Command-line interface using Cobra - **internal/app**: Core application services @@ -103,22 +103,22 @@ TermAI is built with a modular architecture: ```bash # Clone the repository -git clone https://github.com/kujtimiihoxha/termai.git -cd termai +git clone https://github.com/kujtimiihoxha/opencode.git +cd opencode # Build the diff script first go run cmd/diff/main.go # Build -go build -o termai +go build -o opencode # Run -./termai +./opencode ``` ## Acknowledgments -TermAI builds upon the work of several open source projects and developers: +OpenCode builds upon the work of several open source projects and developers: - [@isaacphi](https://github.com/isaacphi) - LSP client implementation diff --git a/cmd/root.go b/cmd/root.go index a2e63006f195b406598388c538e6abcfd8f525bc..ff71747d56458c6e8094e19db90a9f36acc2db42 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -20,7 +20,7 @@ import ( ) var rootCmd = &cobra.Command{ - Use: "termai", + Use: "OpenCode", Short: "A terminal ai assistant", Long: `A terminal ai assistant`, RunE: func(cmd *cobra.Command, args []string) error { @@ -89,12 +89,9 @@ var rootCmd = &cobra.Command{ // Set up message handling for the TUI go func() { defer tuiWg.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in TUI message handling: %v", r) - attemptTUIRecovery(program) - } - }() + defer logging.RecoverPanic("TUI-message-handler", func() { + attemptTUIRecovery(program) + }) for { select { @@ -153,11 +150,7 @@ func attemptTUIRecovery(program *tea.Program) { func initMCPTools(ctx context.Context, app *app.App) { go func() { - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in MCP goroutine: %v", r) - } - }() + defer logging.RecoverPanic("MCP-goroutine", nil) // Create a context with timeout for the initial MCP tools fetch ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -179,11 +172,7 @@ func setupSubscriber[T any]( wg.Add(1) go func() { defer wg.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("Panic in %s subscription goroutine: %v", name, r) - } - }() + defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil) for { select { @@ -232,6 +221,7 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { // Wait with a timeout for all goroutines to complete waitCh := make(chan struct{}) go func() { + defer logging.RecoverPanic("subscription-cleanup", nil) wg.Wait() close(waitCh) }() diff --git a/go.mod b/go.mod index 925a71097a6d55b7cbbc78c16151d0bbb42a4a32..16c88d3a61c83d913721aecdf21afb07d26dbe29 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,6 @@ require ( github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/generative-ai-go v0.19.0 github.com/google/uuid v1.6.0 - github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9 github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231 github.com/mark3labs/mcp-go v0.17.0 github.com/mattn/go-runewidth v0.0.16 @@ -36,7 +35,6 @@ require ( github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 - golang.org/x/net v0.39.0 google.golang.org/api v0.215.0 ) @@ -106,7 +104,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect - github.com/sahilm/fuzzy v0.1.1 // indirect github.com/skeema/knownhosts v1.3.1 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect @@ -129,11 +126,8 @@ require ( go.opentelemetry.io/otel/trace v1.29.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.design/x/clipboard v0.7.0 // indirect golang.org/x/crypto v0.37.0 // indirect - golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect - golang.org/x/image v0.14.0 // indirect - golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect + golang.org/x/net v0.39.0 // indirect golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect diff --git a/go.sum b/go.sum index 9c2c2df8fbcb0d909a73a67ead79339bf3814892..4832271f21485670a9779d3c2ee0cc1909d0b8f8 100644 --- a/go.sum +++ b/go.sum @@ -180,10 +180,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9 h1:xYfCLI8KUwmXDFp1pOpNX+XsQczQw9VbEuju1pQF5/A= -github.com/kujtimiihoxha/vimtea v0.0.3-0.20250329221256-a250e98498f9/go.mod h1:Ye+kIkTmPO5xuqCQ+PPHDTGIViRRoSpSIlcYgma8YlA= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231 h1:9rjt7AfnrXKNSZhp36A3/4QAZAwGGCGD/p8Bse26zms= @@ -235,8 +231,6 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= -github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= -github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= @@ -302,8 +296,6 @@ go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.design/x/clipboard v0.7.0 h1:4Je8M/ys9AJumVnl8m+rZnIvstSnYj1fvzqYrU3TXvo= -golang.design/x/clipboard v0.7.0/go.mod h1:PQIvqYO9GP29yINEfsEn5zSQKAz3UgXmZKzDA6dnq2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= @@ -314,12 +306,6 @@ golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= -golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 h1:bFYqOIMdeiCEdzPJkLiOoMDzW/v3tjW4AA/RmUZYsL8= -golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8= -golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= -golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/internal/app/app.go b/internal/app/app.go index ca23b3c404fd651bae9eb2d5524c684e7e591ba3..1c16ccc1186253961e01bd60f76a541bc891364b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/llm/agent" @@ -20,7 +21,7 @@ import ( type App struct { Sessions session.Service Messages message.Service - Files history.Service + History history.Service Permissions permission.Service CoderAgent agent.Service @@ -43,7 +44,7 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { app := &App{ Sessions: sessions, Messages: messages, - Files: files, + History: files, Permissions: permission.NewPermissionService(), LSPClients: make(map[string]*lsp.Client), } @@ -51,11 +52,17 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { app.initLSPClients(ctx) var err error - app.CoderAgent, err = agent.NewCoderAgent( - app.Permissions, + app.CoderAgent, err = agent.NewAgent( + config.AgentCoder, app.Sessions, app.Messages, - app.LSPClients, + agent.CoderAgentTools( + app.Permissions, + app.Sessions, + app.Messages, + app.History, + app.LSPClients, + ), ) if err != nil { logging.Error("Failed to create coder agent", err) diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 4e0568f071f71031f898361748c1516b23d756fe..4a762f1a156cc2677ffd3806723d7ea402d60005 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -22,16 +22,17 @@ func (app *App) initLSPClients(ctx context.Context) { // createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) { // Create a specific context for initialization with a timeout - initCtx, initCancel := context.WithTimeout(context.Background(), 30*time.Second) - defer initCancel() // Create the LSP client - lspClient, err := lsp.NewClient(initCtx, command, args...) + lspClient, err := lsp.NewClient(ctx, command, args...) if err != nil { logging.Error("Failed to create LSP client for", name, err) return + } + initCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() // Initialize with the initialization context _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory()) if err != nil { @@ -64,14 +65,10 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman // runWorkspaceWatcher executes the workspace watcher for an LSP client func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceWatcher *watcher.WorkspaceWatcher) { defer app.watcherWG.Done() - defer func() { - if r := recover(); r != nil { - logging.Error("LSP client crashed", "client", name, "panic", r) - - // Try to restart the client - app.restartLSPClient(ctx, name) - } - }() + defer logging.RecoverPanic("LSP-"+name, func() { + // Try to restart the client + app.restartLSPClient(ctx, name) + }) workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) logging.Info("Workspace watcher stopped", "client", name) diff --git a/internal/config/config.go b/internal/config/config.go index f0afbdd3c9839549403d937bb8982896804f733d..147d6c83a4bb77bcb7749f81dd6867410ecc8059 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -31,12 +31,18 @@ type MCPServer struct { Headers map[string]string `json:"headers"` } -// Model defines configuration for different LLM models and their token limits. -type Model struct { - Coder models.ModelID `json:"coder"` - CoderMaxTokens int64 `json:"coderMaxTokens"` - Task models.ModelID `json:"task"` - TaskMaxTokens int64 `json:"taskMaxTokens"` +type AgentName string + +const ( + AgentCoder AgentName = "coder" + AgentTask AgentName = "task" + AgentTitle AgentName = "title" +) + +// Agent defines configuration for different LLM models and their token limits. +type Agent struct { + Model models.ModelID `json:"model"` + MaxTokens int64 `json:"maxTokens"` } // Provider defines configuration for an LLM provider. @@ -65,8 +71,9 @@ type Config struct { MCPServers map[string]MCPServer `json:"mcpServers,omitempty"` Providers map[models.ModelProvider]Provider `json:"providers,omitempty"` LSP map[string]LSPConfig `json:"lsp,omitempty"` - Model Model `json:"model"` + Agents map[AgentName]Agent `json:"agents"` Debug bool `json:"debug,omitempty"` + DebugLSP bool `json:"debugLSP,omitempty"` } // Application constants @@ -118,11 +125,42 @@ func Load(workingDir string, debug bool) (*Config, error) { if cfg.Debug { defaultLevel = slog.LevelDebug } - // Configure logger - logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ - Level: defaultLevel, - })) - slog.SetDefault(logger) + // if we are in debug mode make the writer a file + if cfg.Debug { + loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log") + + // if file does not exist create it + if _, err := os.Stat(loggingFile); os.IsNotExist(err) { + if err := os.MkdirAll(cfg.Data.Directory, 0o755); err != nil { + return cfg, fmt.Errorf("failed to create directory: %w", err) + } + if _, err := os.Create(loggingFile); err != nil { + return cfg, fmt.Errorf("failed to create log file: %w", err) + } + } + + sloggingFileWriter, err := os.OpenFile(loggingFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o666) + if err != nil { + return cfg, fmt.Errorf("failed to open log file: %w", err) + } + // Configure logger + logger := slog.New(slog.NewTextHandler(sloggingFileWriter, &slog.HandlerOptions{ + Level: defaultLevel, + })) + slog.SetDefault(logger) + } else { + // Configure logger + logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ + Level: defaultLevel, + })) + slog.SetDefault(logger) + } + + // Override the max tokens for title agent + cfg.Agents[AgentTitle] = Agent{ + Model: cfg.Agents[AgentTitle].Model, + MaxTokens: 80, + } return cfg, nil } @@ -159,44 +197,50 @@ func setProviderDefaults() { // Groq configuration if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { viper.SetDefault("providers.groq.apiKey", apiKey) - viper.SetDefault("model.coder", models.QWENQwq) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.QWENQwq) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.QWENQwq) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.QWENQwq) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.QWENQwq) } // Google Gemini configuration if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { viper.SetDefault("providers.gemini.apiKey", apiKey) - viper.SetDefault("model.coder", models.GRMINI20Flash) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.GRMINI20Flash) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.GRMINI20Flash) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.GRMINI20Flash) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.GRMINI20Flash) } // OpenAI configuration if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { viper.SetDefault("providers.openai.apiKey", apiKey) - viper.SetDefault("model.coder", models.GPT4o) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.GPT4o) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.GPT4o) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.GPT4o) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.GPT4o) + } // Anthropic configuration if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { viper.SetDefault("providers.anthropic.apiKey", apiKey) - viper.SetDefault("model.coder", models.Claude37Sonnet) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.Claude37Sonnet) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.Claude37Sonnet) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.Claude37Sonnet) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.Claude37Sonnet) } if hasAWSCredentials() { - viper.SetDefault("model.coder", models.BedrockClaude37Sonnet) - viper.SetDefault("model.coderMaxTokens", defaultMaxTokens) - viper.SetDefault("model.task", models.BedrockClaude37Sonnet) - viper.SetDefault("model.taskMaxTokens", defaultMaxTokens) + viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet) + viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet) + viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) + viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet) } } diff --git a/internal/db/files.sql.go b/internal/db/files.sql.go index b45731098451eaee1ed2b0b198ce5db39ac40094..39def271f104addc2fa0057de503c17c2cdfecf7 100644 --- a/internal/db/files.sql.go +++ b/internal/db/files.sql.go @@ -97,7 +97,9 @@ func (q *Queries) GetFile(ctx context.Context, id string) (File, error) { const getFileByPathAndSession = `-- name: GetFileByPathAndSession :one SELECT id, session_id, path, content, version, created_at, updated_at FROM files -WHERE path = ? AND session_id = ? LIMIT 1 +WHERE path = ? AND session_id = ? +ORDER BY created_at DESC +LIMIT 1 ` type GetFileByPathAndSessionParams struct { diff --git a/internal/db/sql/files.sql b/internal/db/sql/files.sql index c2e7990764fc71a827ce92e7297cb4b155a2eafd..aba2a61111088ef7362753dc7b43c79769428473 100644 --- a/internal/db/sql/files.sql +++ b/internal/db/sql/files.sql @@ -6,7 +6,9 @@ WHERE id = ? LIMIT 1; -- name: GetFileByPathAndSession :one SELECT * FROM files -WHERE path = ? AND session_id = ? LIMIT 1; +WHERE path = ? AND session_id = ? +ORDER BY created_at DESC +LIMIT 1; -- name: ListFilesBySession :many SELECT * diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 02d4d7140c99d94b4aa0894316cdfb243b78b54b..829554c7e1052bb95c02dc8b586634da21938b77 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -19,6 +19,8 @@ import ( "github.com/charmbracelet/x/ansi" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/object" + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/sergi/go-diff/diffmatchpatch" ) @@ -77,6 +79,8 @@ type linePair struct { // StyleConfig defines styling for diff rendering type StyleConfig struct { + ShowHeader bool + FileNameFg lipgloss.Color // Background colors RemovedLineBg lipgloss.Color AddedLineBg lipgloss.Color @@ -106,11 +110,13 @@ type StyleOption func(*StyleConfig) func NewStyleConfig(opts ...StyleOption) StyleConfig { // Default color scheme config := StyleConfig{ + ShowHeader: true, + FileNameFg: lipgloss.Color("#fab283"), RemovedLineBg: lipgloss.Color("#3A3030"), AddedLineBg: lipgloss.Color("#303A30"), ContextLineBg: lipgloss.Color("#212121"), - HunkLineBg: lipgloss.Color("#23252D"), - HunkLineFg: lipgloss.Color("#8CA3B4"), + HunkLineBg: lipgloss.Color("#212121"), + HunkLineFg: lipgloss.Color("#a0a0a0"), RemovedFg: lipgloss.Color("#7C4444"), AddedFg: lipgloss.Color("#478247"), LineNumberFg: lipgloss.Color("#888888"), @@ -132,6 +138,10 @@ func NewStyleConfig(opts ...StyleOption) StyleConfig { } // Style option functions +func WithFileNameFg(color lipgloss.Color) StyleOption { + return func(s *StyleConfig) { s.FileNameFg = color } +} + func WithRemovedLineBg(color lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.RemovedLineBg = color } } @@ -190,6 +200,10 @@ func WithHunkLineFg(color lipgloss.Color) StyleOption { return func(s *StyleConfig) { s.HunkLineFg = color } } +func WithShowHeader(show bool) StyleOption { + return func(s *StyleConfig) { s.ShowHeader = show } +} + // ------------------------------------------------------------------------- // Parse Configuration // ------------------------------------------------------------------------- @@ -841,10 +855,12 @@ func RenderSideBySideHunk(fileName string, h Hunk, opts ...SideBySideOption) str // Calculate column width colWidth := config.TotalWidth / 2 + leftWidth := colWidth + rightWidth := config.TotalWidth - colWidth var sb strings.Builder for _, p := range pairs { - leftStr := renderLeftColumn(fileName, p.left, colWidth, config.Style) - rightStr := renderRightColumn(fileName, p.right, colWidth, config.Style) + leftStr := renderLeftColumn(fileName, p.left, leftWidth, config.Style) + rightStr := renderRightColumn(fileName, p.right, rightWidth, config.Style) sb.WriteString(leftStr + rightStr + "\n") } @@ -861,17 +877,50 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { var sb strings.Builder config := NewSideBySideConfig(opts...) - for i, h := range diffResult.Hunks { - if i > 0 { - // Render hunk header - sb.WriteString( - lipgloss.NewStyle(). - Background(config.Style.HunkLineBg). - Foreground(config.Style.HunkLineFg). - Width(config.TotalWidth). - Render(h.Header) + "\n", - ) - } + if config.Style.ShowHeader { + removeIcon := lipgloss.NewStyle(). + Background(config.Style.RemovedLineBg). + Foreground(config.Style.RemovedFg). + Render("⏹") + addIcon := lipgloss.NewStyle(). + Background(config.Style.AddedLineBg). + Foreground(config.Style.AddedFg). + Render("⏹") + + fileName := lipgloss.NewStyle(). + Background(config.Style.ContextLineBg). + Foreground(config.Style.FileNameFg). + Render(" " + diffResult.OldFile) + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.ContextLineBg). + Padding(0, 1, 0, 1). + Foreground(config.Style.FileNameFg). + BorderStyle(lipgloss.NormalBorder()). + BorderTop(true). + BorderBottom(true). + BorderForeground(config.Style.FileNameFg). + BorderBackground(config.Style.ContextLineBg). + Width(config.TotalWidth). + Render( + lipgloss.JoinHorizontal(lipgloss.Top, + removeIcon, + addIcon, + fileName, + ), + ) + "\n", + ) + } + + for _, h := range diffResult.Hunks { + // Render hunk header + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.HunkLineBg). + Foreground(config.Style.HunkLineFg). + Width(config.TotalWidth). + Render(h.Header) + "\n", + ) sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) } @@ -880,9 +929,15 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { // GenerateDiff creates a unified diff from two file contents func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, int) { + // remove the cwd prefix and ensure consistent path format + // this prevents issues with absolute paths in different environments + cwd := config.WorkingDirectory() + fileName = strings.TrimPrefix(fileName, cwd) + fileName = strings.TrimPrefix(fileName, "/") // Create temporary directory for git operations - tempDir, err := os.MkdirTemp("", "git-diff-temp") + tempDir, err := os.MkdirTemp("", fmt.Sprintf("git-diff-%d", time.Now().UnixNano())) if err != nil { + logging.Error("Failed to create temp directory for git diff", "error", err) return "", 0, 0 } defer os.RemoveAll(tempDir) @@ -890,25 +945,30 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in // Initialize git repo repo, err := git.PlainInit(tempDir, false) if err != nil { + logging.Error("Failed to initialize git repository", "error", err) return "", 0, 0 } wt, err := repo.Worktree() if err != nil { + logging.Error("Failed to get git worktree", "error", err) return "", 0, 0 } // Write the "before" content and commit it fullPath := filepath.Join(tempDir, fileName) if err = os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil { + logging.Error("Failed to create directory for file", "error", err) return "", 0, 0 } if err = os.WriteFile(fullPath, []byte(beforeContent), 0o644); err != nil { + logging.Error("Failed to write before content to file", "error", err) return "", 0, 0 } _, err = wt.Add(fileName) if err != nil { + logging.Error("Failed to add file to git", "error", err) return "", 0, 0 } @@ -920,16 +980,19 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in }, }) if err != nil { + logging.Error("Failed to commit before content", "error", err) return "", 0, 0 } // Write the "after" content and commit it if err = os.WriteFile(fullPath, []byte(afterContent), 0o644); err != nil { + logging.Error("Failed to write after content to file", "error", err) return "", 0, 0 } _, err = wt.Add(fileName) if err != nil { + logging.Error("Failed to add file to git", "error", err) return "", 0, 0 } @@ -941,22 +1004,26 @@ func GenerateDiff(beforeContent, afterContent, fileName string) (string, int, in }, }) if err != nil { + logging.Error("Failed to commit after content", "error", err) return "", 0, 0 } // Get the diff between the two commits beforeCommitObj, err := repo.CommitObject(beforeCommit) if err != nil { + logging.Error("Failed to get before commit object", "error", err) return "", 0, 0 } afterCommitObj, err := repo.CommitObject(afterCommit) if err != nil { + logging.Error("Failed to get after commit object", "error", err) return "", 0, 0 } patch, err := beforeCommitObj.Patch(afterCommitObj) if err != nil { + logging.Error("Failed to create git diff patch", "error", err) return "", 0, 0 } diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 83160bb645458150d96373de255ac78f1111dadd..308412bde86f8f6743997e8627b64082d2d8866f 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/message" @@ -53,7 +54,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required") } - agent, err := NewTaskAgent(b.messages, b.sessions, b.lspClients) + agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients)) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err) } @@ -63,21 +64,16 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err) } - err = agent.Generate(ctx, session.ID, params.Prompt) + done, err := agent.Run(ctx, session.ID, params.Prompt) if err != nil { return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) } - - messages, err := b.messages.List(ctx, session.ID) - if err != nil { - return tools.ToolResponse{}, fmt.Errorf("error listing messages: %s", err) - } - - if len(messages) == 0 { - return tools.NewTextErrorResponse("no response"), nil + result := <-done + if result.Err() != nil { + return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err()) } - response := messages[len(messages)-1] + response := result.Response() if response.Role != message.Assistant { return tools.NewTextErrorResponse("no response"), nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 1958111a136ff6956c9370d568831831c5906807..ab2742ec19b54b3c81f2c53df3d609c6f29e4a73 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "os" - "runtime/debug" "strings" "sync" @@ -16,133 +14,101 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" "github.com/kujtimiihoxha/termai/internal/session" ) // Common errors var ( - ErrProviderNotEnabled = errors.New("provider is not enabled") - ErrRequestCancelled = errors.New("request cancelled by user") - ErrSessionBusy = errors.New("session is currently processing another request") + ErrRequestCancelled = errors.New("request cancelled by user") + ErrSessionBusy = errors.New("session is currently processing another request") ) -// Service defines the interface for generating responses +type AgentEvent struct { + message message.Message + err error +} + +func (e *AgentEvent) Err() error { + return e.err +} + +func (e *AgentEvent) Response() message.Message { + return e.message +} + type Service interface { - Generate(ctx context.Context, sessionID string, content string) error - Cancel(sessionID string) error + Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) + Cancel(sessionID string) + IsSessionBusy(sessionID string) bool } type agent struct { - sessions session.Service - messages message.Service - model models.Model - tools []tools.BaseTool - agent provider.Provider - titleGenerator provider.Provider - activeRequests sync.Map // map[sessionID]context.CancelFunc + sessions session.Service + messages message.Service + + tools []tools.BaseTool + provider provider.Provider + + titleProvider provider.Provider + + activeRequests sync.Map } -// NewAgent creates a new agent instance with the given model and tools -func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) { - agentProvider, titleGenerator, err := getAgentProviders(ctx, model) +func NewAgent( + agentName config.AgentName, + sessions session.Service, + messages message.Service, + agentTools []tools.BaseTool, +) (Service, error) { + agentProvider, err := createAgentProvider(agentName) if err != nil { - return nil, fmt.Errorf("failed to initialize providers: %w", err) + return nil, err + } + var titleProvider provider.Provider + // Only generate titles for the coder agent + if agentName == config.AgentCoder { + titleProvider, err = createAgentProvider(config.AgentTitle) + if err != nil { + return nil, err + } } - return &agent{ - model: model, - tools: tools, - sessions: sessions, + agent := &agent{ + provider: agentProvider, messages: messages, - agent: agentProvider, - titleGenerator: titleGenerator, + sessions: sessions, + tools: agentTools, + titleProvider: titleProvider, activeRequests: sync.Map{}, - }, nil + } + + return agent, nil } -// Cancel cancels an active request by session ID -func (a *agent) Cancel(sessionID string) error { +func (a *agent) Cancel(sessionID string) { if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists { if cancel, ok := cancelFunc.(context.CancelFunc); ok { logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID)) cancel() - return nil } } - return errors.New("no active request found for this session") } -// Generate starts the generation process -func (a *agent) Generate(ctx context.Context, sessionID string, content string) error { - // Check if this session already has an active request - if _, busy := a.activeRequests.Load(sessionID); busy { - return ErrSessionBusy - } - - // Create a cancellable context - genCtx, cancel := context.WithCancel(ctx) - - // Store cancel function to allow user cancellation - a.activeRequests.Store(sessionID, cancel) - - // Launch the generation in a goroutine - go func() { - defer func() { - if r := recover(); r != nil { - logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r)) - - // dump stack trace into a file - file, err := os.Create("panic.log") - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) - return - } - - defer file.Close() - - stackTrace := debug.Stack() - if _, err := file.Write(stackTrace); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err)) - } - - } - }() - defer a.activeRequests.Delete(sessionID) - defer cancel() - - if err := a.generate(genCtx, sessionID, content); err != nil { - if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) { - // Log the error (avoid logging cancellations as they're expected) - logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err)) - - // You may want to create an error message in the chat - bgCtx := context.Background() - errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err) - _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{ - Role: message.System, - Parts: []message.ContentPart{ - message.TextContent{ - Text: errorMsg, - }, - }, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr)) - } - } - } - }() - - return nil -} - -// IsSessionBusy checks if a session currently has an active request func (a *agent) IsSessionBusy(sessionID string) bool { _, busy := a.activeRequests.Load(sessionID) return busy -} // handleTitleGeneration asynchronously generates a title for new sessions -func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { - response, err := a.titleGenerator.SendMessages( +} + +func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error { + if a.titleProvider == nil { + return nil + } + session, err := a.sessions.Get(ctx, sessionID) + if err != nil { + return err + } + response, err := a.titleProvider.SendMessages( ctx, []message.Message{ { @@ -154,121 +120,152 @@ func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content st }, }, }, - nil, + make([]tools.BaseTool, 0), ) if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err)) - return + return err } - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err)) - return + title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " ")) + if title == "" { + return nil } - if response.Content != "" { - session.Title = strings.TrimSpace(response.Content) - session.Title = strings.ReplaceAll(session.Title, "\n", " ") - if _, err := a.sessions.Save(ctx, session); err != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err)) - } + session.Title = title + _, err = a.sessions.Save(ctx, session) + return err +} + +func (a *agent) err(err error) AgentEvent { + return AgentEvent{ + err: err, } } -// TrackUsage updates token usage statistics for the session -func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { - session, err := a.sessions.Get(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to get session: %w", err) +func (a *agent) Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) { + events := make(chan AgentEvent) + if a.IsSessionBusy(sessionID) { + return nil, ErrSessionBusy } - cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + - model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + - model.CostPer1MIn/1e6*float64(usage.InputTokens) + - model.CostPer1MOut/1e6*float64(usage.OutputTokens) + genCtx, cancel := context.WithCancel(ctx) + + a.activeRequests.Store(sessionID, cancel) + go func() { + logging.Debug("Request started", "sessionID", sessionID) + defer logging.RecoverPanic("agent.Run", func() { + events <- a.err(fmt.Errorf("panic while running the agent")) + }) - session.Cost += cost - session.CompletionTokens += usage.OutputTokens - session.PromptTokens += usage.InputTokens + result := a.processGeneration(genCtx, sessionID, content) + if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) { + logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, result)) + } + logging.Debug("Request completed", "sessionID", sessionID) + a.activeRequests.Delete(sessionID) + cancel() + events <- result + close(events) + }() + return events, nil +} - _, err = a.sessions.Save(ctx, session) +func (a *agent) processGeneration(ctx context.Context, sessionID, content string) AgentEvent { + // List existing messages; if none, start title generation asynchronously. + msgs, err := a.messages.List(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to save session: %w", err) + return a.err(fmt.Errorf("failed to list messages: %w", err)) + } + if len(msgs) == 0 { + go func() { + defer logging.RecoverPanic("agent.Run", func() { + logging.ErrorPersist("panic while generating title") + }) + titleErr := a.generateTitle(context.Background(), sessionID, content) + if titleErr != nil { + logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr)) + } + }() } - return nil -} -// processEvent handles different types of events during generation -func (a *agent) processEvent( - ctx context.Context, - sessionID string, - assistantMsg *message.Message, - event provider.ProviderEvent, -) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - // Continue processing + userMsg, err := a.createUserMessage(ctx, sessionID, content) + if err != nil { + return a.err(fmt.Errorf("failed to create user message: %w", err)) } - switch event.Type { - case provider.EventThinkingDelta: - assistantMsg.AppendReasoningContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventContentDelta: - assistantMsg.AppendContent(event.Content) - return a.messages.Update(ctx, *assistantMsg) - case provider.EventError: - if errors.Is(event.Error, context.Canceled) { - logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) - return context.Canceled + // Append the new user message to the conversation history. + msgHistory := append(msgs, userMsg) + for { + // Check for cancellation before each iteration + select { + case <-ctx.Done(): + return a.err(ctx.Err()) + default: + // Continue processing } - logging.ErrorPersist(event.Error.Error()) - return event.Error - case provider.EventWarning: - logging.WarnPersist(event.Info) - case provider.EventInfo: - logging.InfoPersist(event.Info) - case provider.EventComplete: - assistantMsg.SetToolCalls(event.Response.ToolCalls) - assistantMsg.AddFinish(event.Response.FinishReason) - if err := a.messages.Update(ctx, *assistantMsg); err != nil { - return fmt.Errorf("failed to update message: %w", err) + agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) + if err != nil { + if errors.Is(err, context.Canceled) { + return a.err(ErrRequestCancelled) + } + return a.err(fmt.Errorf("failed to process events: %w", err)) + } + logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) + if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { + // We are not done, we need to respond with the tool response + msgHistory = append(msgHistory, agentMessage, *toolResults) + continue + } + return AgentEvent{ + message: agentMessage, } - return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage) } +} - return nil +func (a *agent) createUserMessage(ctx context.Context, sessionID, content string) (message.Message, error) { + return a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.User, + Parts: []message.ContentPart{ + message.TextContent{Text: content}, + }, + }) } -// ExecuteTools runs all tool calls sequentially and returns the results -func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) { - toolResults := make([]message.ToolResult, len(toolCalls)) +func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { + eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools) + + assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: a.provider.Model().ID, + }) + if err != nil { + return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) + } - // Create a child context that can be canceled - ctx, cancel := context.WithCancel(ctx) - defer cancel() + // Add the session and message ID into the context if needed by tools. + ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) + ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) - // Check if already canceled before starting any execution - if ctx.Err() != nil { - // Mark all tools as canceled - for i, toolCall := range toolCalls { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: "Tool execution canceled by user", - IsError: true, - } + // Process each event in the stream. + for event := range eventChan { + if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil { + a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, processErr + } + if ctx.Err() != nil { + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + return assistantMsg, nil, ctx.Err() } - return toolResults, ctx.Err() } + toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls())) + toolCalls := assistantMsg.ToolCalls() for i, toolCall := range toolCalls { - // Check for cancellation before executing each tool select { case <-ctx.Done(): - // Mark this and all remaining tools as canceled + a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled) + // Make all future tool calls cancelled for j := i; j < len(toolCalls); j++ { toolResults[j] = message.ToolResult{ ToolCallID: toolCalls[j].ID, @@ -276,412 +273,180 @@ func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, IsError: true, } } - return toolResults, ctx.Err() + goto out default: // Continue processing - } - - response := "" - isError := false - found := false - - // Find and execute the appropriate tool - for _, tool := range tls { - if tool.Info().Name == toolCall.Name { - found = true - toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ - ID: toolCall.ID, - Name: toolCall.Name, - Input: toolCall.Input, - }) - - if toolErr != nil { - if errors.Is(toolErr, context.Canceled) { - response = "Tool execution canceled by user" - } else { - response = fmt.Sprintf("Error running tool: %s", toolErr) - } - isError = true - } else { - response = toolResult.Content - isError = toolResult.IsError + var tool tools.BaseTool + for _, availableTools := range a.tools { + if availableTools.Info().Name == toolCall.Name { + tool = availableTools } - break } - } - - if !found { - response = fmt.Sprintf("Tool not found: %s", toolCall.Name) - isError = true - } - - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: response, - IsError: isError, - } - } - return toolResults, nil -} - -// handleToolExecution processes tool calls and creates tool result messages -func (a *agent) handleToolExecution( - ctx context.Context, - assistantMsg message.Message, -) (*message.Message, error) { - select { - case <-ctx.Done(): - // If cancelled, create tool results that indicate cancellation - if len(assistantMsg.ToolCalls()) > 0 { - toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls())) - for _, tc := range assistantMsg.ToolCalls() { - toolResults = append(toolResults, message.ToolResult{ - ToolCallID: tc.ID, - Content: "Tool execution canceled by user", + // Tool not found + if tool == nil { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: fmt.Sprintf("Tool not found: %s", toolCall.Name), IsError: true, - }) + } + continue } - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) - } - msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, + toolResult, toolErr := tool.Run(ctx, tools.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Name, + Input: toolCall.Input, }) - if err != nil { - return nil, fmt.Errorf("failed to create cancelled tool message: %w", err) - } - return &msg, ctx.Err() - } - return nil, ctx.Err() - default: - // Continue processing - } - - if len(assistantMsg.ToolCalls()) == 0 { - return nil, nil - } - - toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools) - if err != nil { - // If error is from cancellation, still return the partial results we have - if errors.Is(err, context.Canceled) { - // Use background context to ensure the message is created even if original context is cancelled - bgCtx := context.Background() - parts := make([]message.ContentPart, 0) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) + if toolErr != nil { + if errors.Is(toolErr, permission.ErrorPermissionDenied) { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Permission denied", + IsError: true, + } + for j := i + 1; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Tool execution canceled by user", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied) + } else { + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolErr.Error(), + IsError: true, + } + for j := i; j < len(toolCalls); j++ { + toolResults[j] = message.ToolResult{ + ToolCallID: toolCalls[j].ID, + Content: "Previous tool failed", + IsError: true, + } + } + a.finishMessage(ctx, &assistantMsg, message.FinishReasonError) + } + // If permission is denied or an error happens we cancel all the following tools + break } - - msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: parts, - }) - if createErr != nil { - logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr)) - return nil, err + toolResults[i] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: toolResult.Content, + Metadata: toolResult.Metadata, + IsError: toolResult.IsError, } - return &msg, err } - return nil, err } - - parts := make([]message.ContentPart, 0, len(toolResults)) - for _, toolResult := range toolResults { - parts = append(parts, toolResult) +out: + if len(toolResults) == 0 { + return assistantMsg, nil, nil } - - msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{ + parts := make([]message.ContentPart, 0) + for _, tr := range toolResults { + parts = append(parts, tr) + } + msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: parts, }) if err != nil { - return nil, fmt.Errorf("failed to create tool message: %w", err) + return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err) } - return &msg, nil + return assistantMsg, &msg, err } -// generate handles the main generation workflow -func (a *agent) generate(ctx context.Context, sessionID string, content string) error { - ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) +func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) { + msg.AddFinish(finishReson) + _ = a.messages.Update(ctx, *msg) +} - // Handle context cancellation at any point - if err := ctx.Err(); err != nil { - return ErrRequestCancelled +func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Continue processing. } - messages, err := a.messages.List(ctx, sessionID) - if err != nil { - return fmt.Errorf("failed to list messages: %w", err) + switch event.Type { + case provider.EventThinkingDelta: + assistantMsg.AppendReasoningContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventContentDelta: + assistantMsg.AppendContent(event.Content) + return a.messages.Update(ctx, *assistantMsg) + case provider.EventError: + if errors.Is(event.Error, context.Canceled) { + logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) + return context.Canceled + } + logging.ErrorPersist(event.Error.Error()) + return event.Error + case provider.EventComplete: + assistantMsg.SetToolCalls(event.Response.ToolCalls) + assistantMsg.AddFinish(event.Response.FinishReason) + if err := a.messages.Update(ctx, *assistantMsg); err != nil { + return fmt.Errorf("failed to update message: %w", err) + } + return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage) } - if len(messages) == 0 { - titleCtx := context.Background() - go a.handleTitleGeneration(titleCtx, sessionID, content) - } + return nil +} - userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.User, - Parts: []message.ContentPart{ - message.TextContent{ - Text: content, - }, - }, - }) +func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error { + sess, err := a.sessions.Get(ctx, sessionID) if err != nil { - return fmt.Errorf("failed to create user message: %w", err) + return fmt.Errorf("failed to get session: %w", err) } - messages = append(messages, userMsg) - - for { - // Check for cancellation before each iteration - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - // Continue processing - } - - eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools) - if err != nil { - if errors.Is(err, context.Canceled) { - return ErrRequestCancelled - } - return fmt.Errorf("failed to stream response: %w", err) - } - - assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{}, - Model: a.model.ID, - }) - if err != nil { - return fmt.Errorf("failed to create assistant message: %w", err) - } - - ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) - - // Process events from the LLM provider - for event := range eventChan { - if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil { - if errors.Is(err, context.Canceled) { - // Mark as canceled but don't create separate message - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - assistantMsg.AddFinish("error:" + err.Error()) - _ = a.messages.Update(ctx, assistantMsg) - return fmt.Errorf("event processing error: %w", err) - } - - // Check for cancellation during event processing - select { - case <-ctx.Done(): - // Mark as canceled - assistantMsg.AddFinish("canceled") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - } - - // Check for cancellation before tool execution - select { - case <-ctx.Done(): - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - default: - } - - // Execute any tool calls - toolMsg, err := a.handleToolExecution(ctx, assistantMsg) - if err != nil { - if errors.Is(err, context.Canceled) { - assistantMsg.AddFinish("canceled_by_user") - _ = a.messages.Update(context.Background(), assistantMsg) - return ErrRequestCancelled - } - return fmt.Errorf("tool execution error: %w", err) - } - - if err := a.messages.Update(ctx, assistantMsg); err != nil { - return fmt.Errorf("failed to update assistant message: %w", err) - } - - // If no tool calls, we're done - if len(assistantMsg.ToolCalls()) == 0 { - break - } + cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + + model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + + model.CostPer1MIn/1e6*float64(usage.InputTokens) + + model.CostPer1MOut/1e6*float64(usage.OutputTokens) - // Add messages for next iteration - messages = append(messages, assistantMsg) - if toolMsg != nil { - messages = append(messages, *toolMsg) - } + sess.Cost += cost + sess.CompletionTokens += usage.OutputTokens + sess.PromptTokens += usage.InputTokens - // Check for cancellation after tool execution - select { - case <-ctx.Done(): - return ErrRequestCancelled - default: - } + _, err = a.sessions.Save(ctx, sess) + if err != nil { + return fmt.Errorf("failed to save session: %w", err) } - return nil } -// getAgentProviders initializes the LLM providers based on the chosen model -func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) { - maxTokens := config.Get().Model.CoderMaxTokens - - providerConfig, ok := config.Get().Providers[model.Provider] - if !ok || providerConfig.Disabled { - return nil, nil, ErrProviderNotEnabled +func createAgentProvider(agentName config.AgentName) (provider.Provider, error) { + cfg := config.Get() + agentConfig, ok := cfg.Agents[agentName] + if !ok { + return nil, fmt.Errorf("agent %s not found", agentName) + } + model, ok := models.SupportedModels[agentConfig.Model] + if !ok { + return nil, fmt.Errorf("model %s not supported", agentConfig.Model) } - var agentProvider provider.Provider - var titleGenerator provider.Provider - var err error - - switch model.Provider { - case models.ProviderOpenAI: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err) - } - - case models.ProviderAnthropic: - agentProvider, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithAnthropicMaxTokens(maxTokens), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err) - } - - titleGenerator, err = provider.NewAnthropicProvider( - provider.WithAnthropicSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithAnthropicMaxTokens(80), - provider.WithAnthropicKey(providerConfig.APIKey), - provider.WithAnthropicModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err) - } - - case models.ProviderGemini: - agentProvider, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.CoderOpenAISystemPrompt(), - ), - provider.WithGeminiMaxTokens(int32(maxTokens)), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err) - } - - titleGenerator, err = provider.NewGeminiProvider( - ctx, - provider.WithGeminiSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithGeminiMaxTokens(80), - provider.WithGeminiKey(providerConfig.APIKey), - provider.WithGeminiModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err) - } - - case models.ProviderGROQ: - agentProvider, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithOpenAIMaxTokens(maxTokens), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err) - } - - titleGenerator, err = provider.NewOpenAIProvider( - provider.WithOpenAISystemMessage( - prompt.TitlePrompt(), - ), - provider.WithOpenAIMaxTokens(80), - provider.WithOpenAIModel(model), - provider.WithOpenAIKey(providerConfig.APIKey), - provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err) - } - - case models.ProviderBedrock: - agentProvider, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.CoderAnthropicSystemPrompt(), - ), - provider.WithBedrockMaxTokens(maxTokens), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err) - } - - titleGenerator, err = provider.NewBedrockProvider( - provider.WithBedrockSystemMessage( - prompt.TitlePrompt(), - ), - provider.WithBedrockMaxTokens(80), - provider.WithBedrockModel(model), - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err) - } - default: - return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider) + providerCfg, ok := cfg.Providers[model.Provider] + if !ok { + return nil, fmt.Errorf("provider %s not supported", model.Provider) + } + if providerCfg.Disabled { + return nil, fmt.Errorf("provider %s is not enabled", model.Provider) + } + agentProvider, err := provider.NewProvider( + model.Provider, + provider.WithAPIKey(providerCfg.APIKey), + provider.WithModel(model), + provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), + provider.WithMaxTokens(agentConfig.MaxTokens), + ) + if err != nil { + return nil, fmt.Errorf("could not create provider: %v", err) } - return agentProvider, titleGenerator, nil + return agentProvider, nil } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go deleted file mode 100644 index a3db6b55c1f675d4ca77ff03cf9cc781b2e795e9..0000000000000000000000000000000000000000 --- a/internal/llm/agent/coder.go +++ /dev/null @@ -1,63 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type coderAgent struct { - Service -} - -func NewCoderAgent( - permissions permission.Service, - sessions session.Service, - messages message.Service, - lspClients map[string]*lsp.Client, -) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - otherTools := GetMcpTools(ctx, permissions) - if len(lspClients) > 0 { - otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) - } - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - append( - []tools.BaseTool{ - tools.NewBashTool(permissions), - tools.NewEditTool(lspClients, permissions), - tools.NewFetchTool(permissions), - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - tools.NewWriteTool(lspClients, permissions), - NewAgentTool(sessions, messages, lspClients), - }, otherTools..., - ), - ) - if err != nil { - return nil, err - } - - return &coderAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index b1c97b512b04f77d6da37636a474b8ef25ccf78c..c7ea4916cad38a69df1e8bfde3fe78155a009496 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -46,7 +46,7 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } @@ -135,7 +135,7 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions initRequest := mcp.InitializeRequest{} initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "termai", + Name: "OpenCode", Version: version.Version, } diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go deleted file mode 100644 index fca1f223f505c1756503967b21dc592dc64903d2..0000000000000000000000000000000000000000 --- a/internal/llm/agent/task.go +++ /dev/null @@ -1,47 +0,0 @@ -package agent - -import ( - "context" - "errors" - - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/session" -) - -type taskAgent struct { - Service -} - -func NewTaskAgent(messages message.Service, sessions session.Service, lspClients map[string]*lsp.Client) (Service, error) { - model, ok := models.SupportedModels[config.Get().Model.Coder] - if !ok { - return nil, errors.New("model not supported") - } - - ctx := context.Background() - - agent, err := NewAgent( - ctx, - sessions, - messages, - model, - []tools.BaseTool{ - tools.NewGlobTool(), - tools.NewGrepTool(), - tools.NewLsTool(), - tools.NewSourcegraphTool(), - tools.NewViewTool(lspClients), - }, - ) - if err != nil { - return nil, err - } - - return &taskAgent{ - agent, - }, nil -} diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go new file mode 100644 index 0000000000000000000000000000000000000000..a37f1d65d0327ad51d8bfe2830e4c3661c324a30 --- /dev/null +++ b/internal/llm/agent/tools.go @@ -0,0 +1,50 @@ +package agent + +import ( + "context" + + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/session" +) + +func CoderAgentTools( + permissions permission.Service, + sessions session.Service, + messages message.Service, + history history.Service, + lspClients map[string]*lsp.Client, +) []tools.BaseTool { + ctx := context.Background() + otherTools := GetMcpTools(ctx, permissions) + if len(lspClients) > 0 { + otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients)) + } + return append( + []tools.BaseTool{ + tools.NewBashTool(permissions), + tools.NewEditTool(lspClients, permissions, history), + tools.NewFetchTool(permissions), + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + tools.NewWriteTool(lspClients, permissions, history), + NewAgentTool(sessions, messages, lspClients), + }, otherTools..., + ) +} + +func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool { + return []tools.BaseTool{ + tools.NewGlobTool(), + tools.NewGrepTool(), + tools.NewLsTool(), + tools.NewSourcegraphTool(), + tools.NewViewTool(lspClients), + } +} diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go new file mode 100644 index 0000000000000000000000000000000000000000..48307e6d3fe72af892944234767ca8dba723398f --- /dev/null +++ b/internal/llm/models/anthropic.go @@ -0,0 +1,71 @@ +package models + +const ( + ProviderAnthropic ModelProvider = "anthropic" + + // Models + Claude35Sonnet ModelID = "claude-3.5-sonnet" + Claude3Haiku ModelID = "claude-3-haiku" + Claude37Sonnet ModelID = "claude-3.7-sonnet" + Claude35Haiku ModelID = "claude-3.5-haiku" + Claude3Opus ModelID = "claude-3-opus" +) + +var AnthropicModels = map[ModelID]Model{ + // Anthropic + Claude35Sonnet: { + ID: Claude35Sonnet, + Name: "Claude 3.5 Sonnet", + Provider: ProviderAnthropic, + APIModel: "claude-3-5-sonnet-latest", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + ContextWindow: 200000, + }, + Claude3Haiku: { + ID: Claude3Haiku, + Name: "Claude 3 Haiku", + Provider: ProviderAnthropic, + APIModel: "claude-3-haiku-latest", + CostPer1MIn: 0.25, + CostPer1MInCached: 0.30, + CostPer1MOutCached: 0.03, + CostPer1MOut: 1.25, + ContextWindow: 200000, + }, + Claude37Sonnet: { + ID: Claude37Sonnet, + Name: "Claude 3.7 Sonnet", + Provider: ProviderAnthropic, + APIModel: "claude-3-7-sonnet-latest", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + ContextWindow: 200000, + }, + Claude35Haiku: { + ID: Claude35Haiku, + Name: "Claude 3.5 Haiku", + Provider: ProviderAnthropic, + APIModel: "claude-3-5-haiku-latest", + CostPer1MIn: 0.80, + CostPer1MInCached: 1.0, + CostPer1MOutCached: 0.08, + CostPer1MOut: 4.0, + ContextWindow: 200000, + }, + Claude3Opus: { + ID: Claude3Opus, + Name: "Claude 3 Opus", + Provider: ProviderAnthropic, + APIModel: "claude-3-opus-latest", + CostPer1MIn: 15.0, + CostPer1MInCached: 18.75, + CostPer1MOutCached: 1.50, + CostPer1MOut: 75.0, + ContextWindow: 200000, + }, +} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 140693237bec620b2fb0120addcf1e52ad7b194a..4d4589bfdf15b777782633fa4c0c09b324cffa9a 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -1,5 +1,7 @@ package models +import "maps" + type ( ModelID string ModelProvider string @@ -14,15 +16,13 @@ type Model struct { CostPer1MOut float64 `json:"cost_per_1m_out"` CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` + ContextWindow int64 `json:"context_window"` } // Model IDs const ( - // Anthropic - Claude35Sonnet ModelID = "claude-3.5-sonnet" - Claude3Haiku ModelID = "claude-3-haiku" - Claude37Sonnet ModelID = "claude-3.7-sonnet" // OpenAI + GPT4o ModelID = "gpt-4o" GPT41 ModelID = "gpt-4.1" // GEMINI @@ -37,47 +37,59 @@ const ( ) const ( - ProviderOpenAI ModelProvider = "openai" - ProviderAnthropic ModelProvider = "anthropic" - ProviderBedrock ModelProvider = "bedrock" - ProviderGemini ModelProvider = "gemini" - ProviderGROQ ModelProvider = "groq" + ProviderOpenAI ModelProvider = "openai" + ProviderBedrock ModelProvider = "bedrock" + ProviderGemini ModelProvider = "gemini" + ProviderGROQ ModelProvider = "groq" + + // ForTests + ProviderMock ModelProvider = "__mock" ) var SupportedModels = map[ModelID]Model{ - // Anthropic - Claude35Sonnet: { - ID: Claude35Sonnet, - Name: "Claude 3.5 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-5-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, - Claude3Haiku: { - ID: Claude3Haiku, - Name: "Claude 3 Haiku", - Provider: ProviderAnthropic, - APIModel: "claude-3-haiku-latest", - CostPer1MIn: 0.80, - CostPer1MInCached: 1, - CostPer1MOutCached: 0.08, - CostPer1MOut: 4, - }, - Claude37Sonnet: { - ID: Claude37Sonnet, - Name: "Claude 3.7 Sonnet", - Provider: ProviderAnthropic, - APIModel: "claude-3-7-sonnet-latest", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, + // // Anthropic + // Claude35Sonnet: { + // ID: Claude35Sonnet, + // Name: "Claude 3.5 Sonnet", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-5-sonnet-latest", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, + // Claude3Haiku: { + // ID: Claude3Haiku, + // Name: "Claude 3 Haiku", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-haiku-latest", + // CostPer1MIn: 0.80, + // CostPer1MInCached: 1, + // CostPer1MOutCached: 0.08, + // CostPer1MOut: 4, + // }, + // Claude37Sonnet: { + // ID: Claude37Sonnet, + // Name: "Claude 3.7 Sonnet", + // Provider: ProviderAnthropic, + // APIModel: "claude-3-7-sonnet-latest", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, + // + // // OpenAI + GPT4o: { + ID: GPT4o, + Name: "GPT-4o", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1", + CostPer1MIn: 2.00, + CostPer1MInCached: 0.50, + CostPer1MOutCached: 0, + CostPer1MOut: 8.00, }, - - // OpenAI GPT41: { ID: GPT41, Name: "GPT-4.1", @@ -88,51 +100,55 @@ var SupportedModels = map[ModelID]Model{ CostPer1MOutCached: 0, CostPer1MOut: 8.00, }, + // + // // GEMINI + // GEMINI25: { + // ID: GEMINI25, + // Name: "Gemini 2.5 Pro", + // Provider: ProviderGemini, + // APIModel: "gemini-2.5-pro-exp-03-25", + // CostPer1MIn: 0, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0, + // CostPer1MOut: 0, + // }, + // + // GRMINI20Flash: { + // ID: GRMINI20Flash, + // Name: "Gemini 2.0 Flash", + // Provider: ProviderGemini, + // APIModel: "gemini-2.0-flash", + // CostPer1MIn: 0.1, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0.025, + // CostPer1MOut: 0.4, + // }, + // + // // GROQ + // QWENQwq: { + // ID: QWENQwq, + // Name: "Qwen Qwq", + // Provider: ProviderGROQ, + // APIModel: "qwen-qwq-32b", + // CostPer1MIn: 0, + // CostPer1MInCached: 0, + // CostPer1MOutCached: 0, + // CostPer1MOut: 0, + // }, + // + // // Bedrock + // BedrockClaude37Sonnet: { + // ID: BedrockClaude37Sonnet, + // Name: "Bedrock: Claude 3.7 Sonnet", + // Provider: ProviderBedrock, + // APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", + // CostPer1MIn: 3.0, + // CostPer1MInCached: 3.75, + // CostPer1MOutCached: 0.30, + // CostPer1MOut: 15.0, + // }, +} - // GEMINI - GEMINI25: { - ID: GEMINI25, - Name: "Gemini 2.5 Pro", - Provider: ProviderGemini, - APIModel: "gemini-2.5-pro-exp-03-25", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0, - }, - - GRMINI20Flash: { - ID: GRMINI20Flash, - Name: "Gemini 2.0 Flash", - Provider: ProviderGemini, - APIModel: "gemini-2.0-flash", - CostPer1MIn: 0.1, - CostPer1MInCached: 0, - CostPer1MOutCached: 0.025, - CostPer1MOut: 0.4, - }, - - // GROQ - QWENQwq: { - ID: QWENQwq, - Name: "Qwen Qwq", - Provider: ProviderGROQ, - APIModel: "qwen-qwq-32b", - CostPer1MIn: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - CostPer1MOut: 0, - }, - - // Bedrock - BedrockClaude37Sonnet: { - ID: BedrockClaude37Sonnet, - Name: "Bedrock: Claude 3.7 Sonnet", - Provider: ProviderBedrock, - APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", - CostPer1MIn: 3.0, - CostPer1MInCached: 3.75, - CostPer1MOutCached: 0.30, - CostPer1MOut: 15.0, - }, +func init() { + maps.Copy(SupportedModels, AnthropicModels) } diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 47941f976f76e1f71c6d17cf6f951a65c72dcf87..7439fd57064559e26592d1f23ea24abfc217f2c0 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -9,11 +9,22 @@ import ( "time" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" ) -func CoderOpenAISystemPrompt() string { - basePrompt := `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. +func CoderPrompt(provider models.ModelProvider) string { + basePrompt := baseAnthropicCoderPrompt + switch provider { + case models.ProviderOpenAI: + basePrompt = baseOpenAICoderPrompt + } + envInfo := getEnvironmentInfo() + + return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) +} + +const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. # Your mindset Act like a competent, efficient software engineer who is familiar with large codebases. You should: @@ -65,13 +76,7 @@ assistant: [searches repo for references, returns file paths and lines] Never commit changes unless the user explicitly asks you to.` - envInfo := getEnvironmentInfo() - - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) -} - -func CoderAnthropicSystemPrompt() string { - basePrompt := `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. +const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. @@ -166,11 +171,6 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.` - envInfo := getEnvironmentInfo() - - return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) -} - func getEnvironmentInfo() string { cwd := config.WorkingDirectory() isGit := isGitRepo(cwd) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go new file mode 100644 index 0000000000000000000000000000000000000000..63fc2df7bcecd30f37ee04259c05e6b425cd75fd --- /dev/null +++ b/internal/llm/prompt/prompt.go @@ -0,0 +1,19 @@ +package prompt + +import ( + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/llm/models" +) + +func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { + switch agentName { + case config.AgentCoder: + return CoderPrompt(provider) + case config.AgentTitle: + return TitlePrompt(provider) + case config.AgentTask: + return TaskPrompt(provider) + default: + return "You are a helpful assistant" + } +} diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index ee3c707faddf78b8bb7e8df649805794cb81bd37..8bf604ad99750595505d334659a4f596d54edf0c 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -2,11 +2,12 @@ package prompt import ( "fmt" + + "github.com/kujtimiihoxha/termai/internal/llm/models" ) -func TaskAgentSystemPrompt() string { +func TaskPrompt(_ models.ModelProvider) string { agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question. - Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 5c47f4d64617f2f91af88bbabdb0bace37d71f4b..3023a8550d18f7c0451d52f907150d76499b1552 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,6 +1,8 @@ package prompt -func TitlePrompt() string { +import "github.com/kujtimiihoxha/termai/internal/llm/models" + +func TitlePrompt(_ models.ModelProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 93c4308ad77c0fdf7a3cd8d7629719e598f95bf1..c3a4efc49bea916f55a48be32771a1f06d6a9617 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -12,187 +12,257 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" ) -type anthropicProvider struct { - client anthropic.Client - model models.Model - maxTokens int64 - apiKey string - systemMessage string - useBedrock bool - disableCache bool +type anthropicOptions struct { + useBedrock bool + disableCache bool + shouldThink func(userMessage string) bool } -type AnthropicOption func(*anthropicProvider) +type AnthropicOption func(*anthropicOptions) -func WithAnthropicSystemMessage(message string) AnthropicOption { - return func(a *anthropicProvider) { - a.systemMessage = message - } +type anthropicClient struct { + providerOptions providerClientOptions + options anthropicOptions + client anthropic.Client } -func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption { - return func(a *anthropicProvider) { - a.maxTokens = maxTokens - } -} +type AnthropicClient ProviderClient -func WithAnthropicModel(model models.Model) AnthropicOption { - return func(a *anthropicProvider) { - a.model = model +func newAnthropicClient(opts providerClientOptions) AnthropicClient { + anthropicOpts := anthropicOptions{} + for _, o := range opts.anthropicOptions { + o(&anthropicOpts) } -} -func WithAnthropicKey(apiKey string) AnthropicOption { - return func(a *anthropicProvider) { - a.apiKey = apiKey + anthropicClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) } -} - -func WithAnthropicBedrock() AnthropicOption { - return func(a *anthropicProvider) { - a.useBedrock = true + if anthropicOpts.useBedrock { + anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } -} -func WithAnthropicDisableCache() AnthropicOption { - return func(a *anthropicProvider) { - a.disableCache = true + client := anthropic.NewClient(anthropicClientOptions...) + return &anthropicClient{ + providerOptions: opts, + options: anthropicOpts, + client: client, } } -func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) { - provider := &anthropicProvider{ - maxTokens: 1024, - } +func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { + cachedBlocks := 0 + for _, msg := range messages { + switch msg.Role { + case message.User: + content := anthropic.NewTextBlock(msg.Content().String()) + if cachedBlocks < 2 && !a.options.disableCache { + content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } + cachedBlocks++ + } + anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) - for _, opt := range opts { - opt(provider) - } + case message.Assistant: + blocks := []anthropic.ContentBlockParamUnion{} + if msg.Content().String() != "" { + content := anthropic.NewTextBlock(msg.Content().String()) + if cachedBlocks < 2 && !a.options.disableCache { + content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } + cachedBlocks++ + } + blocks = append(blocks, content) + } - if provider.systemMessage == "" { - return nil, errors.New("system message is required") - } + for _, toolCall := range msg.ToolCalls() { + var inputMap map[string]any + err := json.Unmarshal([]byte(toolCall.Input), &inputMap) + if err != nil { + continue + } + blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) + } - anthropicOptions := []option.RequestOption{} + if len(blocks) == 0 { + logging.Warn("There is a message without content, investigate") + // This should never happend but we log this because we might have a bug in our cleanup method + continue + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - if provider.apiKey != "" { - anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey)) - } - if provider.useBedrock { - anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background())) + case message.Tool: + results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) + for i, toolResult := range msg.ToolResults() { + results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) + } + anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) + } } - - provider.client = anthropic.NewClient(anthropicOptions...) - return provider, nil + return } -func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - anthropicMessages := a.convertToAnthropicMessages(messages) - anthropicTools := a.convertToAnthropicTools(tools) - - response, err := a.client.Messages.New( - ctx, - anthropic.MessageNewParams{ - Model: anthropic.Model(a.model.APIModel), - MaxTokens: a.maxTokens, - Temperature: anthropic.Float(0), - Messages: anthropicMessages, - Tools: anthropicTools, - System: []anthropic.TextBlockParam{ - { - Text: a.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, +func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { + anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) + + for i, tool := range tools { + info := tool.Info() + toolParam := anthropic.ToolParam{ + Name: info.Name, + Description: anthropic.String(info.Description), + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: info.Parameters, + // TODO: figure out how we can tell claude the required fields? }, - }, - ) - if err != nil { - return nil, err - } + } - content := "" - for _, block := range response.Content { - if text, ok := block.AsAny().(anthropic.TextBlock); ok { - content += text.Text + if i == len(tools)-1 && !a.options.disableCache { + toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + } } - } - toolCalls := a.extractToolCalls(response.Content) - tokenUsage := a.extractTokenUsage(response.Usage) + anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return anthropicTools } -func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - anthropicMessages := a.convertToAnthropicMessages(messages) - anthropicTools := a.convertToAnthropicTools(tools) +func (a *anthropicClient) finishReason(reason string) message.FinishReason { + switch reason { + case "end_turn": + return message.FinishReasonEndTurn + case "max_tokens": + return message.FinishReasonMaxTokens + case "tool_use": + return message.FinishReasonToolUse + case "stop_sequence": + return message.FinishReasonEndTurn + default: + return message.FinishReasonUnknown + } +} +func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { var thinkingParam anthropic.ThinkingConfigParamUnion lastMessage := messages[len(messages)-1] + isUser := lastMessage.Role == anthropic.MessageParamRoleUser + messageContent := "" temperature := anthropic.Float(0) - if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") { - thinkingParam = anthropic.ThinkingConfigParamUnion{ - OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ - BudgetTokens: int64(float64(a.maxTokens) * 0.8), - Type: "enabled", - }, + if isUser { + for _, m := range lastMessage.Content { + if m.OfRequestTextBlock != nil && m.OfRequestTextBlock.Text != "" { + messageContent = m.OfRequestTextBlock.Text + } + } + if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) { + thinkingParam = anthropic.ThinkingConfigParamUnion{ + OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{ + BudgetTokens: int64(float64(a.providerOptions.maxTokens) * 0.8), + Type: "enabled", + }, + } + temperature = anthropic.Float(1) } - temperature = anthropic.Float(1) } - eventChan := make(chan ProviderEvent) + return anthropic.MessageNewParams{ + Model: anthropic.Model(a.providerOptions.model.APIModel), + MaxTokens: a.providerOptions.maxTokens, + Temperature: temperature, + Messages: messages, + Tools: tools, + Thinking: thinkingParam, + System: []anthropic.TextBlockParam{ + { + Text: a.providerOptions.systemMessage, + CacheControl: anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + }, + }, + }, + } +} - go func() { - defer close(eventChan) +func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) { + preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(preparedMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + for { + attempts++ + anthropicResponse, err := a.client.Messages.New( + ctx, + preparedMessages, + ) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := a.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr + } - const maxRetries = 8 - attempts := 0 + content := "" + for _, block := range anthropicResponse.Content { + if text, ok := block.AsAny().(anthropic.TextBlock); ok { + content += text.Text + } + } - for { + return &ProviderResponse{ + Content: content, + ToolCalls: a.toolCalls(*anthropicResponse), + Usage: a.usage(*anthropicResponse), + }, nil + } +} +func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(preparedMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 + eventChan := make(chan ProviderEvent) + go func() { + for { attempts++ - - stream := a.client.Messages.NewStreaming( + anthropicStream := a.client.Messages.NewStreaming( ctx, - anthropic.MessageNewParams{ - Model: anthropic.Model(a.model.APIModel), - MaxTokens: a.maxTokens, - Temperature: temperature, - Messages: anthropicMessages, - Tools: anthropicTools, - Thinking: thinkingParam, - System: []anthropic.TextBlockParam{ - { - Text: a.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, - }, - }, + preparedMessages, ) - accumulatedMessage := anthropic.Message{} - for stream.Next() { - event := stream.Current() + for anthropicStream.Next() { + event := anthropicStream.Current() err := accumulatedMessage.Accumulate(event) if err != nil { eventChan <- ProviderEvent{Type: EventError, Error: err} - return // Don't retry on accumulation errors + continue } switch event := event.AsAny().(type) { @@ -211,6 +281,7 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa Content: event.Delta.Text, } } + // TODO: check if we can somehow stream tool calls case anthropic.ContentBlockStopEvent: eventChan <- ProviderEvent{Type: EventContentStop} @@ -223,84 +294,87 @@ func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []messa } } - toolCalls := a.extractToolCalls(accumulatedMessage.Content) - tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage) - eventChan <- ProviderEvent{ Type: EventComplete, Response: &ProviderResponse{ Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - FinishReason: string(accumulatedMessage.StopReason), + ToolCalls: a.toolCalls(accumulatedMessage), + Usage: a.usage(accumulatedMessage), + FinishReason: a.finishReason(string(accumulatedMessage.StopReason)), }, } } } - err := stream.Err() + err := anthropicStream.Err() if err == nil || errors.Is(err, io.EOF) { + close(eventChan) return } - - var apierr *anthropic.Error - if !errors.As(err, &apierr) { - eventChan <- ProviderEvent{Type: EventError, Error: err} - return - } - - if apierr.StatusCode != 429 && apierr.StatusCode != 529 { - eventChan <- ProviderEvent{Type: EventError, Error: err} + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := a.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } - - if attempts > maxRetries { - eventChan <- ProviderEvent{ - Type: EventError, - Error: errors.New("maximum retry attempts reached for rate limit (429)"), - } - return - } - - retryMs := 0 - retryAfterValues := apierr.Response.Header.Values("Retry-After") - if len(retryAfterValues) > 0 { - var retryAfterSec int - if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil { - retryMs = retryAfterSec * 1000 - eventChan <- ProviderEvent{ - Type: EventWarning, - Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec), + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue } - } else { - eventChan <- ProviderEvent{ - Type: EventWarning, - Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries), - } - - backoffMs := 2000 * (1 << (attempts - 1)) - jitterMs := int(float64(backoffMs) * 0.2) - retryMs = backoffMs + jitterMs } - select { - case <-ctx.Done(): + if ctx.Err() != nil { eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} - return - case <-time.After(time.Duration(retryMs) * time.Millisecond): - continue } + close(eventChan) + return } }() + return eventChan +} - return eventChan, nil +func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *anthropic.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + if apierr.StatusCode != 429 && apierr.StatusCode != 529 { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 + } + } + return true, int64(retryMs), nil } -func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall { +func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall { var toolCalls []message.ToolCall - for _, block := range content { + for _, block := range msg.Content { switch variant := block.AsAny().(type) { case anthropic.ToolUseBlock: toolCall := message.ToolCall{ @@ -316,90 +390,33 @@ func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUni return toolCalls } -func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage { +func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage { return TokenUsage{ - InputTokens: usage.InputTokens, - OutputTokens: usage.OutputTokens, - CacheCreationTokens: usage.CacheCreationInputTokens, - CacheReadTokens: usage.CacheReadInputTokens, + InputTokens: msg.Usage.InputTokens, + OutputTokens: msg.Usage.OutputTokens, + CacheCreationTokens: msg.Usage.CacheCreationInputTokens, + CacheReadTokens: msg.Usage.CacheReadInputTokens, } } -func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam { - anthropicTools := make([]anthropic.ToolUnionParam, len(tools)) - - for i, tool := range tools { - info := tool.Info() - toolParam := anthropic.ToolParam{ - Name: info.Name, - Description: anthropic.String(info.Description), - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: info.Parameters, - }, - } - - if i == len(tools)-1 && !a.disableCache { - toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - } - - anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam} +func WithAnthropicBedrock(useBedrock bool) AnthropicOption { + return func(options *anthropicOptions) { + options.useBedrock = useBedrock } - - return anthropicTools } -func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam { - anthropicMessages := make([]anthropic.MessageParam, 0, len(messages)) - cachedBlocks := 0 - - for _, msg := range messages { - switch msg.Role { - case message.User: - content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.disableCache { - content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - cachedBlocks++ - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) - - case message.Assistant: - blocks := []anthropic.ContentBlockParamUnion{} - if msg.Content().String() != "" { - content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.disableCache { - content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - } - cachedBlocks++ - } - blocks = append(blocks, content) - } - - for _, toolCall := range msg.ToolCalls() { - var inputMap map[string]any - err := json.Unmarshal([]byte(toolCall.Input), &inputMap) - if err != nil { - continue - } - blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name)) - } +func WithAnthropicDisableCache() AnthropicOption { + return func(options *anthropicOptions) { + options.disableCache = true + } +} - if len(blocks) > 0 { - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } +func DefaultShouldThinkFn(s string) bool { + return strings.Contains(strings.ToLower(s), "think") +} - case message.Tool: - results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults())) - for i, toolResult := range msg.ToolResults() { - results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError) - } - anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...)) - } +func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption { + return func(options *anthropicOptions) { + options.shouldThink = fn } - - return anthropicMessages } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 677f4676b3333bff3651174e7f7fd3d4d7e3a096..d76925ad10274bd6ded1401e00bbb7954035b122 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -7,33 +7,29 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" ) -type bedrockProvider struct { - childProvider Provider - model models.Model - maxTokens int64 - systemMessage string +type bedrockOptions struct { + // Bedrock specific options can be added here } -func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - return b.childProvider.SendMessages(ctx, messages, tools) -} +type BedrockOption func(*bedrockOptions) -func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - return b.childProvider.StreamResponse(ctx, messages, tools) +type bedrockClient struct { + providerOptions providerClientOptions + options bedrockOptions + childProvider ProviderClient } -func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { - provider := &bedrockProvider{} - for _, opt := range opts { - opt(provider) - } +type BedrockClient ProviderClient + +func newBedrockClient(opts providerClientOptions) BedrockClient { + bedrockOpts := bedrockOptions{} + // Apply bedrock specific options if they are added in the future - // based on the AWS region prefix the model name with, us, eu, ap, sa, etc. + // Get AWS region from environment region := os.Getenv("AWS_REGION") if region == "" { region = os.Getenv("AWS_DEFAULT_REGION") @@ -43,45 +39,62 @@ func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { region = "us-east-1" // default region } if len(region) < 2 { - return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid") + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: nil, // Will cause an error when used + } } + + // Prefix the model name with region regionPrefix := region[:2] - provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel) + modelName := opts.model.APIModel + opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName) - if strings.Contains(string(provider.model.APIModel), "anthropic") { - anthropic, err := NewAnthropicProvider( - WithAnthropicModel(provider.model), - WithAnthropicMaxTokens(provider.maxTokens), - WithAnthropicSystemMessage(provider.systemMessage), - WithAnthropicBedrock(), + // Determine which provider to use based on the model + if strings.Contains(string(opts.model.APIModel), "anthropic") { + // Create Anthropic client with Bedrock configuration + anthropicOpts := opts + anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions, + WithAnthropicBedrock(true), WithAnthropicDisableCache(), ) - provider.childProvider = anthropic - if err != nil { - return nil, err + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: newAnthropicClient(anthropicOpts), } - } else { - return nil, errors.New("unsupported model for bedrock provider") } - return provider, nil -} - -type BedrockOption func(*bedrockProvider) -func WithBedrockSystemMessage(message string) BedrockOption { - return func(a *bedrockProvider) { - a.systemMessage = message + // Return client with nil childProvider if model is not supported + // This will cause an error when used + return &bedrockClient{ + providerOptions: opts, + options: bedrockOpts, + childProvider: nil, } } -func WithBedrockMaxTokens(maxTokens int64) BedrockOption { - return func(a *bedrockProvider) { - a.maxTokens = maxTokens +func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + if b.childProvider == nil { + return nil, errors.New("unsupported model for bedrock provider") } + return b.childProvider.send(ctx, messages, tools) } -func WithBedrockModel(model models.Model) BedrockOption { - return func(a *bedrockProvider) { - a.model = model +func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + eventChan := make(chan ProviderEvent) + + if b.childProvider == nil { + go func() { + eventChan <- ProviderEvent{ + Type: EventError, + Error: errors.New("unsupported model for bedrock provider"), + } + close(eventChan) + }() + return eventChan } -} + + return b.childProvider.stream(ctx, messages, tools) +} \ No newline at end of file diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 2d1db2b64878d0aaeb8fcc7a864fb8485c935a09..804baea281bdd9e609e113fed811a684f2049a81 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -4,80 +4,68 @@ import ( "context" "encoding/json" "errors" + "fmt" + "io" + "strings" + "time" "github.com/google/generative-ai-go/genai" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "google.golang.org/api/iterator" "google.golang.org/api/option" ) -type geminiProvider struct { - client *genai.Client - model models.Model - maxTokens int32 - apiKey string - systemMessage string +type geminiOptions struct { + disableCache bool } -type GeminiOption func(*geminiProvider) +type GeminiOption func(*geminiOptions) -func NewGeminiProvider(ctx context.Context, opts ...GeminiOption) (Provider, error) { - provider := &geminiProvider{ - maxTokens: 5000, - } +type geminiClient struct { + providerOptions providerClientOptions + options geminiOptions + client *genai.Client +} - for _, opt := range opts { - opt(provider) - } +type GeminiClient ProviderClient - if provider.systemMessage == "" { - return nil, errors.New("system message is required") +func newGeminiClient(opts providerClientOptions) GeminiClient { + geminiOpts := geminiOptions{} + for _, o := range opts.geminiOptions { + o(&geminiOpts) } - client, err := genai.NewClient(ctx, option.WithAPIKey(provider.apiKey)) + client, err := genai.NewClient(context.Background(), option.WithAPIKey(opts.apiKey)) if err != nil { - return nil, err - } - provider.client = client - - return provider, nil -} - -func WithGeminiSystemMessage(message string) GeminiOption { - return func(p *geminiProvider) { - p.systemMessage = message + logging.Error("Failed to create Gemini client", "error", err) + return nil } -} -func WithGeminiMaxTokens(maxTokens int32) GeminiOption { - return func(p *geminiProvider) { - p.maxTokens = maxTokens + return &geminiClient{ + providerOptions: opts, + options: geminiOpts, + client: client, } } -func WithGeminiModel(model models.Model) GeminiOption { - return func(p *geminiProvider) { - p.model = model - } -} - -func WithGeminiKey(apiKey string) GeminiOption { - return func(p *geminiProvider) { - p.apiKey = apiKey - } -} +func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content { + var history []*genai.Content -func (p *geminiProvider) Close() { - if p.client != nil { - p.client.Close() - } -} + // Add system message first + history = append(history, &genai.Content{ + Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)}, + Role: "user", + }) -func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*genai.Content { - var history []*genai.Content + // Add a system response to acknowledge the system message + history = append(history, &genai.Content{ + Parts: []genai.Part{genai.Text("I'll help you with that.")}, + Role: "model", + }) for _, msg := range messages { switch msg.Role { @@ -86,6 +74,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g Parts: []genai.Part{genai.Text(msg.Content().String())}, Role: "user", }) + case message.Assistant: content := &genai.Content{ Role: "model", @@ -107,6 +96,7 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g } history = append(history, content) + case message.Tool: for _, result := range msg.ToolResults() { response := map[string]interface{}{"result": result.Content} @@ -114,10 +104,11 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g if err == nil { response = parsed } + var toolCall message.ToolCall - for _, msg := range messages { - if msg.Role == message.Assistant { - for _, call := range msg.ToolCalls() { + for _, m := range messages { + if m.Role == message.Assistant { + for _, call := range m.ToolCalls() { if call.ID == result.ToolCallID { toolCall = call break @@ -140,186 +131,358 @@ func (p *geminiProvider) convertToGeminiHistory(messages []message.Message) []*g return history } -func (p *geminiProvider) extractTokenUsage(resp *genai.GenerateContentResponse) TokenUsage { - if resp == nil || resp.UsageMetadata == nil { - return TokenUsage{} - } +func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool { + geminiTools := make([]*genai.Tool, 0, len(tools)) - return TokenUsage{ - InputTokens: int64(resp.UsageMetadata.PromptTokenCount), - OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), - CacheCreationTokens: 0, // Not directly provided by Gemini - CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), + for _, tool := range tools { + info := tool.Info() + declaration := &genai.FunctionDeclaration{ + Name: info.Name, + Description: info.Description, + Parameters: &genai.Schema{ + Type: genai.TypeObject, + Properties: convertSchemaProperties(info.Parameters), + Required: info.Required, + }, + } + + geminiTools = append(geminiTools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{declaration}, + }) } + + return geminiTools } -func (p *geminiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - model := p.client.GenerativeModel(p.model.APIModel) - model.SetMaxOutputTokens(p.maxTokens) +func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason { + reasonStr := reason.String() + switch { + case reasonStr == "STOP": + return message.FinishReasonEndTurn + case reasonStr == "MAX_TOKENS": + return message.FinishReasonMaxTokens + case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"): + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown + } +} - model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) +func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + model := g.client.GenerativeModel(g.providerOptions.model.APIModel) + model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) + // Convert tools if len(tools) > 0 { - declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - for _, declaration := range declarations { - model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) - } + model.Tools = g.convertTools(tools) } - chat := model.StartChat() - chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message + // Convert messages + geminiMessages := g.convertMessages(messages) - lastUserMsg := messages[len(messages)-1] - resp, err := chat.SendMessage(ctx, genai.Text(lastUserMsg.Content().String())) - if err != nil { - return nil, err + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(geminiMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) } - var content string - var toolCalls []message.ToolCall + attempts := 0 + for { + attempts++ + chat := model.StartChat() + chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message + + lastMsg := geminiMessages[len(geminiMessages)-1] + var lastText string + for _, part := range lastMsg.Parts { + if text, ok := part.(genai.Text); ok { + lastText = string(text) + break + } + } - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - content = string(p) - case genai.FunctionCall: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) - toolCalls = append(toolCalls, message.ToolCall{ - ID: id, - Name: p.Name, - Input: string(args), - Type: "function", - }) + resp, err := chat.SendMessage(ctx, genai.Text(lastText)) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := g.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } - } - tokenUsage := p.extractTokenUsage(resp) + content := "" + var toolCalls []message.ToolCall + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + switch p := part.(type) { + case genai.Text: + content = string(p) + case genai.FunctionCall: + id := "call_" + uuid.New().String() + args, _ := json.Marshal(p.Args) + toolCalls = append(toolCalls, message.ToolCall{ + ID: id, + Name: p.Name, + Input: string(args), + Type: "function", + }) + } + } + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return &ProviderResponse{ + Content: content, + ToolCalls: toolCalls, + Usage: g.usage(resp), + FinishReason: g.finishReason(resp.Candidates[0].FinishReason), + }, nil + } } -func (p *geminiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - model := p.client.GenerativeModel(p.model.APIModel) - model.SetMaxOutputTokens(p.maxTokens) - - model.SystemInstruction = genai.NewUserContent(genai.Text(p.systemMessage)) +func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + model := g.client.GenerativeModel(g.providerOptions.model.APIModel) + model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens)) + // Convert tools if len(tools) > 0 { - declarations := p.convertToolsToGeminiFunctionDeclarations(tools) - for _, declaration := range declarations { - model.Tools = append(model.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{declaration}}) - } + model.Tools = g.convertTools(tools) } - chat := model.StartChat() - chat.History = p.convertToGeminiHistory(messages[:len(messages)-1]) // Exclude last message + // Convert messages + geminiMessages := g.convertMessages(messages) - lastUserMsg := messages[len(messages)-1] - - iter := chat.SendMessageStream(ctx, genai.Text(lastUserMsg.Content().String())) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(geminiMessages) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 eventChan := make(chan ProviderEvent) go func() { defer close(eventChan) - var finalResp *genai.GenerateContentResponse - currentContent := "" - toolCalls := []message.ToolCall{} - for { - resp, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - eventChan <- ProviderEvent{ - Type: EventError, - Error: err, + attempts++ + chat := model.StartChat() + chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message + + lastMsg := geminiMessages[len(geminiMessages)-1] + var lastText string + for _, part := range lastMsg.Parts { + if text, ok := part.(genai.Text); ok { + lastText = string(text) + break } - return } - finalResp = resp + iter := chat.SendMessageStream(ctx, genai.Text(lastText)) - if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { - for _, part := range resp.Candidates[0].Content.Parts { - switch p := part.(type) { - case genai.Text: - newText := string(p) - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: newText, - } - currentContent += newText - case genai.FunctionCall: - id := "call_" + uuid.New().String() - args, _ := json.Marshal(p.Args) - newCall := message.ToolCall{ - ID: id, - Name: p.Name, - Input: string(args), - Type: "function", - } + currentContent := "" + toolCalls := []message.ToolCall{} + var finalResp *genai.GenerateContentResponse - isNew := true - for _, existing := range toolCalls { - if existing.Name == newCall.Name && existing.Input == newCall.Input { - isNew = false - break + eventChan <- ProviderEvent{Type: EventContentStart} + + for { + resp, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + retry, after, retryErr := g.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + return + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} } + + return + case <-time.After(time.Duration(after) * time.Millisecond): + break } + } else { + eventChan <- ProviderEvent{Type: EventError, Error: err} + return + } + } + + finalResp = resp + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + switch p := part.(type) { + case genai.Text: + newText := string(p) + delta := newText[len(currentContent):] + if delta != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: delta, + } + currentContent = newText + } + case genai.FunctionCall: + id := "call_" + uuid.New().String() + args, _ := json.Marshal(p.Args) + newCall := message.ToolCall{ + ID: id, + Name: p.Name, + Input: string(args), + Type: "function", + } - if isNew { - toolCalls = append(toolCalls, newCall) + isNew := true + for _, existing := range toolCalls { + if existing.Name == newCall.Name && existing.Input == newCall.Input { + isNew = false + break + } + } + + if isNew { + toolCalls = append(toolCalls, newCall) + } } } } } - } - tokenUsage := p.extractTokenUsage(finalResp) + eventChan <- ProviderEvent{Type: EventContentStop} - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, - FinishReason: string(finalResp.Candidates[0].FinishReason.String()), - }, + if finalResp != nil { + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: g.usage(finalResp), + FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason), + }, + } + return + } + + // If we get here, we need to retry + if attempts > maxRetries { + eventChan <- ProviderEvent{ + Type: EventError, + Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries), + } + return + } + + // Wait before retrying + select { + case <-ctx.Done(): + if ctx.Err() != nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + return + case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond): + continue + } } }() - return eventChan, nil + return eventChan } -func (p *geminiProvider) convertToolsToGeminiFunctionDeclarations(tools []tools.BaseTool) []*genai.FunctionDeclaration { - declarations := make([]*genai.FunctionDeclaration, len(tools)) +func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) { + // Check if error is a rate limit error + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } - for i, tool := range tools { - info := tool.Info() - declarations[i] = &genai.FunctionDeclaration{ - Name: info.Name, - Description: info.Description, - Parameters: &genai.Schema{ - Type: genai.TypeObject, - Properties: convertSchemaProperties(info.Parameters), - Required: info.Required, - }, + // Gemini doesn't have a standard error type we can check against + // So we'll check the error message for rate limit indicators + if errors.Is(err, io.EOF) { + return false, 0, err + } + + errMsg := err.Error() + isRateLimit := false + + // Check for common rate limit error messages + if contains(errMsg, "rate limit", "quota exceeded", "too many requests") { + isRateLimit = true + } + + if !isRateLimit { + return false, 0, err + } + + // Calculate backoff with jitter + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs := backoffMs + jitterMs + + return true, int64(retryMs), nil +} + +func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + if funcCall, ok := part.(genai.FunctionCall); ok { + id := "call_" + uuid.New().String() + args, _ := json.Marshal(funcCall.Args) + toolCalls = append(toolCalls, message.ToolCall{ + ID: id, + Name: funcCall.Name, + Input: string(args), + Type: "function", + }) + } } } - return declarations + return toolCalls +} + +func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage { + if resp == nil || resp.UsageMetadata == nil { + return TokenUsage{} + } + + return TokenUsage{ + InputTokens: int64(resp.UsageMetadata.PromptTokenCount), + OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount), + CacheCreationTokens: 0, // Not directly provided by Gemini + CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount), + } +} + +func WithGeminiDisableCache() GeminiOption { + return func(options *geminiOptions) { + options.disableCache = true + } +} + +// Helper functions +func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { + var result map[string]interface{} + err := json.Unmarshal([]byte(jsonStr), &result) + return result, err } func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema { @@ -396,8 +559,12 @@ func mapJSONTypeToGenAI(jsonType string) genai.Type { } } -func parseJsonToMap(jsonStr string) (map[string]interface{}, error) { - var result map[string]interface{} - err := json.Unmarshal([]byte(jsonStr), &result) - return result, err +func contains(s string, substrs ...string) bool { + for _, substr := range substrs { + if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) { + return true + } + } + return false } + diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index dbfde3fa88aaf67b22c02c5595d208ffff3ee4b1..9c2ad201263d0daf182d6feda8c6351d764f896e 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -2,89 +2,65 @@ package provider import ( "context" + "encoding/json" "errors" + "fmt" + "io" + "time" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" ) -type openaiProvider struct { - client openai.Client - model models.Model - maxTokens int64 - baseURL string - apiKey string - systemMessage string +type openaiOptions struct { + baseURL string + disableCache bool } -type OpenAIOption func(*openaiProvider) +type OpenAIOption func(*openaiOptions) -func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) { - provider := &openaiProvider{ - maxTokens: 5000, - } - - for _, opt := range opts { - opt(provider) - } - - clientOpts := []option.RequestOption{ - option.WithAPIKey(provider.apiKey), - } - if provider.baseURL != "" { - clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL)) - } - - provider.client = openai.NewClient(clientOpts...) - if provider.systemMessage == "" { - return nil, errors.New("system message is required") - } - - return provider, nil +type openaiClient struct { + providerOptions providerClientOptions + options openaiOptions + client openai.Client } -func WithOpenAISystemMessage(message string) OpenAIOption { - return func(p *openaiProvider) { - p.systemMessage = message - } -} +type OpenAIClient ProviderClient -func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption { - return func(p *openaiProvider) { - p.maxTokens = maxTokens +func newOpenAIClient(opts providerClientOptions) OpenAIClient { + openaiOpts := openaiOptions{} + for _, o := range opts.openaiOptions { + o(&openaiOpts) } -} -func WithOpenAIModel(model models.Model) OpenAIOption { - return func(p *openaiProvider) { - p.model = model + openaiClientOptions := []option.RequestOption{} + if opts.apiKey != "" { + openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } -} - -func WithOpenAIBaseURL(baseURL string) OpenAIOption { - return func(p *openaiProvider) { - p.baseURL = baseURL + if openaiOpts.baseURL != "" { + openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL)) } -} -func WithOpenAIKey(apiKey string) OpenAIOption { - return func(p *openaiProvider) { - p.apiKey = apiKey + client := openai.NewClient(openaiClientOptions...) + return &openaiClient{ + providerOptions: opts, + options: openaiOpts, + client: client, } } -func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion { - var chatMessages []openai.ChatCompletionMessageParamUnion - - chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage)) +func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { + // Add system message first + openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage)) for _, msg := range messages { switch msg.Role { case message.User: - chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String())) + openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String())) case message.Assistant: assistantMsg := openai.ChatCompletionAssistantMessageParam{ @@ -111,23 +87,23 @@ func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []o } } - chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{ + openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{ OfAssistant: &assistantMsg, }) case message.Tool: for _, result := range msg.ToolResults() { - chatMessages = append(chatMessages, + openaiMessages = append(openaiMessages, openai.ToolMessage(result.Content, result.ToolCallID), ) } } } - return chatMessages + return } -func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { +func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam { openaiTools := make([]openai.ChatCompletionToolParam, len(tools)) for i, tool := range tools { @@ -148,133 +124,238 @@ func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.C return openaiTools } -func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage { - cachedTokens := int64(0) - - cachedTokens = usage.PromptTokensDetails.CachedTokens - inputTokens := usage.PromptTokens - cachedTokens - - return TokenUsage{ - InputTokens: inputTokens, - OutputTokens: usage.CompletionTokens, - CacheCreationTokens: 0, // OpenAI doesn't provide this directly - CacheReadTokens: cachedTokens, +func (o *openaiClient) finishReason(reason string) message.FinishReason { + switch reason { + case "stop": + return message.FinishReasonEndTurn + case "length": + return message.FinishReasonMaxTokens + case "tool_calls": + return message.FinishReasonToolUse + default: + return message.FinishReasonUnknown } } -func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { - messages = cleanupMessages(messages) - chatMessages := p.convertToOpenAIMessages(messages) - openaiTools := p.convertToOpenAITools(tools) - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(p.model.APIModel), - Messages: chatMessages, - MaxTokens: openai.Int(p.maxTokens), - Tools: openaiTools, - } - - response, err := p.client.Chat.Completions.New(ctx, params) - if err != nil { - return nil, err +func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { + return openai.ChatCompletionNewParams{ + Model: openai.ChatModel(o.providerOptions.model.APIModel), + Messages: messages, + MaxTokens: openai.Int(o.providerOptions.maxTokens), + Tools: tools, } +} - content := "" - if response.Choices[0].Message.Content != "" { - content = response.Choices[0].Message.Content +func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { + params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) } - - var toolCalls []message.ToolCall - if len(response.Choices[0].Message.ToolCalls) > 0 { - toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls)) - for i, call := range response.Choices[0].Message.ToolCalls { - toolCalls[i] = message.ToolCall{ - ID: call.ID, - Name: call.Function.Name, - Input: call.Function.Arguments, - Type: "function", + attempts := 0 + for { + attempts++ + openaiResponse, err := o.client.Chat.Completions.New( + ctx, + params, + ) + // If there is an error we are going to see if we can retry the call + if err != nil { + retry, after, retryErr := o.shouldRetry(attempts, err) + if retryErr != nil { + return nil, retryErr } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + return nil, retryErr } - } - tokenUsage := p.extractTokenUsage(response.Usage) + content := "" + if openaiResponse.Choices[0].Message.Content != "" { + content = openaiResponse.Choices[0].Message.Content + } - return &ProviderResponse{ - Content: content, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, nil + return &ProviderResponse{ + Content: content, + ToolCalls: o.toolCalls(*openaiResponse), + Usage: o.usage(*openaiResponse), + FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)), + }, nil + } } -func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { - messages = cleanupMessages(messages) - chatMessages := p.convertToOpenAIMessages(messages) - openaiTools := p.convertToOpenAITools(tools) - - params := openai.ChatCompletionNewParams{ - Model: openai.ChatModel(p.model.APIModel), - Messages: chatMessages, - MaxTokens: openai.Int(p.maxTokens), - Tools: openaiTools, - StreamOptions: openai.ChatCompletionStreamOptionsParam{ - IncludeUsage: openai.Bool(true), - }, +func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), } - stream := p.client.Chat.Completions.NewStreaming(ctx, params) + cfg := config.Get() + if cfg.Debug { + jsonData, _ := json.Marshal(params) + logging.Debug("Prepared messages", "messages", string(jsonData)) + } + attempts := 0 eventChan := make(chan ProviderEvent) - toolCalls := make([]message.ToolCall, 0) go func() { - defer close(eventChan) - - acc := openai.ChatCompletionAccumulator{} - currentContent := "" - - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - - if tool, ok := acc.JustFinishedToolCall(); ok { - toolCalls = append(toolCalls, message.ToolCall{ - ID: tool.Id, - Name: tool.Name, - Input: tool.Arguments, - Type: "function", - }) - } + for { + attempts++ + openaiStream := o.client.Chat.Completions.NewStreaming( + ctx, + params, + ) + + acc := openai.ChatCompletionAccumulator{} + currentContent := "" + toolCalls := make([]message.ToolCall, 0) + + for openaiStream.Next() { + chunk := openaiStream.Current() + acc.AddChunk(chunk) + + if tool, ok := acc.JustFinishedToolCall(); ok { + toolCalls = append(toolCalls, message.ToolCall{ + ID: tool.Id, + Name: tool.Name, + Input: tool.Arguments, + Type: "function", + }) + } - for _, choice := range chunk.Choices { - if choice.Delta.Content != "" { - eventChan <- ProviderEvent{ - Type: EventContentDelta, - Content: choice.Delta.Content, + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + eventChan <- ProviderEvent{ + Type: EventContentDelta, + Content: choice.Delta.Content, + } + currentContent += choice.Delta.Content } - currentContent += choice.Delta.Content } } - } - if err := stream.Err(); err != nil { - eventChan <- ProviderEvent{ - Type: EventError, - Error: err, + err := openaiStream.Err() + if err == nil || errors.Is(err, io.EOF) { + // Stream completed successfully + eventChan <- ProviderEvent{ + Type: EventComplete, + Response: &ProviderResponse{ + Content: currentContent, + ToolCalls: toolCalls, + Usage: o.usage(acc.ChatCompletion), + FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)), + }, + } + close(eventChan) + return } + + // If there is an error we are going to see if we can retry the call + retry, after, retryErr := o.shouldRetry(attempts, err) + if retryErr != nil { + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) + return + } + if retry { + logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100)) + select { + case <-ctx.Done(): + // context cancelled + if ctx.Err() == nil { + eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()} + } + close(eventChan) + return + case <-time.After(time.Duration(after) * time.Millisecond): + continue + } + } + eventChan <- ProviderEvent{Type: EventError, Error: retryErr} + close(eventChan) return } + }() - tokenUsage := p.extractTokenUsage(acc.Usage) + return eventChan +} - eventChan <- ProviderEvent{ - Type: EventComplete, - Response: &ProviderResponse{ - Content: currentContent, - ToolCalls: toolCalls, - Usage: tokenUsage, - }, +func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { + var apierr *openai.Error + if !errors.As(err, &apierr) { + return false, 0, err + } + + if apierr.StatusCode != 429 && apierr.StatusCode != 500 { + return false, 0, err + } + + if attempts > maxRetries { + return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) + } + + retryMs := 0 + retryAfterValues := apierr.Response.Header.Values("Retry-After") + + backoffMs := 2000 * (1 << (attempts - 1)) + jitterMs := int(float64(backoffMs) * 0.2) + retryMs = backoffMs + jitterMs + if len(retryAfterValues) > 0 { + if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil { + retryMs = retryMs * 1000 } - }() + } + return true, int64(retryMs), nil +} - return eventChan, nil +func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall { + var toolCalls []message.ToolCall + + if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { + for _, call := range completion.Choices[0].Message.ToolCalls { + toolCall := message.ToolCall{ + ID: call.ID, + Name: call.Function.Name, + Input: call.Function.Arguments, + Type: "function", + } + toolCalls = append(toolCalls, toolCall) + } + } + + return toolCalls } + +func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage { + cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens + inputTokens := completion.Usage.PromptTokens - cachedTokens + + return TokenUsage{ + InputTokens: inputTokens, + OutputTokens: completion.Usage.CompletionTokens, + CacheCreationTokens: 0, // OpenAI doesn't provide this directly + CacheReadTokens: cachedTokens, + } +} + +func WithOpenAIBaseURL(baseURL string) OpenAIOption { + return func(options *openaiOptions) { + options.baseURL = baseURL + } +} + +func WithOpenAIDisableCache() OpenAIOption { + return func(options *openaiOptions) { + options.disableCache = true + } +} + diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 34d91f2b771dd918ff6432bb9aff5f12d667e0e0..1a5b3dc8ace7f2b363761c9defd37a53407f42d4 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -2,14 +2,17 @@ package provider import ( "context" + "fmt" + "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/message" ) -// EventType represents the type of streaming event type EventType string +const maxRetries = 8 + const ( EventContentStart EventType = "content_start" EventContentDelta EventType = "content_delta" @@ -18,7 +21,6 @@ const ( EventComplete EventType = "complete" EventError EventType = "error" EventWarning EventType = "warning" - EventInfo EventType = "info" ) type TokenUsage struct { @@ -32,61 +34,152 @@ type ProviderResponse struct { Content string ToolCalls []message.ToolCall Usage TokenUsage - FinishReason string + FinishReason message.FinishReason } type ProviderEvent struct { - Type EventType + Type EventType + Content string Thinking string - ToolCall *message.ToolCall - Error error Response *ProviderResponse - // Used for giving users info on e.x retry - Info string + Error error } - type Provider interface { SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) - StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) + StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent + + Model() models.Model +} + +type providerClientOptions struct { + apiKey string + model models.Model + maxTokens int64 + systemMessage string + + anthropicOptions []AnthropicOption + openaiOptions []OpenAIOption + geminiOptions []GeminiOption + bedrockOptions []BedrockOption +} + +type ProviderClientOption func(*providerClientOptions) + +type ProviderClient interface { + send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) + stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent +} + +type baseProvider[C ProviderClient] struct { + options providerClientOptions + client C +} + +func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) { + clientOptions := providerClientOptions{} + for _, o := range opts { + o(&clientOptions) + } + switch providerName { + case models.ProviderAnthropic: + return &baseProvider[AnthropicClient]{ + options: clientOptions, + client: newAnthropicClient(clientOptions), + }, nil + case models.ProviderOpenAI: + return &baseProvider[OpenAIClient]{ + options: clientOptions, + client: newOpenAIClient(clientOptions), + }, nil + case models.ProviderGemini: + return &baseProvider[GeminiClient]{ + options: clientOptions, + client: newGeminiClient(clientOptions), + }, nil + case models.ProviderBedrock: + return &baseProvider[BedrockClient]{ + options: clientOptions, + client: newBedrockClient(clientOptions), + }, nil + case models.ProviderMock: + // TODO: implement mock client for test + panic("not implemented") + } + return nil, fmt.Errorf("provider not supported: %s", providerName) } -func cleanupMessages(messages []message.Message) []message.Message { - // First pass: filter out canceled messages - var cleanedMessages []message.Message +func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) { for _, msg := range messages { - if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 { - // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been - // cancelled - cleanedMessages = append(cleanedMessages, msg) + // The message has no content + if len(msg.Parts) == 0 { + continue } + cleaned = append(cleaned, msg) } + return +} - // Second pass: filter out tool messages without a corresponding tool call - var result []message.Message - toolMessageIDs := make(map[string]bool) +func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + messages = p.cleanMessages(messages) + return p.client.send(ctx, messages, tools) +} - for _, msg := range cleanedMessages { - if msg.Role == message.Assistant { - for _, toolCall := range msg.ToolCalls() { - toolMessageIDs[toolCall.ID] = true // Mark as referenced - } - } +func (p *baseProvider[C]) Model() models.Model { + return p.options.model +} + +func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { + messages = p.cleanMessages(messages) + return p.client.stream(ctx, messages, tools) +} + +func WithAPIKey(apiKey string) ProviderClientOption { + return func(options *providerClientOptions) { + options.apiKey = apiKey } +} - // Keep only messages that aren't unreferenced tool messages - for _, msg := range cleanedMessages { - if msg.Role == message.Tool { - for _, toolCall := range msg.ToolResults() { - if referenced, exists := toolMessageIDs[toolCall.ToolCallID]; exists && referenced { - result = append(result, msg) - } - } - } else { - result = append(result, msg) - } +func WithModel(model models.Model) ProviderClientOption { + return func(options *providerClientOptions) { + options.model = model + } +} + +func WithMaxTokens(maxTokens int64) ProviderClientOption { + return func(options *providerClientOptions) { + options.maxTokens = maxTokens + } +} + +func WithSystemMessage(systemMessage string) ProviderClientOption { + return func(options *providerClientOptions) { + options.systemMessage = systemMessage + } +} + +func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.anthropicOptions = anthropicOptions + } +} + +func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.openaiOptions = openaiOptions + } +} + +func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.geminiOptions = geminiOptions + } +} + +func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption { + return func(options *providerClientOptions) { + options.bedrockOptions = bedrockOptions } - return result } diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 0cea20878731b24902796dba7efeace9f2d91312..c7c970e5a1a9b10b3905d51701414bd464be4247 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -23,7 +23,8 @@ type BashPermissionsParams struct { } type BashResponseMetadata struct { - Took int64 `json:"took"` + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` } type bashTool struct { permissions permission.Service @@ -282,7 +283,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return ToolResponse{}, fmt.Errorf("error executing command: %w", err) } - took := time.Since(startTime).Milliseconds() stdout = truncateOutput(stdout) stderr = truncateOutput(stderr) @@ -311,7 +311,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } metadata := BashResponseMetadata{ - Took: took, + StartTime: startTime.UnixMilli(), + EndTime: time.Now().UnixMilli(), } if stdout == "" { return WithResponseMetadata(NewTextResponse("no output"), metadata), nil diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go index 97be3683aa243f9c9b74fc84503e973959ac9e8a..dafb0ccc5fe4b7f612c451755b09cb70f17ac388 100644 --- a/internal/llm/tools/bash_test.go +++ b/internal/llm/tools/bash_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -340,32 +338,3 @@ func TestCountLines(t *testing.T) { }) } } - -// Mock permission service for testing -type mockPermissionService struct { - *pubsub.Broker[permission.PermissionRequest] - allow bool -} - -func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { - return m.allow -} - -func newMockPermissionService(allow bool) permission.Service { - return &mockPermissionService{ - Broker: pubsub.NewBroker[permission.PermissionRequest](), - allow: allow, - } -} diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 08d6d446c9320b7da3e60040ba61d84022f3affb..148e7aba7a78b11df9eeea9b3ac57658e15dc070 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -11,6 +11,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -35,6 +36,7 @@ type EditResponseMetadata struct { type editTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } const ( @@ -88,10 +90,11 @@ When making edits: Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.` ) -func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &editTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -153,6 +156,11 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if err != nil { return response, nil } + if response.IsError { + // Return early if there was an error during content replacement + // This prevents unnecessary LSP diagnostics processing + return response, nil + } waitForLspDiagnostics(ctx, params.FilePath, e.lspClients) text := fmt.Sprintf("\n%s\n\n", response.Content) @@ -208,6 +216,20 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // File can't be in the history so we create a new file history + _, err = e.files.Create(ctx, sessionID, filePath, "") + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + + // Add the new content to the file history + _, err = e.files.CreateVersion(ctx, sessionID, filePath, content) + if err != nil { + // Log error but don't fail the operation + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -298,6 +320,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string if err != nil { return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, "") + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) @@ -356,6 +401,9 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] + if oldContent == newContent { + return NewTextErrorResponse("new content is the same as old content. No changes made."), nil + } sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { @@ -374,8 +422,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, - - Diff: diff, + Diff: diff, }, }, ) @@ -388,6 +435,28 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) } + // Check if file exists in history + file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = e.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index 48a34ed75c2261f022cbed1cd0e08a6f1949d642..0971775ddb508f50e1bbf3a4cf4a8f372f419dbe 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -14,7 +14,7 @@ import ( ) func TestEditTool_Info(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, EditToolName, info.Name) @@ -34,7 +34,7 @@ func TestEditTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -64,7 +64,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -94,7 +94,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file that already exists", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -123,7 +123,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("fails to create file when path is a directory", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -151,7 +151,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("replaces content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "replace_content.txt") @@ -191,7 +191,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("deletes content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "delete_content.txt") @@ -230,7 +230,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: EditToolName, @@ -243,7 +243,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := EditParams{ FilePath: "", @@ -265,7 +265,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not found", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "non_existent_file.txt") params := EditParams{ @@ -288,7 +288,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles old_string not found in file", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "content_not_found.txt") @@ -320,7 +320,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles multiple occurrences of old_string", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file with duplicate content filePath := filepath.Join(tempDir, "duplicate_content.txt") @@ -352,7 +352,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file modified since last read", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -394,7 +394,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles file not read before editing", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "not_read_file.txt") @@ -423,7 +423,7 @@ func TestEditTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "permission_denied.txt") diff --git a/internal/llm/tools/file.go b/internal/llm/tools/file.go index 9c9707c9c3f387b59e7f8a528344dd2879e1ee71..7f34fdc1f615031decf00706c58aac37a235b57e 100644 --- a/internal/llm/tools/file.go +++ b/internal/llm/tools/file.go @@ -3,8 +3,6 @@ package tools import ( "sync" "time" - - "github.com/kujtimiihoxha/termai/internal/config" ) // File record to track when files were read/written @@ -19,14 +17,6 @@ var ( fileRecordMutex sync.RWMutex ) -func removeWorkingDirectoryPrefix(path string) string { - wd := config.WorkingDirectory() - if len(path) > len(wd) && path[:len(wd)] == wd { - return path[len(wd)+1:] - } - return path -} - func recordFileRead(path string) { fileRecordMutex.Lock() defer fileRecordMutex.Unlock() diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index bdfc23b4ababde88978b3888878f4265fd52eae6..7b4fb11870aa1f2c79755c5f8de30a7e49ead11b 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -63,7 +63,7 @@ type GlobParams struct { Path string `json:"path"` } -type GlobMetadata struct { +type GlobResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -124,7 +124,7 @@ func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GlobMetadata{ + GlobResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 7e52821d07bc6dbc71acdf6d7a5eeb42d2bc2dc1..19333f50b9ac0820588fda90413488c3d99e3aa4 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -27,7 +27,7 @@ type grepMatch struct { modTime time.Time } -type GrepMetadata struct { +type GrepResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } @@ -134,7 +134,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return WithResponseMetadata( NewTextResponse(output), - GrepMetadata{ + GrepResponseMetadata{ NumberOfMatches: len(matches), Truncated: truncated, }, diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a679f261b1ae28b3551245a8e5d25f76185ef14d..a63bf0eebfb98869e28b02b1e4bb1c31c81fbd3d 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -23,7 +23,7 @@ type TreeNode struct { Children []*TreeNode `json:"children,omitempty"` } -type LSMetadata struct { +type LSResponseMetadata struct { NumberOfFiles int `json:"number_of_files"` Truncated bool `json:"truncated"` } @@ -121,7 +121,7 @@ func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { return WithResponseMetadata( NewTextResponse(output), - LSMetadata{ + LSResponseMetadata{ NumberOfFiles: len(files), Truncated: truncated, }, diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go new file mode 100644 index 0000000000000000000000000000000000000000..321f09ac1ab00b8c413176db72cba676c4c45dd1 --- /dev/null +++ b/internal/llm/tools/mocks_test.go @@ -0,0 +1,246 @@ +package tools + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "github.com/kujtimiihoxha/termai/internal/history" + "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/termai/internal/pubsub" +) + +// Mock permission service for testing +type mockPermissionService struct { + *pubsub.Broker[permission.PermissionRequest] + allow bool +} + +func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { + // Not needed for tests +} + +func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { + return m.allow +} + +func newMockPermissionService(allow bool) permission.Service { + return &mockPermissionService{ + Broker: pubsub.NewBroker[permission.PermissionRequest](), + allow: allow, + } +} + +type mockFileHistoryService struct { + *pubsub.Broker[history.File] + files map[string]history.File // ID -> File + timeNow func() int64 +} + +// Create implements history.Service. +func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion) +} + +// CreateVersion implements history.Service. +func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) { + var files []history.File + for _, file := range m.files { + if file.Path == path { + files = append(files, file) + } + } + + if len(files) == 0 { + // No previous versions, create initial + return m.Create(ctx, sessionID, path, content) + } + + // Sort files by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + // Get the latest version + latestFile := files[0] + latestVersion := latestFile.Version + + // Generate the next version + var nextVersion string + if latestVersion == history.InitialVersion { + nextVersion = "v1" + } else if strings.HasPrefix(latestVersion, "v") { + versionNum, err := strconv.Atoi(latestVersion[1:]) + if err != nil { + // If we can't parse the version, just use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } else { + nextVersion = fmt.Sprintf("v%d", versionNum+1) + } + } else { + // If the version format is unexpected, use a timestamp-based version + nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) + } + + return m.createWithVersion(ctx, sessionID, path, content, nextVersion) +} + +func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) { + now := m.timeNow() + file := history.File{ + ID: uuid.New().String(), + SessionID: sessionID, + Path: path, + Content: content, + Version: version, + CreatedAt: now, + UpdatedAt: now, + } + + m.files[file.ID] = file + m.Publish(pubsub.CreatedEvent, file) + return file, nil +} + +// Delete implements history.Service. +func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error { + file, ok := m.files[id] + if !ok { + return fmt.Errorf("file not found: %s", id) + } + + delete(m.files, id) + m.Publish(pubsub.DeletedEvent, file) + return nil +} + +// DeleteSessionFiles implements history.Service. +func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error { + files, err := m.ListBySession(ctx, sessionID) + if err != nil { + return err + } + + for _, file := range files { + err = m.Delete(ctx, file.ID) + if err != nil { + return err + } + } + + return nil +} + +// Get implements history.Service. +func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) { + file, ok := m.files[id] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", id) + } + return file, nil +} + +// GetByPathAndSession implements history.Service. +func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) { + var latestFile history.File + var found bool + var latestTime int64 + + for _, file := range m.files { + if file.Path == path && file.SessionID == sessionID { + if !found || file.CreatedAt > latestTime { + latestFile = file + latestTime = file.CreatedAt + found = true + } + } + } + + if !found { + return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID) + } + return latestFile, nil +} + +// ListBySession implements history.Service. +func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) { + var files []history.File + for _, file := range m.files { + if file.SessionID == sessionID { + files = append(files, file) + } + } + + // Sort by CreatedAt in descending order + sort.Slice(files, func(i, j int) bool { + return files[i].CreatedAt > files[j].CreatedAt + }) + + return files, nil +} + +// ListLatestSessionFiles implements history.Service. +func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) { + // Map to track the latest file for each path + latestFiles := make(map[string]history.File) + + for _, file := range m.files { + if file.SessionID == sessionID { + existing, ok := latestFiles[file.Path] + if !ok || file.CreatedAt > existing.CreatedAt { + latestFiles[file.Path] = file + } + } + } + + // Convert map to slice + var result []history.File + for _, file := range latestFiles { + result = append(result, file) + } + + // Sort by CreatedAt in descending order + sort.Slice(result, func(i, j int) bool { + return result[i].CreatedAt > result[j].CreatedAt + }) + + return result, nil +} + +// Subscribe implements history.Service. +func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] { + return m.Broker.Subscribe(ctx) +} + +// Update implements history.Service. +func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) { + _, ok := m.files[file.ID] + if !ok { + return history.File{}, fmt.Errorf("file not found: %s", file.ID) + } + + file.UpdatedAt = m.timeNow() + m.files[file.ID] = file + m.Publish(pubsub.UpdatedEvent, file) + return file, nil +} + +func newMockFileHistoryService() history.Service { + return &mockFileHistoryService{ + Broker: pubsub.NewBroker[history.File](), + files: make(map[string]history.File), + timeNow: func() int64 { return time.Now().Unix() }, + } +} diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 64592f67da8db74ba990b97e8dadd354778f44a1..4a776478ab67bb4031a698b56beaddfb9734fa1d 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -83,11 +83,21 @@ func newPersistentShell(cwd string) *PersistentShell { commandQueue: make(chan *commandExecution, 10), } - go shell.processCommands() + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r) + shell.isAlive = false + close(shell.commandQueue) + } + }() + shell.processCommands() + }() go func() { err := cmd.Wait() if err != nil { + // Log the error if needed } shell.isAlive = false close(shell.commandQueue) diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index 17bc610ea8fe2597ba8a210efa0f0e5a82c620bd..a6f2c8afb578c4fce7c29bcf1ffabfffdba75d71 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -18,7 +18,7 @@ type SourcegraphParams struct { Timeout int `json:"timeout,omitempty"` } -type SourcegraphMetadata struct { +type SourcegraphResponseMetadata struct { NumberOfMatches int `json:"number_of_matches"` Truncated bool `json:"truncated"` } diff --git a/internal/llm/tools/tools.go b/internal/llm/tools/tools.go index 07afe1363fc114ca17daf0211d656f4b21f2723d..bf0f8df0bacdc34fc6a3f5f89f3591b8b6fd8f15 100644 --- a/internal/llm/tools/tools.go +++ b/internal/llm/tools/tools.go @@ -14,12 +14,17 @@ type ToolInfo struct { type toolResponseType string +type ( + sessionIDContextKey string + messageIDContextKey string +) + const ( ToolResponseTypeText toolResponseType = "text" ToolResponseTypeImage toolResponseType = "image" - SessionIDContextKey = "session_id" - MessageIDContextKey = "message_id" + SessionIDContextKey sessionIDContextKey = "session_id" + MessageIDContextKey messageIDContextKey = "message_id" ) type ToolResponse struct { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 889561d2af4d5e0f428deb7562535fd4647f2331..bb49381fd4423f70cad38dfdbc99197c765c3080 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -10,6 +10,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/lsp" "github.com/kujtimiihoxha/termai/internal/permission" ) @@ -27,6 +28,7 @@ type WritePermissionsParams struct { type writeTool struct { lspClients map[string]*lsp.Client permissions permission.Service + files history.Service } type WriteResponseMetadata struct { @@ -67,10 +69,11 @@ TIPS: - Always include descriptive comments when making changes to existing code` ) -func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service) BaseTool { +func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { return &writeTool{ lspClients: lspClients, permissions: permissions, + files: files, } } @@ -176,6 +179,28 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, fmt.Errorf("error writing file: %w", err) } + // Check if file exists in history + file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID) + if err != nil { + _, err = w.files.Create(ctx, sessionID, filePath, oldContent) + if err != nil { + // Log error but don't fail the operation + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User Manually changed the content store an intermediate version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + recordFileWrite(filePath) recordFileRead(filePath) waitForLspDiagnostics(ctx, filePath, w.lspClients) diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 50dafc14f0f8a42b1a848a3f5c1d9b530a59e196..2264f36fb8aff7537095d751f5b847159d9743c2 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -14,7 +14,7 @@ import ( ) func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) info := tool.Info() assert.Equal(t, WriteToolName, info.Name) @@ -32,7 +32,7 @@ func TestWriteTool_Run(t *testing.T) { defer os.RemoveAll(tempDir) t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "new_file.txt") content := "This is a test content" @@ -61,7 +61,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") content := "Content in nested directory" @@ -90,7 +90,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("updates existing file", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file first filePath := filepath.Join(tempDir, "existing_file.txt") @@ -127,7 +127,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) call := ToolCall{ Name: WriteToolName, @@ -140,7 +140,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing file_path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: "", @@ -161,7 +161,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles missing content", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) params := WriteParams{ FilePath: filepath.Join(tempDir, "file.txt"), @@ -182,7 +182,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles writing to a directory path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a directory dirPath := filepath.Join(tempDir, "test_dir") @@ -208,7 +208,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("handles permission denied", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) filePath := filepath.Join(tempDir, "permission_denied.txt") params := WriteParams{ @@ -234,7 +234,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("detects file modified since last read", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "modified_file.txt") @@ -275,7 +275,7 @@ func TestWriteTool_Run(t *testing.T) { }) t.Run("skips writing when content is identical", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true)) + tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) // Create a file filePath := filepath.Join(tempDir, "identical_content.txt") diff --git a/internal/logging/logger.go b/internal/logging/logger.go index b0639147271b50b4029f1cdbe5baf22c9d60333f..7ae2e7b87ab7f3f71811c793118c79e2a72a3bbf 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -1,6 +1,12 @@ package logging -import "log/slog" +import ( + "fmt" + "log/slog" + "os" + "runtime/debug" + "time" +) func Info(msg string, args ...any) { slog.Info(msg, args...) @@ -37,3 +43,36 @@ func ErrorPersist(msg string, args ...any) { args = append(args, persistKeyArg, true) slog.Error(msg, args...) } + +// RecoverPanic is a common function to handle panics gracefully. +// It logs the error, creates a panic log file with stack trace, +// and executes an optional cleanup function before returning. +func RecoverPanic(name string, cleanup func()) { + if r := recover(); r != nil { + // Log the panic + ErrorPersist(fmt.Sprintf("Panic in %s: %v", name, r)) + + // Create a timestamped panic log file + timestamp := time.Now().Format("20060102-150405") + filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp) + + file, err := os.Create(filename) + if err != nil { + ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err)) + } else { + defer file.Close() + + // Write panic information and stack trace + fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r) + fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339)) + fmt.Fprintf(file, "Stack Trace:\n%s\n", debug.Stack()) + + InfoPersist(fmt.Sprintf("Panic details written to %s", filename)) + } + + // Execute cleanup function if provided + if cleanup != nil { + cleanup() + } + } +} diff --git a/internal/lsp/client.go b/internal/lsp/client.go index e2eedc4fcfde8601b89562f7bdf6a20ed1ee666b..0f03e7fcb12f8745accd0e8b33fb9af98d806d08 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -97,7 +97,12 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er }() // Start message handling loop - go client.handleMessages() + go func() { + defer logging.RecoverPanic("LSP-message-handler", func() { + logging.ErrorPersist("LSP message handler crashed, LSP functionality may be impaired") + }) + client.handleMessages() + }() return client, nil } @@ -374,7 +379,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error { }, } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Closing file", "file", filepath) } if err := c.Notify(ctx, "textDocument/didClose", params); err != nil { @@ -413,12 +418,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) { // Then close them all for _, filePath := range filesToClose { err := c.CloseFile(ctx, filePath) - if err != nil && cnf.Debug { + if err != nil && cnf.DebugLSP { logging.Warn("Error closing file", "file", filePath, "error", err) } } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Closed all files", "files", filesToClose) } } diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 4913c743d97c5e195dc77986ebacd622baf21958..c3088d6852061ee26ab94aa7e2b783cf3b52ca54 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -88,7 +88,7 @@ func HandleServerMessage(params json.RawMessage) { Message string `json:"message"` } if err := json.Unmarshal(params, &msg); err == nil { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Server message", "type", msg.Type, "message", msg.Message) } } diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 4185966f32d199961f66da25222e06d735ed41a5..89255fd78bfe356afc52cf4e82ac64b0a6553e10 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -20,7 +20,7 @@ func WriteMessage(w io.Writer, msg *Message) error { } cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID) } @@ -49,7 +49,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } line = strings.TrimSpace(line) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received header", "line", line) } @@ -65,7 +65,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Content-Length", "length", contentLength) } @@ -76,7 +76,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { return nil, fmt.Errorf("failed to read content: %w", err) } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received content", "content", string(content)) } @@ -95,7 +95,7 @@ func (c *Client) handleMessages() { for { msg, err := ReadMessage(c.stdout) if err != nil { - if cnf.Debug { + if cnf.DebugLSP { logging.Error("Error reading message", "error", err) } return @@ -103,7 +103,7 @@ func (c *Client) handleMessages() { // Handle server->client request (has both Method and ID) if msg.Method != "" && msg.ID != 0 { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID) } @@ -157,11 +157,11 @@ func (c *Client) handleMessages() { c.notificationMu.RUnlock() if ok { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Handling notification", "method", msg.Method) } go handler(msg.Params) - } else if cnf.Debug { + } else if cnf.DebugLSP { logging.Debug("No handler for notification", "method", msg.Method) } continue @@ -174,12 +174,12 @@ func (c *Client) handleMessages() { c.handlersMu.RUnlock() if ok { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received response for request", "id", msg.ID) } ch <- msg close(ch) - } else if cnf.Debug { + } else if cnf.DebugLSP { logging.Debug("No handler for response", "id", msg.ID) } } @@ -191,7 +191,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any cnf := config.Get() id := c.nextID.Add(1) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Making call", "method", method, "id", id) } @@ -217,14 +217,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any return fmt.Errorf("failed to send request: %w", err) } - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Request sent", "method", method, "id", id) } // Wait for response resp := <-ch - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Received response", "id", id) } @@ -250,7 +250,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any // Notify sends a notification (a request without an ID that doesn't expect a response) func (c *Client) Notify(ctx context.Context, method string, params any) error { cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Sending notification", "method", method) } diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index b5ef157109e517c9ba8afc5d26c29ac76a527868..156f38e1aa897b5195921607cd322a7196bebbc2 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -50,7 +50,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc w.registrations = append(w.registrations, watchers...) // Print detailed registration information for debugging - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Adding file watcher registrations", "id", id, "watchers", len(watchers), @@ -116,7 +116,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // Skip directories that should be excluded if d.IsDir() { if path != w.workspacePath && shouldExcludeDir(path) { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -136,7 +136,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc }) elapsedTime := time.Since(startTime) - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Workspace scan complete", "filesOpened", filesOpened, "elapsedTime", elapsedTime.Seconds(), @@ -144,7 +144,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc ) } - if err != nil && cnf.Debug { + if err != nil && cnf.DebugLSP { logging.Debug("Error scanning workspace for files to open", "error", err) } }() @@ -175,7 +175,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str // Skip excluded directories (except workspace root) if d.IsDir() && path != workspacePath { if shouldExcludeDir(path) { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir @@ -228,7 +228,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str } // Debug logging - if cnf.Debug { + if cnf.DebugLSP { matched, kind := w.isPathWatched(event.Name) logging.Debug("File event", "path", event.Name, @@ -491,7 +491,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan // notifyFileEvent sends a didChangeWatchedFiles notification for a file event func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error { cnf := config.Get() - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Notifying file event", "uri", uri, "changeType", changeType, @@ -615,7 +615,7 @@ func shouldExcludeFile(filePath string) bool { // Skip large files if info.Size() > maxFileSize { - if cnf.Debug { + if cnf.DebugLSP { logging.Debug("Skipping large file", "path", filePath, "size", info.Size(), @@ -648,7 +648,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { // Check if this path should be watched according to server registrations if watched, _ := w.isPathWatched(path); watched { // Don't need to check if it's already open - the client.OpenFile handles that - if err := w.client.OpenFile(ctx, path); err != nil && cnf.Debug { + if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { logging.Error("Error opening file", "path", path, "error", err) } } diff --git a/internal/message/content.go b/internal/message/content.go index 422c04f52ca0e5546986ea62b610f467502b08f1..f9e76b11c1a44fde90bedb8721eac44603ab8af3 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -2,6 +2,7 @@ package message import ( "encoding/base64" + "slices" "time" "github.com/kujtimiihoxha/termai/internal/llm/models" @@ -16,6 +17,20 @@ const ( Tool MessageRole = "tool" ) +type FinishReason string + +const ( + FinishReasonEndTurn FinishReason = "end_turn" + FinishReasonMaxTokens FinishReason = "max_tokens" + FinishReasonToolUse FinishReason = "tool_use" + FinishReasonCanceled FinishReason = "canceled" + FinishReasonError FinishReason = "error" + FinishReasonPermissionDenied FinishReason = "permission_denied" + + // Should never happen + FinishReasonUnknown FinishReason = "unknown" +) + type ContentPart interface { isPart() } @@ -83,8 +98,8 @@ type ToolResult struct { func (ToolResult) isPart() {} type Finish struct { - Reason string `json:"reason"` - Time int64 `json:"time"` + Reason FinishReason `json:"reason"` + Time int64 `json:"time"` } func (Finish) isPart() {} @@ -176,7 +191,7 @@ func (m *Message) FinishPart() *Finish { return nil } -func (m *Message) FinishReason() string { +func (m *Message) FinishReason() FinishReason { for _, part := range m.Parts { if c, ok := part.(Finish); ok { return c.Reason @@ -246,7 +261,14 @@ func (m *Message) SetToolResults(tr []ToolResult) { } } -func (m *Message) AddFinish(reason string) { +func (m *Message) AddFinish(reason FinishReason) { + // remove any existing finish part + for i, part := range m.Parts { + if _, ok := part.(Finish); ok { + m.Parts = slices.Delete(m.Parts, i, i+1) + break + } + } m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix()}) } diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 633a6d57f8cf66c246d25b7941c1bb4ad71944ff..3e70ae09525a60e143d710614cf3d68d84f75d03 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -5,7 +5,7 @@ import ( "sync" ) -const bufferSize = 1024 * 1024 +const bufferSize = 1024 type Logger interface { Debug(msg string, args ...any) diff --git a/internal/session/session.go b/internal/session/session.go index 9a16224c3b3c9cc84c50d0d4c829a03c8f7de5d8..019019df47d5ab46e5a2ccb22dff8f3ceadae29e 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -24,6 +24,7 @@ type Session struct { type Service interface { pubsub.Suscriber[Session] Create(ctx context.Context, title string) (Session, error) + CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) Get(ctx context.Context, id string) (Session, error) List(ctx context.Context) ([]Session, error) @@ -63,6 +64,20 @@ func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessi return session, nil } +func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) { + dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{ + ID: "title-" + parentSessionID, + ParentSessionID: sql.NullString{String: parentSessionID, Valid: true}, + Title: "Generate a title", + }) + if err != nil { + return Session{}, err + } + session := s.fromDBItem(dbSession) + s.Publish(pubsub.CreatedEvent, session) + return session, nil +} + func (s *service) Delete(ctx context.Context, id string) error { session, err := s.Get(ctx, id) if err != nil { diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index e893ec2f5f962643024bd6003d3e97351858a8ae..e98001efa8345b310829cb8c09c57a6869c57b20 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -19,8 +19,6 @@ type SessionSelectedMsg = session.Session type SessionClearedMsg struct{} -type AgentWorkingMsg bool - type EditorFocusMsg bool func lspsConfigured(width int) string { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index e87f1ffae79914fcacccac5dd5bc336c31a0b980..e2f4da9e240802b6872c08aa93d78a192fccfa99 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -5,14 +5,17 @@ import ( "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/app" + "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" ) type editorCmp struct { - textarea textarea.Model - agentWorking bool + app *app.App + session session.Session + textarea textarea.Model } type focusedEditorKeyMaps struct { @@ -32,7 +35,7 @@ var focusedKeyMaps = focusedEditorKeyMaps{ ), Blur: key.NewBinding( key.WithKeys("esc"), - key.WithHelp("esc", "blur editor"), + key.WithHelp("esc", "focus messages"), ), } @@ -52,7 +55,7 @@ func (m *editorCmp) Init() tea.Cmd { } func (m *editorCmp) send() tea.Cmd { - if m.agentWorking { + if m.app.CoderAgent.IsSessionBusy(m.session.ID) { return util.ReportWarn("Agent is working, please wait...") } @@ -66,7 +69,6 @@ func (m *editorCmp) send() tea.Cmd { util.CmdHandler(SendMsg{ Text: value, }), - util.CmdHandler(AgentWorkingMsg(true)), util.CmdHandler(EditorFocusMsg(false)), ) } @@ -74,8 +76,11 @@ func (m *editorCmp) send() tea.Cmd { func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd switch msg := msg.(type) { - case AgentWorkingMsg: - m.agentWorking = bool(msg) + case SessionSelectedMsg: + if msg.ID != m.session.ID { + m.session = msg + } + return m, nil case tea.KeyMsg: // if the key does not match any binding, return if m.textarea.Focused() && key.Matches(msg, focusedKeyMaps.Send) { @@ -122,7 +127,7 @@ func (m *editorCmp) BindingKeys() []key.Binding { return bindings } -func NewEditorCmp() tea.Model { +func NewEditorCmp(app *app.App) tea.Model { ti := textarea.New() ti.Prompt = " " ti.ShowLineNumbers = false @@ -138,6 +143,7 @@ func NewEditorCmp() tea.Model { ti.CharLimit = -1 ti.Focus() return &editorCmp{ + app: app, textarea: ti, } } diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index dc21fca2916f1019dbafa313c6d010d096ca9b34..26a98970ee61b60eff0a7396d250dbae09f5a6e5 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -6,7 +6,9 @@ import ( "fmt" "math" "strings" + "time" + "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" @@ -17,9 +19,11 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/agent" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" ) @@ -32,6 +36,9 @@ const ( toolMessageType ) +// messagesTickMsg is a message sent by the timer to refresh messages +type messagesTickMsg time.Time + type uiMessage struct { ID string messageType uiMessageType @@ -52,24 +59,34 @@ type messagesCmp struct { renderer *glamour.TermRenderer focusRenderer *glamour.TermRenderer cachedContent map[string]string - agentWorking bool spinner spinner.Model needsRerender bool - lastViewport string } func (m *messagesCmp) Init() tea.Cmd { - return tea.Batch(m.viewport.Init()) + return tea.Batch(m.viewport.Init(), m.spinner.Tick, m.tickMessages()) +} + +func (m *messagesCmp) tickMessages() tea.Cmd { + return tea.Tick(time.Second, func(t time.Time) tea.Msg { + return messagesTickMsg(t) + }) } func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd switch msg := msg.(type) { - case AgentWorkingMsg: - m.agentWorking = bool(msg) - if m.agentWorking { - cmds = append(cmds, m.spinner.Tick) + case messagesTickMsg: + // Refresh messages if we have an active session + if m.session.ID != "" { + messages, err := m.app.Messages.List(context.Background(), m.session.ID) + if err == nil { + m.messages = messages + m.needsRerender = true + } } + // Continue ticking + cmds = append(cmds, m.tickMessages()) case EditorFocusMsg: m.writingMode = bool(msg) case SessionSelectedMsg: @@ -84,6 +101,7 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.messages = make([]message.Message, 0) m.currentMsgID = "" m.needsRerender = true + m.cachedContent = make(map[string]string) return m, nil case tea.KeyMsg: @@ -104,6 +122,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } if !messageExists { + // If we have messages, ensure the previous last message is not cached + if len(m.messages) > 0 { + lastMsgID := m.messages[len(m.messages)-1].ID + delete(m.cachedContent, lastMsgID) + } + m.messages = append(m.messages, msg.Payload) delete(m.cachedContent, m.currentMsgID) m.currentMsgID = msg.Payload.ID @@ -112,36 +136,40 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } for _, v := range m.messages { for _, c := range v.ToolCalls() { - // the message is being added to the session of a tool called if c.ID == msg.Payload.SessionID { m.needsRerender = true } } } } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { + logging.Debug("Message", "finish", msg.Payload.FinishReason()) for i, v := range m.messages { if v.ID == msg.Payload.ID { - if !m.messages[i].IsFinished() && msg.Payload.IsFinished() && msg.Payload.FinishReason() == "end_turn" || msg.Payload.FinishReason() == "canceled" { - cmds = append(cmds, util.CmdHandler(AgentWorkingMsg(false))) - } m.messages[i] = msg.Payload delete(m.cachedContent, msg.Payload.ID) + + // If this is the last message, ensure it's not cached + if i == len(m.messages)-1 { + delete(m.cachedContent, msg.Payload.ID) + } + m.needsRerender = true break } } } } - if m.agentWorking { - u, cmd := m.spinner.Update(msg) - m.spinner = u - cmds = append(cmds, cmd) - } + oldPos := m.viewport.YPosition u, cmd := m.viewport.Update(msg) m.viewport = u m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos cmds = append(cmds, cmd) + + spinner, cmd := m.spinner.Update(msg) + m.spinner = spinner + cmds = append(cmds, cmd) + if m.needsRerender { m.renderView() if len(m.messages) > 0 { @@ -157,10 +185,21 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(cmds...) } +func (m *messagesCmp) IsAgentWorking() bool { + return m.app.CoderAgent.IsSessionBusy(m.session.ID) +} + func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string { - if v, ok := m.cachedContent[msg.ID]; ok { - return v + // Check if this is the last message in the list + isLastMessage := len(m.messages) > 0 && m.messages[len(m.messages)-1].ID == msg.ID + + // Only use cache for non-last messages + if !isLastMessage { + if v, ok := m.cachedContent[msg.ID]; ok { + return v + } } + style := styles.BaseStyle. Width(m.width). BorderLeft(true). @@ -191,7 +230,12 @@ func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) s parts..., ), ) - m.cachedContent[msg.ID] = rendered + + // Only cache if it's not the last message + if !isLastMessage { + m.cachedContent[msg.ID] = rendered + } + return rendered } @@ -207,32 +251,71 @@ func formatTimeDifference(unixTime1, unixTime2 int64) string { return fmt.Sprintf("%dm%ds", minutes, seconds) } +func (m *messagesCmp) findToolResponse(callID string) *message.ToolResult { + for _, v := range m.messages { + for _, c := range v.ToolResults() { + if c.ToolCallID == callID { + return &c + } + } + } + return nil +} + func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string { key := "" value := "" + result := styles.BaseStyle.Foreground(styles.PrimaryColor).Render(m.spinner.View() + " waiting for response...") + + response := m.findToolResponse(toolCall.ID) + if response != nil && response.IsError { + // Clean up error message for display by removing newlines + // This ensures error messages display properly in the UI + errMsg := strings.ReplaceAll(response.Content, "\n", " ") + result = styles.BaseStyle.Foreground(styles.Error).Render(ansi.Truncate(errMsg, 40, "...")) + } else if response != nil { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render("Done") + } switch toolCall.Name { // TODO: add result data to the tools case agent.AgentToolName: key = "Task" var params agent.AgentParams json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.Prompt - // TODO: handle nested calls + value = strings.ReplaceAll(params.Prompt, "\n", " ") + if response != nil && !response.IsError { + firstRow := strings.ReplaceAll(response.Content, "\n", " ") + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(ansi.Truncate(firstRow, 40, "...")) + } case tools.BashToolName: key = "Bash" var params tools.BashParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.Command + if response != nil && !response.IsError { + metadata := tools.BashResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("Took %s", formatTimeDifference(metadata.StartTime, metadata.EndTime))) + } + case tools.EditToolName: key = "Edit" var params tools.EditParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.FilePath + if response != nil && !response.IsError { + metadata := tools.EditResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) + } case tools.FetchToolName: key = "Fetch" var params tools.FetchParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.URL + if response != nil && !response.IsError { + result = styles.BaseStyle.Foreground(styles.Error).Render(response.Content) + } case tools.GlobToolName: key = "Glob" var params tools.GlobParams @@ -241,6 +324,15 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s params.Path = "." } value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) + if response != nil && !response.IsError { + metadata := tools.GlobResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) + } + } case tools.GrepToolName: key = "Grep" var params tools.GrepParams @@ -249,19 +341,46 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s params.Path = "." } value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) + if response != nil && !response.IsError { + metadata := tools.GrepResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfMatches)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfMatches)) + } + } case tools.LSToolName: - key = "Ls" + key = "ls" var params tools.LSParams json.Unmarshal([]byte(toolCall.Input), ¶ms) if params.Path == "" { params.Path = "." } value = params.Path + if response != nil && !response.IsError { + metadata := tools.LSResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) + } + } case tools.SourcegraphToolName: key = "Sourcegraph" var params tools.SourcegraphParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.Query + if response != nil && !response.IsError { + metadata := tools.SourcegraphResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + if metadata.Truncated { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found (truncated)", metadata.NumberOfMatches)) + } else { + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found", metadata.NumberOfMatches)) + } + } case tools.ViewToolName: key = "View" var params tools.ViewParams @@ -272,6 +391,12 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s var params tools.WriteParams json.Unmarshal([]byte(toolCall.Input), ¶ms) value = params.FilePath + if response != nil && !response.IsError { + metadata := tools.WriteResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + + result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) + } default: key = toolCall.Name var params map[string]any @@ -300,14 +425,15 @@ func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) s ) if !isNested { value = valyeStyle. - Width(m.width - lipgloss.Width(keyValye) - 2). Render( ansi.Truncate( - value, - m.width-lipgloss.Width(keyValye)-2, + value+" ", + m.width-lipgloss.Width(keyValye)-2-lipgloss.Width(result), "...", ), ) + value += result + } else { keyValye = keyStyle.Render( fmt.Sprintf(" └ %s: ", key), @@ -409,6 +535,27 @@ func (m *messagesCmp) renderView() { m.uiMessages = make([]uiMessage, 0) pos := 0 + // If we have messages, ensure the last message is not cached + // This ensures we always render the latest content for the most recent message + // which may be actively updating (e.g., during generation) + if len(m.messages) > 0 { + lastMsgID := m.messages[len(m.messages)-1].ID + delete(m.cachedContent, lastMsgID) + } + + // Limit cache to 10 messages + if len(m.cachedContent) > 15 { + // Create a list of keys to delete (oldest messages first) + keys := make([]string, 0, len(m.cachedContent)) + for k := range m.cachedContent { + keys = append(keys, k) + } + // Delete oldest messages until we have 10 or fewer + for i := 0; i < len(keys)-15; i++ { + delete(m.cachedContent, keys[i]) + } + } + for _, v := range m.messages { switch v.Role { case message.User: @@ -487,7 +634,7 @@ func (m *messagesCmp) View() string { func (m *messagesCmp) help() string { text := "" - if m.agentWorking { + if m.IsAgentWorking() { text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render( fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."), ) @@ -562,9 +709,15 @@ func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { m.messages = messages m.currentMsgID = m.messages[len(m.messages)-1].ID m.needsRerender = true + m.cachedContent = make(map[string]string) return nil } +func (m *messagesCmp) BindingKeys() []key.Binding { + bindings := layout.KeyMapToSlice(m.viewport.KeyMap) + return bindings +} + func NewMessagesCmp(app *app.App) tea.Model { focusRenderer, _ := glamour.NewTermRenderer( glamour.WithStyles(styles.MarkdownTheme(true)), diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index 51192cf9a6e7a58c94251ae1f608d834a79a49f2..b90269d1a643847a0c039f77055890aa802db3da 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -1,10 +1,15 @@ package chat import ( + "context" "fmt" + "strings" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/diff" + "github.com/kujtimiihoxha/termai/internal/history" "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/session" "github.com/kujtimiihoxha/termai/internal/tui/styles" @@ -13,9 +18,33 @@ import ( type sidebarCmp struct { width, height int session session.Session + history history.Service + modFiles map[string]struct { + additions int + removals int + } } func (m *sidebarCmp) Init() tea.Cmd { + if m.history != nil { + ctx := context.Background() + // Subscribe to file events + filesCh := m.history.Subscribe(ctx) + + // Initialize the modified files map + m.modFiles = make(map[string]struct { + additions int + removals int + }) + + // Load initial files and calculate diffs + m.loadModifiedFiles(ctx) + + // Return a command that will send file events to the Update method + return func() tea.Msg { + return <-filesCh + } + } return nil } @@ -27,6 +56,13 @@ func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg.Payload } } + case pubsub.Event[history.File]: + if msg.Payload.SessionID == m.session.ID { + // When a file changes, reload all modified files + // This ensures we have the complete and accurate list + ctx := context.Background() + m.loadModifiedFiles(ctx) + } } return m, nil } @@ -86,18 +122,28 @@ func (m *sidebarCmp) modifiedFile(filePath string, additions, removals int) stri func (m *sidebarCmp) modifiedFiles() string { modifiedFiles := styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render("Modified Files:") - files := []struct { - path string - additions int - removals int - }{ - {"file1.txt", 10, 5}, - {"file2.txt", 20, 0}, - {"file3.txt", 0, 15}, + + // If no modified files, show a placeholder message + if m.modFiles == nil || len(m.modFiles) == 0 { + message := "No modified files" + remainingWidth := m.width - lipgloss.Width(modifiedFiles) + if remainingWidth > 0 { + message += strings.Repeat(" ", remainingWidth) + } + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + modifiedFiles, + styles.BaseStyle.Foreground(styles.ForgroundDim).Render(message), + ), + ) } + var fileViews []string - for _, file := range files { - fileViews = append(fileViews, m.modifiedFile(file.path, file.additions, file.removals)) + for path, stats := range m.modFiles { + fileViews = append(fileViews, m.modifiedFile(path, stats.additions, stats.removals)) } return styles.BaseStyle. @@ -123,8 +169,116 @@ func (m *sidebarCmp) GetSize() (int, int) { return m.width, m.height } -func NewSidebarCmp(session session.Session) tea.Model { +func NewSidebarCmp(session session.Session, history history.Service) tea.Model { return &sidebarCmp{ session: session, + history: history, + } +} + +func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) { + if m.history == nil || m.session.ID == "" { + return + } + + // Get all latest files for this session + latestFiles, err := m.history.ListLatestSessionFiles(ctx, m.session.ID) + if err != nil { + return + } + + // Get all files for this session (to find initial versions) + allFiles, err := m.history.ListBySession(ctx, m.session.ID) + if err != nil { + return + } + + // Process each latest file + for _, file := range latestFiles { + // Skip if this is the initial version (no changes to show) + if file.Version == history.InitialVersion { + continue + } + + // Find the initial version for this specific file + var initialVersion history.File + for _, v := range allFiles { + if v.Path == file.Path && v.Version == history.InitialVersion { + initialVersion = v + break + } + } + + // Skip if we can't find the initial version + if initialVersion.ID == "" { + continue + } + + // Calculate diff between initial and latest version + _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path) + + // Only add to modified files if there are changes + if additions > 0 || removals > 0 { + // Remove working directory prefix from file path + displayPath := file.Path + workingDir := config.WorkingDirectory() + displayPath = strings.TrimPrefix(displayPath, workingDir) + displayPath = strings.TrimPrefix(displayPath, "/") + + m.modFiles[displayPath] = struct { + additions int + removals int + }{ + additions: additions, + removals: removals, + } + } + } +} + +func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) { + // Skip if not the latest version + if file.Version == history.InitialVersion { + return + } + + // Get all versions of this file + fileVersions, err := m.history.ListBySession(ctx, m.session.ID) + if err != nil { + return + } + + // Find the initial version + var initialVersion history.File + for _, v := range fileVersions { + if v.Path == file.Path && v.Version == history.InitialVersion { + initialVersion = v + break + } + } + + // Skip if we can't find the initial version + if initialVersion.ID == "" { + return + } + + // Calculate diff between initial and latest version + _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path) + + // Only add to modified files if there are changes + if additions > 0 || removals > 0 { + // Remove working directory prefix from file path + displayPath := file.Path + workingDir := config.WorkingDirectory() + displayPath = strings.TrimPrefix(displayPath, workingDir) + displayPath = strings.TrimPrefix(displayPath, "/") + + m.modFiles[displayPath] = struct { + additions int + removals int + }{ + additions: additions, + removals: removals, + } } } diff --git a/internal/tui/components/core/dialog.go b/internal/tui/components/core/dialog.go deleted file mode 100644 index a8fef2e86e5359d88f40d08f3b8c761f1378a578..0000000000000000000000000000000000000000 --- a/internal/tui/components/core/dialog.go +++ /dev/null @@ -1,117 +0,0 @@ -package core - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/util" -) - -type SizeableModel interface { - tea.Model - layout.Sizeable -} - -type DialogMsg struct { - Content SizeableModel - WidthRatio float64 - HeightRatio float64 - - MinWidth int - MinHeight int -} - -type DialogCloseMsg struct{} - -type KeyBindings struct { - Return key.Binding -} - -var keys = KeyBindings{ - Return: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "close"), - ), -} - -type DialogCmp interface { - tea.Model - layout.Bindings -} - -type dialogCmp struct { - content SizeableModel - screenWidth int - screenHeight int - - widthRatio float64 - heightRatio float64 - - minWidth int - minHeight int - - width int - height int -} - -func (d *dialogCmp) Init() tea.Cmd { - return nil -} - -func (d *dialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - d.screenWidth = msg.Width - d.screenHeight = msg.Height - d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth) - d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight) - if d.content != nil { - d.content.SetSize(d.width, d.height) - } - return d, nil - case DialogMsg: - d.content = msg.Content - d.widthRatio = msg.WidthRatio - d.heightRatio = msg.HeightRatio - d.minWidth = msg.MinWidth - d.minHeight = msg.MinHeight - d.width = max(int(float64(d.screenWidth)*d.widthRatio), d.minWidth) - d.height = max(int(float64(d.screenHeight)*d.heightRatio), d.minHeight) - if d.content != nil { - d.content.SetSize(d.width, d.height) - } - case DialogCloseMsg: - d.content = nil - return d, nil - case tea.KeyMsg: - if key.Matches(msg, keys.Return) { - return d, util.CmdHandler(DialogCloseMsg{}) - } - } - if d.content != nil { - u, cmd := d.content.Update(msg) - d.content = u.(SizeableModel) - return d, cmd - } - return d, nil -} - -func (d *dialogCmp) BindingKeys() []key.Binding { - bindings := []key.Binding{keys.Return} - if d.content == nil { - return bindings - } - if c, ok := d.content.(layout.Bindings); ok { - return append(bindings, c.BindingKeys()...) - } - return bindings -} - -func (d *dialogCmp) View() string { - return lipgloss.NewStyle().Width(d.width).Height(d.height).Render(d.content.View()) -} - -func NewDialogCmp() DialogCmp { - return &dialogCmp{} -} diff --git a/internal/tui/components/core/help.go b/internal/tui/components/core/help.go deleted file mode 100644 index 4ef857c78a0ae097b386647fa48826f45037dc14..0000000000000000000000000000000000000000 --- a/internal/tui/components/core/help.go +++ /dev/null @@ -1,119 +0,0 @@ -package core - -import ( - "strings" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" -) - -type HelpCmp interface { - tea.Model - SetBindings(bindings []key.Binding) - Height() int -} - -const ( - helpWidgetHeight = 12 -) - -type helpCmp struct { - width int - bindings []key.Binding -} - -func (h *helpCmp) Init() tea.Cmd { - return nil -} - -func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - h.width = msg.Width - } - return h, nil -} - -func (h *helpCmp) View() string { - helpKeyStyle := styles.Bold.Foreground(styles.Rosewater).Margin(0, 1, 0, 0) - helpDescStyle := styles.Regular.Foreground(styles.Flamingo) - // Compile list of bindings to render - bindings := removeDuplicateBindings(h.bindings) - // Enumerate through each group of bindings, populating a series of - // pairs of columns, one for keys, one for descriptions - var ( - pairs []string - width int - rows = helpWidgetHeight - 2 - ) - for i := 0; i < len(bindings); i += rows { - var ( - keys []string - descs []string - ) - for j := i; j < min(i+rows, len(bindings)); j++ { - keys = append(keys, helpKeyStyle.Render(bindings[j].Help().Key)) - descs = append(descs, helpDescStyle.Render(bindings[j].Help().Desc)) - } - // Render pair of columns; beyond the first pair, render a three space - // left margin, in order to visually separate the pairs. - var cols []string - if len(pairs) > 0 { - cols = []string{" "} - } - cols = append(cols, - strings.Join(keys, "\n"), - strings.Join(descs, "\n"), - ) - - pair := lipgloss.JoinHorizontal(lipgloss.Top, cols...) - // check whether it exceeds the maximum width avail (the width of the - // terminal, subtracting 2 for the borders). - width += lipgloss.Width(pair) - if width > h.width-2 { - break - } - pairs = append(pairs, pair) - } - - // Join pairs of columns and enclose in a border - content := lipgloss.JoinHorizontal(lipgloss.Top, pairs...) - return styles.DoubleBorder.Height(rows).PaddingLeft(1).Width(h.width - 2).Render(content) -} - -func removeDuplicateBindings(bindings []key.Binding) []key.Binding { - seen := make(map[string]struct{}) - result := make([]key.Binding, 0, len(bindings)) - - // Process bindings in reverse order - for i := len(bindings) - 1; i >= 0; i-- { - b := bindings[i] - k := strings.Join(b.Keys(), " ") - if _, ok := seen[k]; ok { - // duplicate, skip - continue - } - seen[k] = struct{}{} - // Add to the beginning of result to maintain original order - result = append([]key.Binding{b}, result...) - } - - return result -} - -func (h *helpCmp) SetBindings(bindings []key.Binding) { - h.bindings = bindings -} - -func (h helpCmp) Height() int { - return helpWidgetHeight -} - -func NewHelpCmp() HelpCmp { - return &helpCmp{ - width: 0, - bindings: make([]key.Binding, 0), - } -} diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 93ba345075e513e691cb0b27a6f5de870c5c2354..089dffa2c33fa3b9b27348a65a77bdac989b049d 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -1,21 +1,25 @@ package core import ( + "fmt" + "strings" "time" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/termai/internal/lsp/protocol" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/termai/internal/version" ) type statusCmp struct { info util.InfoMsg width int messageTTL time.Duration + lspClients map[string]*lsp.Client } // clearMessageCmd is a command that clears status messages after a timeout @@ -47,20 +51,18 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } -var ( - versionWidget = styles.Padded.Background(styles.DarkGrey).Foreground(styles.Text).Render(version.Version) - helpWidget = styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help") -) +var helpWidget = styles.Padded.Background(styles.ForgroundMid).Foreground(styles.BackgroundDarker).Bold(true).Render("ctrl+? help") func (m statusCmp) View() string { - status := styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help") + status := helpWidget + diagnostics := styles.Padded.Background(styles.BackgroundDarker).Render(m.projectDiagnostics()) if m.info.Msg != "" { infoStyle := styles.Padded. Foreground(styles.Base). - Width(m.availableFooterMsgWidth()) + Width(m.availableFooterMsgWidth(diagnostics)) switch m.info.Type { case util.InfoTypeInfo: - infoStyle = infoStyle.Background(styles.Blue) + infoStyle = infoStyle.Background(styles.BorderColor) case util.InfoTypeWarn: infoStyle = infoStyle.Background(styles.Peach) case util.InfoTypeError: @@ -68,7 +70,7 @@ func (m statusCmp) View() string { } // Truncate message if it's longer than available width msg := m.info.Msg - availWidth := m.availableFooterMsgWidth() - 10 + availWidth := m.availableFooterMsgWidth(diagnostics) - 10 if len(msg) > availWidth && availWidth > 0 { msg = msg[:availWidth] + "..." } @@ -76,27 +78,81 @@ func (m statusCmp) View() string { } else { status += styles.Padded. Foreground(styles.Base). - Background(styles.LightGrey). - Width(m.availableFooterMsgWidth()). + Background(styles.BackgroundDim). + Width(m.availableFooterMsgWidth(diagnostics)). Render("") } + status += diagnostics status += m.model() - status += versionWidget return status } -func (m statusCmp) availableFooterMsgWidth() int { - // -2 to accommodate padding - return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(versionWidget)-lipgloss.Width(m.model())) +func (m *statusCmp) projectDiagnostics() string { + errorDiagnostics := []protocol.Diagnostic{} + warnDiagnostics := []protocol.Diagnostic{} + hintDiagnostics := []protocol.Diagnostic{} + infoDiagnostics := []protocol.Diagnostic{} + for _, client := range m.lspClients { + for _, d := range client.GetDiagnostics() { + for _, diag := range d { + switch diag.Severity { + case protocol.SeverityError: + errorDiagnostics = append(errorDiagnostics, diag) + case protocol.SeverityWarning: + warnDiagnostics = append(warnDiagnostics, diag) + case protocol.SeverityHint: + hintDiagnostics = append(hintDiagnostics, diag) + case protocol.SeverityInformation: + infoDiagnostics = append(infoDiagnostics, diag) + } + } + } + } + + if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 { + return "No diagnostics" + } + + diagnostics := []string{} + + if len(errorDiagnostics) > 0 { + errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) + diagnostics = append(diagnostics, errStr) + } + if len(warnDiagnostics) > 0 { + warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) + diagnostics = append(diagnostics, warnStr) + } + if len(hintDiagnostics) > 0 { + hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) + diagnostics = append(diagnostics, hintStr) + } + if len(infoDiagnostics) > 0 { + infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) + diagnostics = append(diagnostics, infoStr) + } + + return strings.Join(diagnostics, " ") +} + +func (m statusCmp) availableFooterMsgWidth(diagnostics string) int { + return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)) } func (m statusCmp) model() string { - model := models.SupportedModels[config.Get().Model.Coder] + cfg := config.Get() + + coder, ok := cfg.Agents[config.AgentCoder] + if !ok { + return "Unknown" + } + model := models.SupportedModels[coder.Model] return styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render(model.Name) } -func NewStatusCmp() tea.Model { +func NewStatusCmp(lspClients map[string]*lsp.Client) tea.Model { return &statusCmp{ messageTTL: 10 * time.Second, + lspClients: lspClients, } } diff --git a/internal/tui/components/dialog/help.go b/internal/tui/components/dialog/help.go new file mode 100644 index 0000000000000000000000000000000000000000..1d3c2b077b1b393c4bc1e2147af4cc0c8d519d9c --- /dev/null +++ b/internal/tui/components/dialog/help.go @@ -0,0 +1,182 @@ +package dialog + +import ( + "strings" + + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/styles" +) + +type helpCmp struct { + width int + height int + keys []key.Binding +} + +func (h *helpCmp) Init() tea.Cmd { + return nil +} + +func (h *helpCmp) SetBindings(k []key.Binding) { + h.keys = k +} + +func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + h.width = 80 + h.height = msg.Height + } + return h, nil +} + +func removeDuplicateBindings(bindings []key.Binding) []key.Binding { + seen := make(map[string]struct{}) + result := make([]key.Binding, 0, len(bindings)) + + // Process bindings in reverse order + for i := len(bindings) - 1; i >= 0; i-- { + b := bindings[i] + k := strings.Join(b.Keys(), " ") + if _, ok := seen[k]; ok { + // duplicate, skip + continue + } + seen[k] = struct{}{} + // Add to the beginning of result to maintain original order + result = append([]key.Binding{b}, result...) + } + + return result +} + +func (h *helpCmp) render() string { + helpKeyStyle := styles.Bold.Background(styles.Background).Foreground(styles.Forground).Padding(0, 1, 0, 0) + helpDescStyle := styles.Regular.Background(styles.Background).Foreground(styles.ForgroundMid) + // Compile list of bindings to render + bindings := removeDuplicateBindings(h.keys) + // Enumerate through each group of bindings, populating a series of + // pairs of columns, one for keys, one for descriptions + var ( + pairs []string + width int + rows = 12 - 2 + ) + for i := 0; i < len(bindings); i += rows { + var ( + keys []string + descs []string + ) + for j := i; j < min(i+rows, len(bindings)); j++ { + keys = append(keys, helpKeyStyle.Render(bindings[j].Help().Key)) + descs = append(descs, helpDescStyle.Render(bindings[j].Help().Desc)) + } + // Render pair of columns; beyond the first pair, render a three space + // left margin, in order to visually separate the pairs. + var cols []string + if len(pairs) > 0 { + cols = []string{styles.BaseStyle.Render(" ")} + } + + maxDescWidth := 0 + for _, desc := range descs { + if maxDescWidth < lipgloss.Width(desc) { + maxDescWidth = lipgloss.Width(desc) + } + } + for i := range descs { + remainingWidth := maxDescWidth - lipgloss.Width(descs[i]) + if remainingWidth > 0 { + descs[i] = descs[i] + styles.BaseStyle.Render(strings.Repeat(" ", remainingWidth)) + } + } + maxKeyWidth := 0 + for _, key := range keys { + if maxKeyWidth < lipgloss.Width(key) { + maxKeyWidth = lipgloss.Width(key) + } + } + for i := range keys { + remainingWidth := maxKeyWidth - lipgloss.Width(keys[i]) + if remainingWidth > 0 { + keys[i] = keys[i] + styles.BaseStyle.Render(strings.Repeat(" ", remainingWidth)) + } + } + + cols = append(cols, + strings.Join(keys, "\n"), + strings.Join(descs, "\n"), + ) + + pair := styles.BaseStyle.Render(lipgloss.JoinHorizontal(lipgloss.Top, cols...)) + // check whether it exceeds the maximum width avail (the width of the + // terminal, subtracting 2 for the borders). + width += lipgloss.Width(pair) + if width > h.width-2 { + break + } + pairs = append(pairs, pair) + } + + // https://github.com/charmbracelet/lipgloss/issues/209 + if len(pairs) > 1 { + prefix := pairs[:len(pairs)-1] + lastPair := pairs[len(pairs)-1] + prefix = append(prefix, lipgloss.Place( + lipgloss.Width(lastPair), // width + lipgloss.Height(prefix[0]), // height + lipgloss.Left, // x + lipgloss.Top, // y + lastPair, // content + lipgloss.WithWhitespaceBackground(styles.Background), // background + )) + content := styles.BaseStyle.Width(h.width).Render( + lipgloss.JoinHorizontal( + lipgloss.Top, + prefix..., + ), + ) + return content + } + // Join pairs of columns and enclose in a border + content := styles.BaseStyle.Width(h.width).Render( + lipgloss.JoinHorizontal( + lipgloss.Top, + pairs..., + ), + ) + return content +} + +func (h *helpCmp) View() string { + content := h.render() + header := styles.BaseStyle. + Bold(true). + Width(lipgloss.Width(content)). + Foreground(styles.PrimaryColor). + Render("Keyboard Shortcuts") + + return styles.BaseStyle.Padding(1). + Border(lipgloss.RoundedBorder()). + BorderForeground(styles.ForgroundDim). + Width(h.width). + BorderBackground(styles.Background). + Render( + lipgloss.JoinVertical(lipgloss.Center, + header, + styles.BaseStyle.Render(strings.Repeat(" ", lipgloss.Width(header))), + content, + ), + ) +} + +type HelpCmp interface { + tea.Model + SetBindings([]key.Binding) +} + +func NewHelpCmp() HelpCmp { + return &helpCmp{} +} diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index d147f89cd0b40095336f989d3d42f9a107822406..9c55effde1c3b18692eb4063cf56b560313329cb 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -12,12 +12,9 @@ import ( "github.com/kujtimiihoxha/termai/internal/diff" "github.com/kujtimiihoxha/termai/internal/llm/tools" "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - - "github.com/charmbracelet/huh" ) type PermissionAction string @@ -35,69 +32,64 @@ type PermissionResponseMsg struct { Action PermissionAction } -// PermissionDialog interface for permission dialog component -type PermissionDialog interface { +// PermissionDialogCmp interface for permission dialog component +type PermissionDialogCmp interface { tea.Model - layout.Sizeable layout.Bindings + SetPermissions(permission permission.PermissionRequest) } -type keyMap struct { - ChangeFocus key.Binding +type permissionsMapping struct { + LeftRight key.Binding + EnterSpace key.Binding + Allow key.Binding + AllowSession key.Binding + Deny key.Binding + Tab key.Binding } -var keyMapValue = keyMap{ - ChangeFocus: key.NewBinding( +var permissionsKeys = permissionsMapping{ + LeftRight: key.NewBinding( + key.WithKeys("left", "right"), + key.WithHelp("←/→", "switch options"), + ), + EnterSpace: key.NewBinding( + key.WithKeys("enter", " "), + key.WithHelp("enter/space", "confirm"), + ), + Allow: key.NewBinding( + key.WithKeys("a"), + key.WithHelp("a", "allow"), + ), + AllowSession: key.NewBinding( + key.WithKeys("A"), + key.WithHelp("A", "allow for session"), + ), + Deny: key.NewBinding( + key.WithKeys("d"), + key.WithHelp("d", "deny"), + ), + Tab: key.NewBinding( key.WithKeys("tab"), - key.WithHelp("tab", "change focus"), + key.WithHelp("tab", "switch options"), ), } // permissionDialogCmp is the implementation of PermissionDialog type permissionDialogCmp struct { - form *huh.Form width int height int permission permission.PermissionRequest windowSize tea.WindowSizeMsg - r *glamour.TermRenderer contentViewPort viewport.Model - isViewportFocus bool - selectOption *huh.Select[string] -} + selectedOption int // 0: Allow, 1: Allow for session, 2: Deny -// formatDiff formats a diff string with colors for additions and deletions -func formatDiff(diffText string) string { - lines := strings.Split(diffText, "\n") - var formattedLines []string - - // Define styles for different line types - addStyle := lipgloss.NewStyle().Foreground(styles.Green) - removeStyle := lipgloss.NewStyle().Foreground(styles.Red) - headerStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Blue) - contextStyle := lipgloss.NewStyle().Foreground(styles.SubText0) - - // Process each line - for _, line := range lines { - if strings.HasPrefix(line, "+") { - formattedLines = append(formattedLines, addStyle.Render(line)) - } else if strings.HasPrefix(line, "-") { - formattedLines = append(formattedLines, removeStyle.Render(line)) - } else if strings.HasPrefix(line, "Changes:") || strings.HasPrefix(line, " ...") { - formattedLines = append(formattedLines, headerStyle.Render(line)) - } else if strings.HasPrefix(line, " ") { - formattedLines = append(formattedLines, contextStyle.Render(line)) - } else { - formattedLines = append(formattedLines, line) - } - } - - // Join all formatted lines - return strings.Join(formattedLines, "\n") + diffCache map[string]string + markdownCache map[string]string } func (p *permissionDialogCmp) Init() tea.Cmd { - return nil + return p.contentViewPort.Init() } func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -106,373 +98,363 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: p.windowSize = msg + p.SetSize() + p.markdownCache = make(map[string]string) + p.diffCache = make(map[string]string) case tea.KeyMsg: - if key.Matches(msg, keyMapValue.ChangeFocus) { - p.isViewportFocus = !p.isViewportFocus - if p.isViewportFocus { - p.selectOption.Blur() - // Add a visual indicator for focus change - cmds = append(cmds, tea.Batch( - util.ReportInfo("Viewing content - use arrow keys to scroll"), - )) - } else { - p.selectOption.Focus() - // Add a visual indicator for focus change - cmds = append(cmds, tea.Batch( - util.CmdHandler(util.ReportInfo("Select an action")), - )) - } - return p, tea.Batch(cmds...) - } - } - - if p.isViewportFocus { - viewPort, cmd := p.contentViewPort.Update(msg) - p.contentViewPort = viewPort - cmds = append(cmds, cmd) - } else { - form, cmd := p.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - p.form = f + switch { + case key.Matches(msg, permissionsKeys.LeftRight) || key.Matches(msg, permissionsKeys.Tab): + // Change selected option + p.selectedOption = (p.selectedOption + 1) % 3 + return p, nil + case key.Matches(msg, permissionsKeys.EnterSpace): + // Select current option + return p, p.selectCurrentOption() + case key.Matches(msg, permissionsKeys.Allow): + // Select Allow + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllow, Permission: p.permission}) + case key.Matches(msg, permissionsKeys.AllowSession): + // Select Allow for session + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllowForSession, Permission: p.permission}) + case key.Matches(msg, permissionsKeys.Deny): + // Select Deny + return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionDeny, Permission: p.permission}) + default: + // Pass other keys to viewport + viewPort, cmd := p.contentViewPort.Update(msg) + p.contentViewPort = viewPort cmds = append(cmds, cmd) } - - if p.form.State == huh.StateCompleted { - // Get the selected action - action := p.form.GetString("action") - - // Close the dialog and return the response - return p, tea.Batch( - util.CmdHandler(core.DialogCloseMsg{}), - util.CmdHandler(PermissionResponseMsg{Action: PermissionAction(action), Permission: p.permission}), - ) - } } + return p, tea.Batch(cmds...) } -func (p *permissionDialogCmp) render() string { - keyStyle := lipgloss.NewStyle().Bold(true).Foreground(styles.Rosewater) - valueStyle := lipgloss.NewStyle().Foreground(styles.Peach) +func (p *permissionDialogCmp) selectCurrentOption() tea.Cmd { + var action PermissionAction - form := p.form.View() - - headerParts := []string{ - lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Tool:"), " ", valueStyle.Render(p.permission.ToolName)), - " ", - lipgloss.JoinHorizontal(lipgloss.Left, keyStyle.Render("Path:"), " ", valueStyle.Render(p.permission.Path)), - " ", + switch p.selectedOption { + case 0: + action = PermissionAllow + case 1: + action = PermissionAllowForSession + case 2: + action = PermissionDeny } - // Create the header content first so it can be used in all cases - headerContent := lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - r, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.CatppuccinMarkdownStyle()), - glamour.WithWordWrap(p.width-10), - glamour.WithEmoji(), - ) - - // Handle different tool types - switch p.permission.ToolName { - case tools.BashToolName: - pr := p.permission.Params.(tools.BashPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Command:")) - content := fmt.Sprintf("```bash\n%s\n```", pr.Command) - - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on content - contentLines := len(strings.Split(renderedContent, "\n")) - // Set a reasonable min/max for the viewport height - minContentHeight := 3 - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - - // Add some padding to the content lines - contentHeight := contentLines + 2 - contentHeight = max(contentHeight, minContentHeight) - contentHeight = min(contentHeight, maxContentHeight) - p.contentViewPort.Height = contentHeight - - p.contentViewPort.SetContent(renderedContent) + return util.CmdHandler(PermissionResponseMsg{Action: action, Permission: p.permission}) +} - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor +func (p *permissionDialogCmp) renderButtons() string { + allowStyle := styles.BaseStyle + allowSessionStyle := styles.BaseStyle + denyStyle := styles.BaseStyle + spacerStyle := styles.BaseStyle.Background(styles.Background) + + // Style the selected button + switch p.selectedOption { + case 0: + allowStyle = allowStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + allowSessionStyle = allowSessionStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + denyStyle = denyStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + case 1: + allowStyle = allowStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + allowSessionStyle = allowSessionStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + denyStyle = denyStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + case 2: + allowStyle = allowStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + allowSessionStyle = allowSessionStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + denyStyle = denyStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + } - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } + allowButton := allowStyle.Padding(0, 1).Render("Allow (a)") + allowSessionButton := allowSessionStyle.Padding(0, 1).Render("Allow for session (A)") + denyButton := denyStyle.Padding(0, 1).Render("Deny (d)") + + content := lipgloss.JoinHorizontal( + lipgloss.Left, + allowButton, + spacerStyle.Render(" "), + allowSessionButton, + spacerStyle.Render(" "), + denyButton, + spacerStyle.Render(" "), + ) - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + remainingWidth := p.width - lipgloss.Width(content) + if remainingWidth > 0 { + content = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + content + } + return content +} - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } +func (p *permissionDialogCmp) renderHeader() string { + toolKey := styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("Tool") + toolValue := styles.BaseStyle. + Foreground(styles.Forground). + Width(p.width - lipgloss.Width(toolKey)). + Render(fmt.Sprintf(": %s", p.permission.ToolName)) - contentFinal := contentStyle.Render(p.contentViewPort.View()) + pathKey := styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("Path") + pathValue := styles.BaseStyle. + Foreground(styles.Forground). + Width(p.width - lipgloss.Width(pathKey)). + Render(fmt.Sprintf(": %s", p.permission.Path)) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) + headerParts := []string{ + lipgloss.JoinHorizontal( + lipgloss.Left, + toolKey, + toolValue, + ), + styles.BaseStyle.Render(strings.Repeat(" ", p.width)), + lipgloss.JoinHorizontal( + lipgloss.Left, + pathKey, + pathValue, + ), + styles.BaseStyle.Render(strings.Repeat(" ", p.width)), + } + // Add tool-specific header information + switch p.permission.ToolName { + case tools.BashToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Command")) case tools.EditToolName: - pr := p.permission.Params.(tools.EditPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Update")) - // Recreate header content with the updated headerParts - headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - // Format the diff with colors - - // Set up viewport for the diff content - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on window size - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.Height = maxContentHeight - diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) - if err != nil { - diff = fmt.Sprintf("Error formatting diff: %v", err) - } - p.contentViewPort.SetContent(diff) + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Diff")) + case tools.WriteToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("Diff")) + case tools.FetchToolName: + headerParts = append(headerParts, styles.BaseStyle.Foreground(styles.ForgroundDim).Width(p.width).Bold(true).Render("URL")) + } - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor + return lipgloss.NewStyle().Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) +} - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } +func (p *permissionDialogCmp) renderBashContent() string { + if pr, ok := p.permission.Params.(tools.BashPermissionsParams); ok { + content := fmt.Sprintf("```bash\n%s\n```", pr.Command) - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(p.width-10), + ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) + + finalContent := styles.BaseStyle. + Width(p.contentViewPort.Width). + Render(renderedContent) + p.contentViewPort.SetContent(finalContent) + return p.styleViewport() + } + return "" +} - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } +func (p *permissionDialogCmp) renderEditContent() string { + if pr, ok := p.permission.Params.(tools.EditPermissionsParams); ok { + diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { + return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) + }) - contentFinal := contentStyle.Render(p.contentViewPort.View()) + p.contentViewPort.SetContent(diff) + return p.styleViewport() + } + return "" +} - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) +func (p *permissionDialogCmp) renderWriteContent() string { + if pr, ok := p.permission.Params.(tools.WritePermissionsParams); ok { + // Use the cache for diff rendering + diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { + return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) + }) - case tools.WriteToolName: - pr := p.permission.Params.(tools.WritePermissionsParams) - headerParts = append(headerParts, keyStyle.Render("Content")) - // Recreate header content with the updated headerParts - headerContent = lipgloss.NewStyle().Padding(0, 1).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...)) - - // Set up viewport for the content - p.contentViewPort.Width = p.width - 2 - 2 - - // Calculate content height dynamically based on window size - maxContentHeight := p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.Height = maxContentHeight - diff, err := diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) - if err != nil { - diff = fmt.Sprintf("Error formatting diff: %v", err) - } p.contentViewPort.SetContent(diff) + return p.styleViewport() + } + return "" +} - // Style the viewport - var contentBorder lipgloss.Border - var borderColor lipgloss.TerminalColor +func (p *permissionDialogCmp) renderFetchContent() string { + if pr, ok := p.permission.Params.(tools.FetchPermissionsParams); ok { + content := fmt.Sprintf("```bash\n%s\n```", pr.URL) - if p.isViewportFocus { - contentBorder = lipgloss.DoubleBorder() - borderColor = styles.Blue - } else { - contentBorder = lipgloss.RoundedBorder() - borderColor = styles.Flamingo - } - - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(contentBorder). - BorderForeground(borderColor) + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(p.width-10), + ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) - if p.isViewportFocus { - contentStyle = contentStyle.BorderBackground(styles.Surface0) - } + p.contentViewPort.SetContent(renderedContent) + return p.styleViewport() + } + return "" +} - contentFinal := contentStyle.Render(p.contentViewPort.View()) +func (p *permissionDialogCmp) renderDefaultContent() string { + content := p.permission.Description - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, + // Use the cache for markdown rendering + renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.CatppuccinMarkdownStyle()), + glamour.WithWordWrap(p.width-10), ) + s, err := r.Render(content) + return styles.ForceReplaceBackgroundWithLipgloss(s, styles.Background), err + }) - case tools.FetchToolName: - pr := p.permission.Params.(tools.FetchPermissionsParams) - headerParts = append(headerParts, keyStyle.Render("URL: "+pr.URL)) - content := p.permission.Description + p.contentViewPort.SetContent(renderedContent) - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.SetContent(renderedContent) + if renderedContent == "" { + return "" + } - // Style the viewport - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Flamingo) + return p.styleViewport() +} - contentFinal := contentStyle.Render(p.contentViewPort.View()) - if renderedContent == "" { - contentFinal = "" - } +func (p *permissionDialogCmp) styleViewport() string { + contentStyle := lipgloss.NewStyle(). + Background(styles.Background) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, - ) + return contentStyle.Render(p.contentViewPort.View()) +} +func (p *permissionDialogCmp) render() string { + title := styles.BaseStyle. + Bold(true). + Width(p.width - 4). + Foreground(styles.PrimaryColor). + Render("Permission Required") + // Render header + headerContent := p.renderHeader() + // Render buttons + buttons := p.renderButtons() + + // Calculate content height dynamically based on window size + p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(buttons) - 2 - lipgloss.Height(title) + p.contentViewPort.Width = p.width - 4 + + // Render content based on tool type + var contentFinal string + switch p.permission.ToolName { + case tools.BashToolName: + contentFinal = p.renderBashContent() + case tools.EditToolName: + contentFinal = p.renderEditContent() + case tools.WriteToolName: + contentFinal = p.renderWriteContent() + case tools.FetchToolName: + contentFinal = p.renderFetchContent() default: - content := p.permission.Description - - renderedContent, _ := r.Render(content) - p.contentViewPort.Width = p.width - 2 - 2 - p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(form) - 2 - 2 - 1 - p.contentViewPort.SetContent(renderedContent) - - // Style the viewport - contentStyle := lipgloss.NewStyle(). - MarginTop(1). - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Flamingo) + contentFinal = p.renderDefaultContent() + } - contentFinal := contentStyle.Render(p.contentViewPort.View()) - if renderedContent == "" { - contentFinal = "" - } + content := lipgloss.JoinVertical( + lipgloss.Top, + title, + styles.BaseStyle.Render(strings.Repeat(" ", lipgloss.Width(title))), + headerContent, + contentFinal, + buttons, + ) - return lipgloss.JoinVertical( - lipgloss.Top, - headerContent, - contentFinal, - form, + return styles.BaseStyle. + Padding(1, 0, 0, 1). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(p.width). + Height(p.height). + Render( + content, ) - } } func (p *permissionDialogCmp) View() string { return p.render() } -func (p *permissionDialogCmp) GetSize() (int, int) { - return p.width, p.height +func (p *permissionDialogCmp) BindingKeys() []key.Binding { + return layout.KeyMapToSlice(helpKeys) } -func (p *permissionDialogCmp) SetSize(width int, height int) { - p.width = width - p.height = height - p.form = p.form.WithWidth(width) +func (p *permissionDialogCmp) SetSize() { + if p.permission.ID == "" { + return + } + switch p.permission.ToolName { + case tools.BashToolName: + p.width = int(float64(p.windowSize.Width) * 0.4) + p.height = int(float64(p.windowSize.Height) * 0.3) + case tools.EditToolName: + p.width = int(float64(p.windowSize.Width) * 0.8) + p.height = int(float64(p.windowSize.Height) * 0.8) + case tools.WriteToolName: + p.width = int(float64(p.windowSize.Width) * 0.8) + p.height = int(float64(p.windowSize.Height) * 0.8) + case tools.FetchToolName: + p.width = int(float64(p.windowSize.Width) * 0.4) + p.height = int(float64(p.windowSize.Height) * 0.3) + default: + p.width = int(float64(p.windowSize.Width) * 0.7) + p.height = int(float64(p.windowSize.Height) * 0.5) + } } -func (p *permissionDialogCmp) BindingKeys() []key.Binding { - return p.form.KeyBinds() +func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) { + p.permission = permission + p.SetSize() } -func newPermissionDialogCmp(permission permission.PermissionRequest) PermissionDialog { - // Create a note field for displaying the content +// Helper to get or set cached diff content +func (c *permissionDialogCmp) GetOrSetDiff(key string, generator func() (string, error)) string { + if cached, ok := c.diffCache[key]; ok { + return cached + } - // Create select field for the permission options - selectOption := huh.NewSelect[string](). - Key("action"). - Options( - huh.NewOption("Allow", string(PermissionAllow)), - huh.NewOption("Allow for this session", string(PermissionAllowForSession)), - huh.NewOption("Deny", string(PermissionDeny)), - ). - Title("Select an action") + content, err := generator() + if err != nil { + return fmt.Sprintf("Error formatting diff: %v", err) + } - // Apply theme - theme := styles.HuhTheme() + c.diffCache[key] = content - // Setup form width and height - form := huh.NewForm(huh.NewGroup(selectOption)). - WithShowHelp(false). - WithTheme(theme). - WithShowErrors(false) + return content +} - // Focus the form for immediate interaction - selectOption.Focus() +// Helper to get or set cached markdown content +func (c *permissionDialogCmp) GetOrSetMarkdown(key string, generator func() (string, error)) string { + if cached, ok := c.markdownCache[key]; ok { + return cached + } - return &permissionDialogCmp{ - permission: permission, - form: form, - selectOption: selectOption, + content, err := generator() + if err != nil { + return fmt.Sprintf("Error rendering markdown: %v", err) } -} -// NewPermissionDialogCmd creates a new permission dialog command -func NewPermissionDialogCmd(permission permission.PermissionRequest) tea.Cmd { - permDialog := newPermissionDialogCmp(permission) - - // Create the dialog layout - dialogPane := layout.NewSinglePane( - permDialog.(*permissionDialogCmp), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneActiveColor(styles.Warning), - layout.WithSinglePaneBorderText(map[layout.BorderPosition]string{ - layout.TopMiddleBorder: " Permission Required ", - }), - ) + c.markdownCache[key] = content - // Focus the dialog - dialogPane.Focus() - widthRatio := 0.7 - heightRatio := 0.6 - minWidth := 100 - minHeight := 30 + return content +} - // Make the dialog size more appropriate for different tools - switch permission.ToolName { - case tools.BashToolName: - // For bash commands, use a more compact dialog - widthRatio = 0.7 - heightRatio = 0.4 // Reduced from 0.5 - minWidth = 100 - minHeight = 20 // Reduced from 30 +func NewPermissionDialogCmp() PermissionDialogCmp { + // Create viewport for content + contentViewport := viewport.New(0, 0) + + return &permissionDialogCmp{ + contentViewPort: contentViewport, + selectedOption: 0, // Default to "Allow" + diffCache: make(map[string]string), + markdownCache: make(map[string]string), } - // Return the dialog command - return util.CmdHandler(core.DialogMsg{ - Content: dialogPane, - WidthRatio: widthRatio, - HeightRatio: heightRatio, - MinWidth: minWidth, - MinHeight: minHeight, - }) } diff --git a/internal/tui/components/dialog/quit.go b/internal/tui/components/dialog/quit.go index 60c1fc0d27503235a9444d52a48dd4c8bdfc66b0..10d9ba8a2c363168fc6bf0048451c899c367d9cc 100644 --- a/internal/tui/components/dialog/quit.go +++ b/internal/tui/components/dialog/quit.go @@ -1,28 +1,58 @@ package dialog import ( + "strings" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" + "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" - - "github.com/charmbracelet/huh" ) const question = "Are you sure you want to quit?" +type CloseQuitMsg struct{} + type QuitDialog interface { tea.Model - layout.Sizeable layout.Bindings } type quitDialogCmp struct { - form *huh.Form - width int - height int + selectedNo bool +} + +type helpMapping struct { + LeftRight key.Binding + EnterSpace key.Binding + Yes key.Binding + No key.Binding + Tab key.Binding +} + +var helpKeys = helpMapping{ + LeftRight: key.NewBinding( + key.WithKeys("left", "right"), + key.WithHelp("←/→", "switch options"), + ), + EnterSpace: key.NewBinding( + key.WithKeys("enter", " "), + key.WithHelp("enter/space", "confirm"), + ), + Yes: key.NewBinding( + key.WithKeys("y", "Y"), + key.WithHelp("y/Y", "yes"), + ), + No: key.NewBinding( + key.WithKeys("n", "N"), + key.WithHelp("n/N", "no"), + ), + Tab: key.NewBinding( + key.WithKeys("tab"), + key.WithHelp("tab", "switch options"), + ), } func (q *quitDialogCmp) Init() tea.Cmd { @@ -30,77 +60,73 @@ func (q *quitDialogCmp) Init() tea.Cmd { } func (q *quitDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - form, cmd := q.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - q.form = f - cmds = append(cmds, cmd) - } - - if q.form.State == huh.StateCompleted { - v := q.form.GetBool("quit") - if v { + switch msg := msg.(type) { + case tea.KeyMsg: + switch { + case key.Matches(msg, helpKeys.LeftRight) || key.Matches(msg, helpKeys.Tab): + q.selectedNo = !q.selectedNo + return q, nil + case key.Matches(msg, helpKeys.EnterSpace): + if !q.selectedNo { + return q, tea.Quit + } + return q, util.CmdHandler(CloseQuitMsg{}) + case key.Matches(msg, helpKeys.Yes): return q, tea.Quit + case key.Matches(msg, helpKeys.No): + return q, util.CmdHandler(CloseQuitMsg{}) } - cmds = append(cmds, util.CmdHandler(core.DialogCloseMsg{})) } - - return q, tea.Batch(cmds...) + return q, nil } func (q *quitDialogCmp) View() string { - return q.form.View() -} + yesStyle := styles.BaseStyle + noStyle := styles.BaseStyle + spacerStyle := styles.BaseStyle.Background(styles.Background) + + if q.selectedNo { + noStyle = noStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + yesStyle = yesStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + } else { + yesStyle = yesStyle.Background(styles.PrimaryColor).Foreground(styles.Background) + noStyle = noStyle.Background(styles.Background).Foreground(styles.PrimaryColor) + } -func (q *quitDialogCmp) GetSize() (int, int) { - return q.width, q.height -} + yesButton := yesStyle.Padding(0, 1).Render("Yes") + noButton := noStyle.Padding(0, 1).Render("No") + + buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, spacerStyle.Render(" "), noButton) + + width := lipgloss.Width(question) + remainingWidth := width - lipgloss.Width(buttons) + if remainingWidth > 0 { + buttons = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + buttons + } -func (q *quitDialogCmp) SetSize(width int, height int) { - q.width = width - q.height = height - q.form = q.form.WithWidth(width).WithHeight(height) + content := styles.BaseStyle.Render( + lipgloss.JoinVertical( + lipgloss.Center, + question, + "", + buttons, + ), + ) + + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(lipgloss.Width(content) + 4). + Render(content) } func (q *quitDialogCmp) BindingKeys() []key.Binding { - return q.form.KeyBinds() + return layout.KeyMapToSlice(helpKeys) } -func newQuitDialogCmp() QuitDialog { - confirm := huh.NewConfirm(). - Title(question). - Affirmative("Yes!"). - Key("quit"). - Negative("No.") - - theme := styles.HuhTheme() - theme.Focused.FocusedButton = theme.Focused.FocusedButton.Background(styles.Warning) - theme.Blurred.FocusedButton = theme.Blurred.FocusedButton.Background(styles.Warning) - form := huh.NewForm(huh.NewGroup(confirm)). - WithShowHelp(false). - WithWidth(0). - WithHeight(0). - WithTheme(theme). - WithShowErrors(false) - confirm.Focus() +func NewQuitCmp() QuitDialog { return &quitDialogCmp{ - form: form, + selectedNo: true, } } - -func NewQuitDialogCmd() tea.Cmd { - content := layout.NewSinglePane( - newQuitDialogCmp().(*quitDialogCmp), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneActiveColor(styles.Warning), - ) - content.Focus() - return util.CmdHandler(core.DialogMsg{ - Content: content, - WidthRatio: 0.2, - HeightRatio: 0.1, - MinWidth: 40, - MinHeight: 5, - }) -} diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index dbace5508d253220df39862ae0e1992ed6616d83..18eb1a526806779fa0ec20d2501ec122e8c1a0e4 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -16,10 +16,8 @@ import ( type DetailComponent interface { tea.Model - layout.Focusable layout.Sizeable layout.Bindings - layout.Bordered } type detailCmp struct { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 9500059b1e6052dd95c9b11bf58be4a21bd7ac25..6e8eb58b13772d294da3fb28eeb3fa4020c44402 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -16,22 +16,14 @@ import ( type TableComponent interface { tea.Model - layout.Focusable layout.Sizeable layout.Bindings - layout.Bordered } type tableCmp struct { table table.Model } -func (i *tableCmp) BorderText() map[layout.BorderPosition]string { - return map[layout.BorderPosition]string{ - layout.TopLeftBorder: "Logs", - } -} - type selectedLogMsg logging.LogMessage func (i *tableCmp) Init() tea.Cmd { @@ -74,20 +66,6 @@ func (i *tableCmp) View() string { return i.table.View() } -func (i *tableCmp) Blur() tea.Cmd { - i.table.Blur() - return nil -} - -func (i *tableCmp) Focus() tea.Cmd { - i.table.Focus() - return nil -} - -func (i *tableCmp) IsFocused() bool { - return i.table.Focused() -} - func (i *tableCmp) GetSize() (int, int) { return i.table.Width(), i.table.Height() } diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go deleted file mode 100644 index b659775e0e6fc92676fa66dbd6a7da2a904a189b..0000000000000000000000000000000000000000 --- a/internal/tui/components/repl/editor.go +++ /dev/null @@ -1,201 +0,0 @@ -package repl - -import ( - "strings" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/vimtea" - "golang.org/x/net/context" -) - -type EditorCmp interface { - tea.Model - layout.Focusable - layout.Sizeable - layout.Bordered - layout.Bindings -} - -type editorCmp struct { - app *app.App - editor vimtea.Editor - editorMode vimtea.EditorMode - sessionID string - focused bool - width int - height int - cancelMessage context.CancelFunc -} - -type editorKeyMap struct { - SendMessage key.Binding - SendMessageI key.Binding - CancelMessage key.Binding - InsertMode key.Binding - NormaMode key.Binding - VisualMode key.Binding - VisualLineMode key.Binding -} - -var editorKeyMapValue = editorKeyMap{ - SendMessage: key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("enter", "send message normal mode"), - ), - SendMessageI: key.NewBinding( - key.WithKeys("ctrl+s"), - key.WithHelp("ctrl+s", "send message insert mode"), - ), - CancelMessage: key.NewBinding( - key.WithKeys("ctrl+x"), - key.WithHelp("ctrl+x", "cancel current message"), - ), - InsertMode: key.NewBinding( - key.WithKeys("i"), - key.WithHelp("i", "insert mode"), - ), - NormaMode: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "normal mode"), - ), - VisualMode: key.NewBinding( - key.WithKeys("v"), - key.WithHelp("v", "visual mode"), - ), - VisualLineMode: key.NewBinding( - key.WithKeys("V"), - key.WithHelp("V", "visual line mode"), - ), -} - -func (m *editorCmp) Init() tea.Cmd { - return m.editor.Init() -} - -func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case vimtea.EditorModeMsg: - m.editorMode = msg.Mode - case SelectedSessionMsg: - if msg.SessionID != m.sessionID { - m.sessionID = msg.SessionID - } - } - if m.IsFocused() { - switch msg := msg.(type) { - case tea.KeyMsg: - switch { - case key.Matches(msg, editorKeyMapValue.SendMessage): - if m.editorMode == vimtea.ModeNormal { - return m, m.Send() - } - case key.Matches(msg, editorKeyMapValue.SendMessageI): - if m.editorMode == vimtea.ModeInsert { - return m, m.Send() - } - case key.Matches(msg, editorKeyMapValue.CancelMessage): - return m, m.Cancel() - } - } - u, cmd := m.editor.Update(msg) - m.editor = u.(vimtea.Editor) - return m, cmd - } - return m, nil -} - -func (m *editorCmp) Blur() tea.Cmd { - m.focused = false - return nil -} - -func (m *editorCmp) BorderText() map[layout.BorderPosition]string { - title := "New Message" - if m.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - return map[layout.BorderPosition]string{ - layout.BottomLeftBorder: title, - } -} - -func (m *editorCmp) Focus() tea.Cmd { - m.focused = true - return m.editor.Tick() -} - -func (m *editorCmp) GetSize() (int, int) { - return m.width, m.height -} - -func (m *editorCmp) IsFocused() bool { - return m.focused -} - -func (m *editorCmp) SetSize(width int, height int) { - m.width = width - m.height = height - m.editor.SetSize(width, height) -} - -func (m *editorCmp) Cancel() tea.Cmd { - if m.cancelMessage == nil { - return util.ReportWarn("No message to cancel") - } - - m.cancelMessage() - m.cancelMessage = nil - return util.ReportWarn("Message cancelled") -} - -func (m *editorCmp) Send() tea.Cmd { - if m.cancelMessage != nil { - return util.ReportWarn("Assistant is still working on the previous message") - } - - messages, err := m.app.Messages.List(context.Background(), m.sessionID) - if err != nil { - return util.ReportError(err) - } - if hasUnfinishedMessages(messages) { - return util.ReportWarn("Assistant is still working on the previous message") - } - - content := strings.Join(m.editor.GetBuffer().Lines(), "\n") - if len(content) == 0 { - return util.ReportWarn("Message is empty") - } - ctx, cancel := context.WithCancel(context.Background()) - m.cancelMessage = cancel - go func() { - defer cancel() - m.app.CoderAgent.Generate(ctx, m.sessionID, content) - m.cancelMessage = nil - }() - - return m.editor.Reset() -} - -func (m *editorCmp) View() string { - return m.editor.View() -} - -func (m *editorCmp) BindingKeys() []key.Binding { - return layout.KeyMapToSlice(editorKeyMapValue) -} - -func NewEditorCmp(app *app.App) EditorCmp { - editor := vimtea.NewEditor( - vimtea.WithFileName("message.md"), - ) - return &editorCmp{ - app: app, - editor: editor, - } -} diff --git a/internal/tui/components/repl/messages.go b/internal/tui/components/repl/messages.go deleted file mode 100644 index 260be220e82fc6d03928db950f2b244cb0ae7e8b..0000000000000000000000000000000000000000 --- a/internal/tui/components/repl/messages.go +++ /dev/null @@ -1,513 +0,0 @@ -package repl - -import ( - "context" - "encoding/json" - "fmt" - "sort" - "strings" - - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/viewport" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/glamour" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" -) - -type MessagesCmp interface { - tea.Model - layout.Focusable - layout.Bordered - layout.Sizeable - layout.Bindings -} - -type messagesCmp struct { - app *app.App - messages []message.Message - selectedMsgIdx int // Index of the selected message - session session.Session - viewport viewport.Model - mdRenderer *glamour.TermRenderer - width int - height int - focused bool - cachedView string -} - -func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case pubsub.Event[message.Message]: - if msg.Type == pubsub.CreatedEvent { - if msg.Payload.SessionID == m.session.ID { - m.messages = append(m.messages, msg.Payload) - m.renderView() - m.viewport.GotoBottom() - } - for _, v := range m.messages { - for _, c := range v.ToolCalls() { - // the message is being added to the session of a tool called - if c.ID == msg.Payload.SessionID { - m.renderView() - m.viewport.GotoBottom() - } - } - } - } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { - for i, v := range m.messages { - if v.ID == msg.Payload.ID { - m.messages[i] = msg.Payload - m.renderView() - if i == len(m.messages)-1 { - m.viewport.GotoBottom() - } - break - } - } - } - case pubsub.Event[session.Session]: - if msg.Type == pubsub.UpdatedEvent && m.session.ID == msg.Payload.ID { - m.session = msg.Payload - } - case SelectedSessionMsg: - m.session, _ = m.app.Sessions.Get(context.Background(), msg.SessionID) - m.messages, _ = m.app.Messages.List(context.Background(), m.session.ID) - m.renderView() - m.viewport.GotoBottom() - } - if m.focused { - u, cmd := m.viewport.Update(msg) - m.viewport = u - return m, cmd - } - return m, nil -} - -func borderColor(role message.MessageRole) lipgloss.TerminalColor { - switch role { - case message.Assistant: - return styles.Mauve - case message.User: - return styles.Rosewater - } - return styles.Blue -} - -func borderText(msgRole message.MessageRole, currentMessage int) map[layout.BorderPosition]string { - role := "" - icon := "" - switch msgRole { - case message.Assistant: - role = "Assistant" - icon = styles.BotIcon - case message.User: - role = "User" - icon = styles.UserIcon - } - return map[layout.BorderPosition]string{ - layout.TopLeftBorder: lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(styles.Crust). - Background(borderColor(msgRole)). - Render(fmt.Sprintf("%s %s ", role, icon)), - layout.TopRightBorder: lipgloss.NewStyle(). - Padding(0, 1). - Bold(true). - Foreground(styles.Crust). - Background(borderColor(msgRole)). - Render(fmt.Sprintf("#%d ", currentMessage)), - } -} - -func hasUnfinishedMessages(messages []message.Message) bool { - if len(messages) == 0 { - return false - } - for _, msg := range messages { - if !msg.IsFinished() { - return true - } - } - return false -} - -func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.ToolCall, futureMessages []message.Message) string { - allParts := []string{content} - - leftPaddingValue := 4 - connectorStyle := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true) - - toolCallStyle := lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Peach). - Width(m.width-leftPaddingValue-5). - Padding(0, 1) - - toolResultStyle := lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Green). - Width(m.width-leftPaddingValue-5). - Padding(0, 1) - - leftPadding := lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue) - - runningStyle := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true) - - renderTool := func(toolCall message.ToolCall) string { - toolHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Blue). - Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name)) - - var paramLines []string - var args map[string]interface{} - var paramOrder []string - - json.Unmarshal([]byte(toolCall.Input), &args) - - for key := range args { - paramOrder = append(paramOrder, key) - } - sort.Strings(paramOrder) - - for _, name := range paramOrder { - value := args[name] - paramName := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true). - Render(name) - - truncate := m.width - leftPaddingValue*2 - 10 - if len(fmt.Sprintf("%v", value)) > truncate { - value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - paramValue := fmt.Sprintf("%v", value) - paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue)) - } - - paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...) - - toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock) - return toolCallStyle.Render(toolContent) - } - - findToolResult := func(toolCallID string, messages []message.Message) *message.ToolResult { - for _, msg := range messages { - if msg.Role == message.Tool { - for _, result := range msg.ToolResults() { - if result.ToolCallID == toolCallID { - return &result - } - } - } - } - return nil - } - - renderToolResult := func(result message.ToolResult) string { - resultHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Green). - Render(fmt.Sprintf("%s Result", styles.CheckIcon)) - - // Use the same style for both header and border if it's an error - borderColor := styles.Green - if result.IsError { - resultHeader = lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Red). - Render(fmt.Sprintf("%s Error", styles.ErrorIcon)) - borderColor = styles.Red - } - - truncate := 200 - content := result.Content - if len(content) > truncate { - content = content[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - - resultContent := lipgloss.JoinVertical(lipgloss.Left, resultHeader, content) - return toolResultStyle.BorderForeground(borderColor).Render(resultContent) - } - - connector := connectorStyle.Render("└─> Tool Calls:") - allParts = append(allParts, connector) - - for _, toolCall := range tools { - toolOutput := renderTool(toolCall) - allParts = append(allParts, leftPadding.Render(toolOutput)) - - result := findToolResult(toolCall.ID, futureMessages) - if result != nil { - - resultOutput := renderToolResult(*result) - allParts = append(allParts, leftPadding.Render(resultOutput)) - - } else if toolCall.Name == agent.AgentToolName { - - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, leftPadding.Render(runningIndicator)) - taskSessionMessages, _ := m.app.Messages.List(context.Background(), toolCall.ID) - for _, msg := range taskSessionMessages { - if msg.Role == message.Assistant { - for _, toolCall := range msg.ToolCalls() { - toolHeader := lipgloss.NewStyle(). - Bold(true). - Foreground(styles.Blue). - Render(fmt.Sprintf("%s %s", styles.ToolIcon, toolCall.Name)) - - var paramLines []string - var args map[string]interface{} - var paramOrder []string - - json.Unmarshal([]byte(toolCall.Input), &args) - - for key := range args { - paramOrder = append(paramOrder, key) - } - sort.Strings(paramOrder) - - for _, name := range paramOrder { - value := args[name] - paramName := lipgloss.NewStyle(). - Foreground(styles.Peach). - Bold(true). - Render(name) - - truncate := 50 - if len(fmt.Sprintf("%v", value)) > truncate { - value = fmt.Sprintf("%v", value)[:truncate] + lipgloss.NewStyle().Foreground(styles.Blue).Render("... (truncated)") - } - paramValue := fmt.Sprintf("%v", value) - paramLines = append(paramLines, fmt.Sprintf(" %s: %s", paramName, paramValue)) - } - - paramBlock := lipgloss.JoinVertical(lipgloss.Left, paramLines...) - toolContent := lipgloss.JoinVertical(lipgloss.Left, toolHeader, paramBlock) - toolOutput := toolCallStyle.BorderForeground(styles.Teal).MaxWidth(m.width - leftPaddingValue*2 - 2).Render(toolContent) - allParts = append(allParts, lipgloss.NewStyle().Padding(0, 0, 0, leftPaddingValue*2).Render(toolOutput)) - } - } - } - - } else { - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, " "+runningIndicator) - } - } - - for _, msg := range futureMessages { - if msg.Content().String() != "" || msg.FinishReason() == "canceled" { - break - } - - for _, toolCall := range msg.ToolCalls() { - toolOutput := renderTool(toolCall) - allParts = append(allParts, " "+strings.ReplaceAll(toolOutput, "\n", "\n ")) - - result := findToolResult(toolCall.ID, futureMessages) - if result != nil { - resultOutput := renderToolResult(*result) - allParts = append(allParts, " "+strings.ReplaceAll(resultOutput, "\n", "\n ")) - } else { - runningIndicator := runningStyle.Render(fmt.Sprintf("%s Running...", styles.SpinnerIcon)) - allParts = append(allParts, " "+runningIndicator) - } - } - } - - return lipgloss.JoinVertical(lipgloss.Left, allParts...) -} - -func (m *messagesCmp) renderView() { - stringMessages := make([]string, 0) - r, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.CatppuccinMarkdownStyle()), - glamour.WithWordWrap(m.width-20), - glamour.WithEmoji(), - ) - textStyle := lipgloss.NewStyle().Width(m.width - 4) - currentMessage := 1 - displayedMsgCount := 0 // Track the actual displayed messages count - - prevMessageWasUser := false - for inx, msg := range m.messages { - content := msg.Content().String() - if content != "" || prevMessageWasUser || msg.FinishReason() == "canceled" { - if msg.ReasoningContent().String() != "" && content == "" { - content = msg.ReasoningContent().String() - } else if content == "" { - content = "..." - } - if msg.FinishReason() == "canceled" { - content, _ = r.Render(content) - content += lipgloss.NewStyle().Padding(1, 0, 0, 1).Foreground(styles.Error).Render(styles.ErrorIcon + " Canceled") - } else { - content, _ = r.Render(content) - } - - isSelected := inx == m.selectedMsgIdx - - border := lipgloss.DoubleBorder() - activeColor := borderColor(msg.Role) - - if isSelected { - activeColor = styles.Primary // Use primary color for selected message - } - - content = layout.Borderize( - textStyle.Render(content), - layout.BorderOptions{ - InactiveBorder: border, - ActiveBorder: border, - ActiveColor: activeColor, - InactiveColor: borderColor(msg.Role), - EmbeddedText: borderText(msg.Role, currentMessage), - }, - ) - if len(msg.ToolCalls()) > 0 { - content = m.renderMessageWithToolCall(content, msg.ToolCalls(), m.messages[inx+1:]) - } - stringMessages = append(stringMessages, content) - currentMessage++ - displayedMsgCount++ - } - if msg.Role == message.User && msg.Content().String() != "" { - prevMessageWasUser = true - } else { - prevMessageWasUser = false - } - } - m.viewport.SetContent(lipgloss.JoinVertical(lipgloss.Top, stringMessages...)) -} - -func (m *messagesCmp) View() string { - return lipgloss.NewStyle().Padding(1).Render(m.viewport.View()) -} - -func (m *messagesCmp) BindingKeys() []key.Binding { - keys := layout.KeyMapToSlice(m.viewport.KeyMap) - - return keys -} - -func (m *messagesCmp) Blur() tea.Cmd { - m.focused = false - return nil -} - -func (m *messagesCmp) projectDiagnostics() string { - errorDiagnostics := []protocol.Diagnostic{} - warnDiagnostics := []protocol.Diagnostic{} - hintDiagnostics := []protocol.Diagnostic{} - infoDiagnostics := []protocol.Diagnostic{} - for _, client := range m.app.LSPClients { - for _, d := range client.GetDiagnostics() { - for _, diag := range d { - switch diag.Severity { - case protocol.SeverityError: - errorDiagnostics = append(errorDiagnostics, diag) - case protocol.SeverityWarning: - warnDiagnostics = append(warnDiagnostics, diag) - case protocol.SeverityHint: - hintDiagnostics = append(hintDiagnostics, diag) - case protocol.SeverityInformation: - infoDiagnostics = append(infoDiagnostics, diag) - } - } - } - } - - if len(errorDiagnostics) == 0 && len(warnDiagnostics) == 0 && len(hintDiagnostics) == 0 && len(infoDiagnostics) == 0 { - return "No diagnostics" - } - - diagnostics := []string{} - - if len(errorDiagnostics) > 0 { - errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) - diagnostics = append(diagnostics, errStr) - } - if len(warnDiagnostics) > 0 { - warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) - diagnostics = append(diagnostics, warnStr) - } - if len(hintDiagnostics) > 0 { - hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) - diagnostics = append(diagnostics, hintStr) - } - if len(infoDiagnostics) > 0 { - infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) - diagnostics = append(diagnostics, infoStr) - } - - return strings.Join(diagnostics, " ") -} - -func (m *messagesCmp) BorderText() map[layout.BorderPosition]string { - title := m.session.Title - titleWidth := m.width / 2 - if len(title) > titleWidth { - title = title[:titleWidth] + "..." - } - if m.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - borderTest := map[layout.BorderPosition]string{ - layout.TopLeftBorder: title, - layout.BottomRightBorder: m.projectDiagnostics(), - } - if hasUnfinishedMessages(m.messages) { - borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Peach).Render("Thinking...") - } else { - borderTest[layout.BottomLeftBorder] = lipgloss.NewStyle().Foreground(styles.Text).Render("Sleeping " + styles.SleepIcon + " ") - } - - return borderTest -} - -func (m *messagesCmp) Focus() tea.Cmd { - m.focused = true - return nil -} - -func (m *messagesCmp) GetSize() (int, int) { - return m.width, m.height -} - -func (m *messagesCmp) IsFocused() bool { - return m.focused -} - -func (m *messagesCmp) SetSize(width int, height int) { - m.width = width - m.height = height - m.viewport.Width = width - 2 // padding - m.viewport.Height = height - 2 // padding - m.renderView() -} - -func (m *messagesCmp) Init() tea.Cmd { - return nil -} - -func NewMessagesCmp(app *app.App) MessagesCmp { - return &messagesCmp{ - app: app, - messages: []message.Message{}, - viewport: viewport.New(0, 0), - } -} diff --git a/internal/tui/components/repl/sessions.go b/internal/tui/components/repl/sessions.go deleted file mode 100644 index c83c4036728138675bc08313454743245723b359..0000000000000000000000000000000000000000 --- a/internal/tui/components/repl/sessions.go +++ /dev/null @@ -1,249 +0,0 @@ -package repl - -import ( - "context" - "fmt" - "strings" - - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/list" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" -) - -type SessionsCmp interface { - tea.Model - layout.Sizeable - layout.Focusable - layout.Bordered - layout.Bindings -} -type sessionsCmp struct { - app *app.App - list list.Model - focused bool -} - -type listItem struct { - id, title, desc string -} - -func (i listItem) Title() string { return i.title } -func (i listItem) Description() string { return i.desc } -func (i listItem) FilterValue() string { return i.title } - -type InsertSessionsMsg struct { - sessions []session.Session -} - -type SelectedSessionMsg struct { - SessionID string -} - -type sessionsKeyMap struct { - Select key.Binding -} - -var sessionKeyMapValue = sessionsKeyMap{ - Select: key.NewBinding( - key.WithKeys("enter", " "), - key.WithHelp("enter/space", "select session"), - ), -} - -func (i *sessionsCmp) Init() tea.Cmd { - existing, err := i.app.Sessions.List(context.Background()) - if err != nil { - return util.ReportError(err) - } - if len(existing) == 0 || existing[0].MessageCount > 0 { - newSession, err := i.app.Sessions.Create( - context.Background(), - "New Session", - ) - if err != nil { - return util.ReportError(err) - } - existing = append([]session.Session{newSession}, existing...) - } - return tea.Batch( - util.CmdHandler(InsertSessionsMsg{existing}), - util.CmdHandler(SelectedSessionMsg{existing[0].ID}), - ) -} - -func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case InsertSessionsMsg: - items := make([]list.Item, len(msg.sessions)) - for i, s := range msg.sessions { - items[i] = listItem{ - id: s.ID, - title: s.Title, - desc: formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost), - } - } - return i, i.list.SetItems(items) - case pubsub.Event[session.Session]: - if msg.Type == pubsub.CreatedEvent && msg.Payload.ParentSessionID == "" { - // Check if the session is already in the list - items := i.list.Items() - for _, item := range items { - s := item.(listItem) - if s.id == msg.Payload.ID { - return i, nil - } - } - // insert the new session at the top of the list - items = append([]list.Item{listItem{ - id: msg.Payload.ID, - title: msg.Payload.Title, - desc: formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost), - }}, items...) - return i, i.list.SetItems(items) - } else if msg.Type == pubsub.UpdatedEvent { - // update the session in the list - items := i.list.Items() - for idx, item := range items { - s := item.(listItem) - if s.id == msg.Payload.ID { - s.title = msg.Payload.Title - s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost) - items[idx] = s - break - } - } - return i, i.list.SetItems(items) - } - - case tea.KeyMsg: - switch { - case key.Matches(msg, sessionKeyMapValue.Select): - selected := i.list.SelectedItem() - if selected == nil { - return i, nil - } - return i, util.CmdHandler(SelectedSessionMsg{selected.(listItem).id}) - } - } - if i.focused { - u, cmd := i.list.Update(msg) - i.list = u - return i, cmd - } - return i, nil -} - -func (i *sessionsCmp) View() string { - return i.list.View() -} - -func (i *sessionsCmp) Blur() tea.Cmd { - i.focused = false - return nil -} - -func (i *sessionsCmp) Focus() tea.Cmd { - i.focused = true - return nil -} - -func (i *sessionsCmp) GetSize() (int, int) { - return i.list.Width(), i.list.Height() -} - -func (i *sessionsCmp) IsFocused() bool { - return i.focused -} - -func (i *sessionsCmp) SetSize(width int, height int) { - i.list.SetSize(width, height) -} - -func (i *sessionsCmp) BorderText() map[layout.BorderPosition]string { - totalCount := len(i.list.Items()) - itemsPerPage := i.list.Paginator.PerPage - currentPage := i.list.Paginator.Page - - current := min(currentPage*itemsPerPage+itemsPerPage, totalCount) - - pageInfo := fmt.Sprintf( - "%d-%d of %d", - currentPage*itemsPerPage+1, - current, - totalCount, - ) - - title := "Sessions" - if i.focused { - title = lipgloss.NewStyle().Foreground(styles.Primary).Render(title) - } - return map[layout.BorderPosition]string{ - layout.TopMiddleBorder: title, - layout.BottomMiddleBorder: pageInfo, - } -} - -func (i *sessionsCmp) BindingKeys() []key.Binding { - return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select) -} - -func formatTokensAndCost(tokens int64, cost float64) string { - // Format tokens in human-readable format (e.g., 110K, 1.2M) - var formattedTokens string - switch { - case tokens >= 1_000_000: - formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000) - case tokens >= 1_000: - formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000) - default: - formattedTokens = fmt.Sprintf("%d", tokens) - } - - // Remove .0 suffix if present - if strings.HasSuffix(formattedTokens, ".0K") { - formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1) - } - if strings.HasSuffix(formattedTokens, ".0M") { - formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1) - } - - // Format cost with $ symbol and 2 decimal places - formattedCost := fmt.Sprintf("$%.2f", cost) - - return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost) -} - -func NewSessionsCmp(app *app.App) SessionsCmp { - listDelegate := list.NewDefaultDelegate() - defaultItemStyle := list.NewDefaultItemStyles() - defaultItemStyle.SelectedTitle = defaultItemStyle.SelectedTitle.BorderForeground(styles.Secondary).Foreground(styles.Primary) - defaultItemStyle.SelectedDesc = defaultItemStyle.SelectedDesc.BorderForeground(styles.Secondary).Foreground(styles.Primary) - - defaultStyle := list.DefaultStyles() - defaultStyle.FilterPrompt = defaultStyle.FilterPrompt.Foreground(styles.Secondary) - defaultStyle.FilterCursor = defaultStyle.FilterCursor.Foreground(styles.Flamingo) - - listDelegate.Styles = defaultItemStyle - - listComponent := list.New([]list.Item{}, listDelegate, 0, 0) - listComponent.FilterInput.PromptStyle = defaultStyle.FilterPrompt - listComponent.FilterInput.Cursor.Style = defaultStyle.FilterCursor - listComponent.SetShowTitle(false) - listComponent.SetShowPagination(false) - listComponent.SetShowHelp(false) - listComponent.SetShowStatusBar(false) - listComponent.DisableQuitKeybindings() - - return &sessionsCmp{ - app: app, - list: listComponent, - focused: false, - } -} diff --git a/internal/tui/layout/overlay.go b/internal/tui/layout/overlay.go index 22f9e00fe0594b749b2b489b1831f4a6cb6571cf..4a1bcf661ad14ec7bfb186fce308dc1f3d255a97 100644 --- a/internal/tui/layout/overlay.go +++ b/internal/tui/layout/overlay.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" "github.com/mattn/go-runewidth" "github.com/muesli/ansi" @@ -45,13 +46,15 @@ func PlaceOverlay( if shadow { var shadowbg string = "" shadowchar := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#333333")). + Background(styles.BackgroundDarker). + Foreground(styles.Background). Render("░") + bgchar := styles.BaseStyle.Render(" ") for i := 0; i <= fgHeight; i++ { if i == 0 { - shadowbg += " " + strings.Repeat(" ", fgWidth) + "\n" + shadowbg += bgchar + strings.Repeat(bgchar, fgWidth) + "\n" } else { - shadowbg += " " + strings.Repeat(shadowchar, fgWidth) + "\n" + shadowbg += bgchar + strings.Repeat(shadowchar, fgWidth) + "\n" } } @@ -159,8 +162,6 @@ func max(a, b int) int { return b } - - type whitespace struct { style termenv.Style chars string diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index 0ed85dd6fc20f8859b550c9f4ead13b5876b8eec..6482fc74cebc10149548543eccd4fed67e32db44 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -10,6 +10,7 @@ import ( type SplitPaneLayout interface { tea.Model Sizeable + Bindings SetLeftPanel(panel Container) SetRightPanel(panel Container) SetBottomPanel(panel Container) diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index 439c89e1f683cb6296266db465bdd5237548f692..cebc0e4610ca696c1fefb8b10c271462e0fcbbb3 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -37,7 +37,6 @@ var keyMap = ChatKeyMap{ } func (p *chatPage) Init() tea.Cmd { - // TODO: remove cmds := []tea.Cmd{ p.layout.Init(), } @@ -48,9 +47,7 @@ func (p *chatPage) Init() tea.Cmd { cmd := p.setSidebar() cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(p.session)), cmd) } - return tea.Batch( - cmds..., - ) + return tea.Batch(cmds...) } func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -68,6 +65,13 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.session = session.Session{} p.clearSidebar() return p, util.CmdHandler(chat.SessionClearedMsg{}) + case key.Matches(msg, keyMap.Cancel): + if p.session.ID != "" { + // Cancel the current session's generation process + // This allows users to interrupt long-running operations + p.app.CoderAgent.Cancel(p.session.ID) + return p, nil + } } } u, cmd := p.layout.Update(msg) @@ -80,7 +84,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (p *chatPage) setSidebar() tea.Cmd { sidebarContainer := layout.NewContainer( - chat.NewSidebarCmp(p.session), + chat.NewSidebarCmp(p.session, p.app.History), layout.WithPadding(1, 1, 1, 1), ) p.layout.SetRightPanel(sidebarContainer) @@ -111,14 +115,28 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session))) } - p.app.CoderAgent.Generate(context.Background(), p.session.ID, text) + p.app.CoderAgent.Run(context.Background(), p.session.ID, text) return tea.Batch(cmds...) } +func (p *chatPage) SetSize(width, height int) { + p.layout.SetSize(width, height) +} + +func (p *chatPage) GetSize() (int, int) { + return p.layout.GetSize() +} + func (p *chatPage) View() string { return p.layout.View() } +func (p *chatPage) BindingKeys() []key.Binding { + bindings := layout.KeyMapToSlice(keyMap) + bindings = append(bindings, p.layout.BindingKeys()...) + return bindings +} + func NewChatPage(app *app.App) tea.Model { messagesContainer := layout.NewContainer( chat.NewMessagesCmp(app), @@ -126,7 +144,7 @@ func NewChatPage(app *app.App) tea.Model { ) editorContainer := layout.NewContainer( - chat.NewEditorCmp(), + chat.NewEditorCmp(app), layout.WithBorder(true, false, false, false), ) return &chatPage{ diff --git a/internal/tui/page/init.go b/internal/tui/page/init.go deleted file mode 100644 index 0a5c6f82a522300e52e0a33676ce9eee71796072..0000000000000000000000000000000000000000 --- a/internal/tui/page/init.go +++ /dev/null @@ -1,308 +0,0 @@ -package page - -import ( - "fmt" - "os" - "path/filepath" - "strconv" - - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/huh" - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/spf13/viper" -) - -var InitPage PageID = "init" - -type configSaved struct{} - -type initPage struct { - form *huh.Form - width int - height int - saved bool - errorMsg string - statusMsg string - modelOpts []huh.Option[string] - bigModel string - smallModel string - openAIKey string - anthropicKey string - groqKey string - maxTokens string - dataDir string - agent string -} - -func (i *initPage) Init() tea.Cmd { - return i.form.Init() -} - -func (i *initPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - - switch msg := msg.(type) { - case tea.WindowSizeMsg: - i.width = msg.Width - 4 // Account for border - i.height = msg.Height - 4 - i.form = i.form.WithWidth(i.width).WithHeight(i.height) - return i, nil - - case configSaved: - i.saved = true - i.statusMsg = "Configuration saved successfully. Press any key to continue." - return i, nil - } - - if i.saved { - switch msg.(type) { - case tea.KeyMsg: - return i, util.CmdHandler(PageChangeMsg{ID: ReplPage}) - } - return i, nil - } - - // Process the form - form, cmd := i.form.Update(msg) - if f, ok := form.(*huh.Form); ok { - i.form = f - cmds = append(cmds, cmd) - } - - if i.form.State == huh.StateCompleted { - // Save configuration to file - configPath := filepath.Join(os.Getenv("HOME"), ".termai.yaml") - maxTokens, _ := strconv.Atoi(i.maxTokens) - config := map[string]any{ - "models": map[string]string{ - "big": i.bigModel, - "small": i.smallModel, - }, - "providers": map[string]any{ - "openai": map[string]string{ - "key": i.openAIKey, - }, - "anthropic": map[string]string{ - "key": i.anthropicKey, - }, - "groq": map[string]string{ - "key": i.groqKey, - }, - "common": map[string]int{ - "max_tokens": maxTokens, - }, - }, - "data": map[string]string{ - "dir": i.dataDir, - }, - "agents": map[string]string{ - "default": i.agent, - }, - "log": map[string]string{ - "level": "info", - }, - } - - // Write config to viper - for k, v := range config { - viper.Set(k, v) - } - - // Save configuration - err := viper.WriteConfigAs(configPath) - if err != nil { - i.errorMsg = fmt.Sprintf("Failed to save configuration: %s", err) - return i, nil - } - - // Return to main page - return i, util.CmdHandler(configSaved{}) - } - - return i, tea.Batch(cmds...) -} - -func (i *initPage) View() string { - if i.saved { - return lipgloss.NewStyle(). - Width(i.width). - Height(i.height). - Align(lipgloss.Center, lipgloss.Center). - Render(lipgloss.JoinVertical( - lipgloss.Center, - lipgloss.NewStyle().Foreground(styles.Green).Render("✓ Configuration Saved"), - "", - lipgloss.NewStyle().Foreground(styles.Blue).Render(i.statusMsg), - )) - } - - view := i.form.View() - if i.errorMsg != "" { - errorBox := lipgloss.NewStyle(). - Padding(1). - Border(lipgloss.RoundedBorder()). - BorderForeground(styles.Red). - Width(i.width - 4). - Render(i.errorMsg) - view = lipgloss.JoinVertical(lipgloss.Left, errorBox, view) - } - return view -} - -func (i *initPage) GetSize() (int, int) { - return i.width, i.height -} - -func (i *initPage) SetSize(width int, height int) { - i.width = width - i.height = height - i.form = i.form.WithWidth(width).WithHeight(height) -} - -func (i *initPage) BindingKeys() []key.Binding { - if i.saved { - return []key.Binding{ - key.NewBinding( - key.WithKeys("enter", "space", "esc"), - key.WithHelp("any key", "continue"), - ), - } - } - return i.form.KeyBinds() -} - -func NewInitPage() tea.Model { - // Create model options - var modelOpts []huh.Option[string] - for id, model := range models.SupportedModels { - modelOpts = append(modelOpts, huh.NewOption(model.Name, string(id))) - } - - // Create agent options - agentOpts := []huh.Option[string]{ - huh.NewOption("Coder", "coder"), - huh.NewOption("Assistant", "assistant"), - } - - // Init page with form - initModel := &initPage{ - modelOpts: modelOpts, - bigModel: string(models.Claude37Sonnet), - smallModel: string(models.Claude37Sonnet), - maxTokens: "4000", - dataDir: ".termai", - agent: "coder", - } - - // API Keys group - apiKeysGroup := huh.NewGroup( - huh.NewNote(). - Title("API Keys"). - Description("You need to provide at least one API key to use termai"), - - huh.NewInput(). - Title("OpenAI API Key"). - Placeholder("sk-..."). - Key("openai_key"). - Value(&initModel.openAIKey), - - huh.NewInput(). - Title("Anthropic API Key"). - Placeholder("sk-ant-..."). - Key("anthropic_key"). - Value(&initModel.anthropicKey), - - huh.NewInput(). - Title("Groq API Key"). - Placeholder("gsk_..."). - Key("groq_key"). - Value(&initModel.groqKey), - ) - - // Model configuration group - modelsGroup := huh.NewGroup( - huh.NewNote(). - Title("Model Configuration"). - Description("Select which models to use"), - - huh.NewSelect[string](). - Title("Big Model"). - Options(modelOpts...). - Key("big_model"). - Value(&initModel.bigModel), - - huh.NewSelect[string](). - Title("Small Model"). - Options(modelOpts...). - Key("small_model"). - Value(&initModel.smallModel), - - huh.NewInput(). - Title("Max Tokens"). - Placeholder("4000"). - Key("max_tokens"). - CharLimit(5). - Validate(func(s string) error { - var n int - _, err := fmt.Sscanf(s, "%d", &n) - if err != nil || n <= 0 { - return fmt.Errorf("must be a positive number") - } - initModel.maxTokens = s - return nil - }). - Value(&initModel.maxTokens), - ) - - // General settings group - generalGroup := huh.NewGroup( - huh.NewNote(). - Title("General Settings"). - Description("Configure general termai settings"), - - huh.NewInput(). - Title("Data Directory"). - Placeholder(".termai"). - Key("data_dir"). - Value(&initModel.dataDir), - - huh.NewSelect[string](). - Title("Default Agent"). - Options(agentOpts...). - Key("agent"). - Value(&initModel.agent), - - huh.NewConfirm(). - Title("Save Configuration"). - Affirmative("Save"). - Negative("Cancel"), - ) - - // Create form with theme - form := huh.NewForm( - apiKeysGroup, - modelsGroup, - generalGroup, - ).WithTheme(styles.HuhTheme()). - WithShowHelp(true). - WithShowErrors(true) - - // Set the form in the model - initModel.form = form - - return layout.NewSinglePane( - initModel, - layout.WithSinglePaneFocusable(true), - layout.WithSinglePaneBordered(true), - layout.WithSinglePaneBorderText( - map[layout.BorderPosition]string{ - layout.TopMiddleBorder: "Welcome to termai - Initial Setup", - }, - ), - ) -} diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index 12afaf6aa6130919b12ce380b46be67020784bdb..d1e557eab548673ac76ba59cdc716a7dc2410c23 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -8,6 +8,23 @@ import ( var LogsPage PageID = "logs" +type logsPage struct { + table logs.TableComponent + details logs.DetailComponent +} + +func (p *logsPage) Init() tea.Cmd { + return nil +} + +func (p *logsPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + return p, nil +} + +func (p *logsPage) View() string { + return p.table.View() + "\n" + p.details.View() +} + func NewLogsPage() tea.Model { return layout.NewBentoLayout( layout.BentoPanes{ diff --git a/internal/tui/page/repl.go b/internal/tui/page/repl.go deleted file mode 100644 index 47a924b7b1ef22c107468b350bed47263971b2d1..0000000000000000000000000000000000000000 --- a/internal/tui/page/repl.go +++ /dev/null @@ -1,21 +0,0 @@ -package page - -import ( - tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/tui/components/repl" - "github.com/kujtimiihoxha/termai/internal/tui/layout" -) - -var ReplPage PageID = "repl" - -func NewReplPage(app *app.App) tea.Model { - return layout.NewBentoLayout( - layout.BentoPanes{ - layout.BentoLeftPane: repl.NewSessionsCmp(app), - layout.BentoRightTopPane: repl.NewMessagesCmp(app), - layout.BentoRightBottomPane: repl.NewEditorCmp(app), - }, - layout.WithBentoLayoutCurrentPane(layout.BentoRightBottomPane), - ) -} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 1b1a1ed50f97943f2388326583bab9798b55b9c1..dff7ad63d11e61463b659bea70f61d79b57c4e1c 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,8 +1,6 @@ package tui import ( - "context" - "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -12,47 +10,41 @@ import ( "github.com/kujtimiihoxha/termai/internal/pubsub" "github.com/kujtimiihoxha/termai/internal/tui/components/core" "github.com/kujtimiihoxha/termai/internal/tui/components/dialog" - "github.com/kujtimiihoxha/termai/internal/tui/components/repl" "github.com/kujtimiihoxha/termai/internal/tui/layout" "github.com/kujtimiihoxha/termai/internal/tui/page" "github.com/kujtimiihoxha/termai/internal/tui/util" - "github.com/kujtimiihoxha/vimtea" ) type keyMap struct { - Logs key.Binding - Return key.Binding - Back key.Binding - Quit key.Binding - Help key.Binding + Logs key.Binding + Quit key.Binding + Help key.Binding } var keys = keyMap{ Logs: key.NewBinding( - key.WithKeys("L"), - key.WithHelp("L", "logs"), - ), - Return: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "close"), - ), - Back: key.NewBinding( - key.WithKeys("backspace"), - key.WithHelp("backspace", "back"), + key.WithKeys("ctrl+l"), + key.WithHelp("ctrl+L", "logs"), ), + Quit: key.NewBinding( - key.WithKeys("ctrl+c", "q"), - key.WithHelp("ctrl+c/q", "quit"), + key.WithKeys("ctrl+c"), + key.WithHelp("ctrl+c", "quit"), ), Help: key.NewBinding( - key.WithKeys("?"), - key.WithHelp("?", "toggle help"), + key.WithKeys("ctrl+_"), + key.WithHelp("ctrl+?", "toggle help"), ), } -var replKeyMap = key.NewBinding( - key.WithKeys("N"), - key.WithHelp("N", "new session"), +var returnKey = key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "close"), +) + +var logsKeyReturnKey = key.NewBinding( + key.WithKeys("backspace"), + key.WithHelp("backspace", "go back"), ) type appModel struct { @@ -62,18 +54,30 @@ type appModel struct { pages map[page.PageID]tea.Model loadedPages map[page.PageID]bool status tea.Model - help core.HelpCmp - dialog core.DialogCmp app *app.App - dialogVisible bool - editorMode vimtea.EditorMode - showHelp bool + + showPermissions bool + permissions dialog.PermissionDialogCmp + + showHelp bool + help dialog.HelpCmp + + showQuit bool + quit dialog.QuitDialog } func (a appModel) Init() tea.Cmd { + var cmds []tea.Cmd cmd := a.pages[a.currentPage].Init() a.loadedPages[a.currentPage] = true - return cmd + cmds = append(cmds, cmd) + cmd = a.status.Init() + cmds = append(cmds, cmd) + cmd = a.quit.Init() + cmds = append(cmds, cmd) + cmd = a.help.Init() + cmds = append(cmds, cmd) + return tea.Batch(cmds...) } func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -81,22 +85,20 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: - var cmds []tea.Cmd msg.Height -= 1 // Make space for the status bar a.width, a.height = msg.Width, msg.Height a.status, _ = a.status.Update(msg) - - uh, _ := a.help.Update(msg) - a.help = uh.(core.HelpCmp) - - p, cmd := a.pages[a.currentPage].Update(msg) + a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) - a.pages[a.currentPage] = p - d, cmd := a.dialog.Update(msg) - cmds = append(cmds, cmd) - a.dialog = d.(core.DialogCmp) + prm, permCmd := a.permissions.Update(msg) + a.permissions = prm.(dialog.PermissionDialogCmp) + cmds = append(cmds, permCmd) + + help, helpCmd := a.help.Update(msg) + a.help = help.(dialog.HelpCmp) + cmds = append(cmds, helpCmd) return a, tea.Batch(cmds...) @@ -141,7 +143,9 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Permission case pubsub.Event[permission.PermissionRequest]: - return a, dialog.NewPermissionDialogCmd(msg.Payload) + a.showPermissions = true + a.permissions.SetPermissions(msg.Payload) + return a, nil case dialog.PermissionResponseMsg: switch msg.Action { case dialog.PermissionAllow: @@ -151,91 +155,71 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case dialog.PermissionDeny: a.app.Permissions.Deny(msg.Permission) } - - // Dialog - case core.DialogMsg: - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - a.dialogVisible = true - return a, cmd - case core.DialogCloseMsg: - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - a.dialogVisible = false - return a, cmd - - // Editor - case vimtea.EditorModeMsg: - a.editorMode = msg.Mode + a.showPermissions = false + return a, nil case page.PageChangeMsg: return a, a.moveToPage(msg.ID) + + case dialog.CloseQuitMsg: + a.showQuit = false + return a, nil + case tea.KeyMsg: - if a.editorMode == vimtea.ModeNormal { - switch { - case key.Matches(msg, keys.Quit): - return a, dialog.NewQuitDialogCmd() - case key.Matches(msg, keys.Back): - if a.previousPage != "" { - return a, a.moveToPage(a.previousPage) - } - case key.Matches(msg, keys.Return): - if a.showHelp { - a.ToggleHelp() - return a, nil - } - case key.Matches(msg, replKeyMap): - if a.currentPage == page.ReplPage { - sessions, err := a.app.Sessions.List(context.Background()) - if err != nil { - return a, util.CmdHandler(util.ReportError(err)) - } - lastSession := sessions[0] - if lastSession.MessageCount == 0 { - return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: lastSession.ID}) - } - s, err := a.app.Sessions.Create(context.Background(), "New Session") - if err != nil { - return a, util.CmdHandler(util.ReportError(err)) - } - return a, util.CmdHandler(repl.SelectedSessionMsg{SessionID: s.ID}) - } - // case key.Matches(msg, keys.Logs): - // return a, a.moveToPage(page.LogsPage) - case msg.String() == "O": - return a, a.moveToPage(page.ReplPage) - case key.Matches(msg, keys.Help): - a.ToggleHelp() + switch { + case key.Matches(msg, keys.Quit): + a.showQuit = !a.showQuit + if a.showHelp { + a.showHelp = false + } + return a, nil + case key.Matches(msg, logsKeyReturnKey): + if a.currentPage == page.LogsPage { + return a, a.moveToPage(page.ChatPage) + } + case key.Matches(msg, returnKey): + if a.showQuit { + a.showQuit = !a.showQuit + return a, nil + } + if a.showHelp { + a.showHelp = !a.showHelp + return a, nil + } + case key.Matches(msg, keys.Logs): + return a, a.moveToPage(page.LogsPage) + case key.Matches(msg, keys.Help): + if a.showQuit { return a, nil } + a.showHelp = !a.showHelp + return a, nil } } - if a.dialogVisible { - d, cmd := a.dialog.Update(msg) - a.dialog = d.(core.DialogCmp) - cmds = append(cmds, cmd) - return a, tea.Batch(cmds...) + if a.showQuit { + q, quitCmd := a.quit.Update(msg) + a.quit = q.(dialog.QuitDialog) + cmds = append(cmds, quitCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } + } + if a.showPermissions { + d, permissionsCmd := a.permissions.Update(msg) + a.permissions = d.(dialog.PermissionDialogCmp) + cmds = append(cmds, permissionsCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } } a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) return a, tea.Batch(cmds...) } -func (a *appModel) ToggleHelp() { - if a.showHelp { - a.showHelp = false - a.height += a.help.Height() - } else { - a.showHelp = true - a.height -= a.help.Height() - } - - if sizable, ok := a.pages[a.currentPage].(layout.Sizeable); ok { - sizable.SetSize(a.width, a.height) - } -} - func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { var cmd tea.Cmd if _, ok := a.loadedPages[pageID]; !ok { @@ -256,27 +240,55 @@ func (a appModel) View() string { a.pages[a.currentPage].View(), } + components = append(components, a.status.View()) + + appView := lipgloss.JoinVertical(lipgloss.Top, components...) + + if a.showPermissions { + overlay := a.permissions.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } + if a.showHelp { bindings := layout.KeyMapToSlice(keys) if p, ok := a.pages[a.currentPage].(layout.Bindings); ok { bindings = append(bindings, p.BindingKeys()...) } - if a.dialogVisible { - bindings = append(bindings, a.dialog.BindingKeys()...) + if a.showPermissions { + bindings = append(bindings, a.permissions.BindingKeys()...) } - if a.currentPage == page.ReplPage { - bindings = append(bindings, replKeyMap) + if a.currentPage == page.LogsPage { + bindings = append(bindings, logsKeyReturnKey) } - a.help.SetBindings(bindings) - components = append(components, a.help.View()) - } - components = append(components, a.status.View()) + a.help.SetBindings(bindings) - appView := lipgloss.JoinVertical(lipgloss.Top, components...) + overlay := a.help.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } - if a.dialogVisible { - overlay := a.dialog.View() + if a.showQuit { + overlay := a.quit.View() row := lipgloss.Height(appView) / 2 row -= lipgloss.Height(overlay) / 2 col := lipgloss.Width(appView) / 2 @@ -289,30 +301,23 @@ func (a appModel) View() string { true, ) } + return appView } func New(app *app.App) tea.Model { - // homedir, _ := os.UserHomeDir() - // configPath := filepath.Join(homedir, ".termai.yaml") - // startPage := page.ChatPage - // if _, err := os.Stat(configPath); os.IsNotExist(err) { - // startPage = page.InitPage - // } - return &appModel{ currentPage: startPage, loadedPages: make(map[page.PageID]bool), - status: core.NewStatusCmp(), - help: core.NewHelpCmp(), - dialog: core.NewDialogCmp(), + status: core.NewStatusCmp(app.LSPClients), + help: dialog.NewHelpCmp(), + quit: dialog.NewQuitCmp(), + permissions: dialog.NewPermissionDialogCmp(), app: app, pages: map[page.PageID]tea.Model{ page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), - page.InitPage: page.NewInitPage(), - page.ReplPage: page.NewReplPage(app), }, } } diff --git a/main.go b/main.go index 4bc8a22f05d0791715a1c66a419bcfd9d1764f78..2e6954646852703947bab5ccf5783032cf56c169 100644 --- a/main.go +++ b/main.go @@ -2,8 +2,15 @@ package main import ( "github.com/kujtimiihoxha/termai/cmd" + "github.com/kujtimiihoxha/termai/internal/logging" ) func main() { + // Set up panic recovery for the main function + defer logging.RecoverPanic("main", func() { + // Perform any necessary cleanup before exit + logging.ErrorPersist("Application terminated due to unhandled panic") + }) + cmd.Execute() } From cc07f7a186995f428436bc1adc66a264a95171a4 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 16 Apr 2025 21:48:29 +0200 Subject: [PATCH 20/41] rename to opencode --- .opencode.json | 11 +++ cmd/root.go | 14 +-- go.mod | 2 +- internal/app/app.go | 18 ++-- internal/app/lsp.go | 8 +- internal/config/config.go | 4 +- internal/db/connect.go | 6 +- internal/diff/diff.go | 4 +- internal/history/file.go | 4 +- internal/llm/agent/agent-tool.go | 10 +- internal/llm/agent/agent.go | 18 ++-- internal/llm/agent/mcp-tools.go | 10 +- internal/llm/agent/tools.go | 12 +-- internal/llm/prompt/coder.go | 97 +++++++++----------- internal/llm/prompt/prompt.go | 4 +- internal/llm/prompt/task.go | 4 +- internal/llm/prompt/title.go | 2 +- internal/llm/provider/anthropic.go | 8 +- internal/llm/provider/bedrock.go | 4 +- internal/llm/provider/gemini.go | 8 +- internal/llm/provider/openai.go | 8 +- internal/llm/provider/provider.go | 6 +- internal/llm/tools/bash.go | 16 ++-- internal/llm/tools/diagnostics.go | 4 +- internal/llm/tools/edit.go | 10 +- internal/llm/tools/edit_test.go | 2 +- internal/llm/tools/fetch.go | 6 +- internal/llm/tools/glob.go | 2 +- internal/llm/tools/grep.go | 2 +- internal/llm/tools/ls.go | 2 +- internal/llm/tools/mocks_test.go | 6 +- internal/llm/tools/shell/shell.go | 8 +- internal/llm/tools/sourcegraph.go | 2 +- internal/llm/tools/view.go | 4 +- internal/llm/tools/write.go | 10 +- internal/llm/tools/write_test.go | 2 +- internal/logging/writer.go | 2 +- internal/lsp/client.go | 6 +- internal/lsp/handlers.go | 8 +- internal/lsp/language.go | 2 +- internal/lsp/methods.go | 2 +- internal/lsp/transport.go | 4 +- internal/lsp/util/edit.go | 2 +- internal/lsp/watcher/watcher.go | 8 +- internal/message/content.go | 2 +- internal/message/message.go | 6 +- internal/permission/permission.go | 2 +- internal/session/session.go | 4 +- internal/tui/components/chat/chat.go | 8 +- internal/tui/components/chat/editor.go | 10 +- internal/tui/components/chat/messages.go | 22 ++--- internal/tui/components/chat/sidebar.go | 12 +-- internal/tui/components/core/status.go | 12 +-- internal/tui/components/dialog/help.go | 2 +- internal/tui/components/dialog/permission.go | 12 +-- internal/tui/components/dialog/quit.go | 6 +- internal/tui/components/logs/details.go | 6 +- internal/tui/components/logs/table.go | 10 +- internal/tui/layout/border.go | 2 +- internal/tui/layout/container.go | 2 +- internal/tui/layout/overlay.go | 4 +- internal/tui/layout/split.go | 2 +- internal/tui/page/chat.go | 10 +- internal/tui/page/logs.go | 4 +- internal/tui/tui.go | 18 ++-- main.go | 4 +- 66 files changed, 266 insertions(+), 266 deletions(-) diff --git a/.opencode.json b/.opencode.json index b7fc19b524371cf7e4a625173f2fe305914694d3..4b2944f869b29804808361de59352957fc18ef81 100644 --- a/.opencode.json +++ b/.opencode.json @@ -3,5 +3,16 @@ "gopls": { "command": "gopls" } + }, + "agents": { + "coder": { + "model": "gpt-4.1" + }, + "task": { + "model": "gpt-4.1" + }, + "title": { + "model": "gpt-4.1" + } } } diff --git a/cmd/root.go b/cmd/root.go index ff71747d56458c6e8094e19db90a9f36acc2db42..f506e99404f2bdc2d0331f592ffe2ab69b560ef5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,13 +8,13 @@ import ( "time" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui" zone "github.com/lrstanley/bubblezone" "github.com/spf13/cobra" ) diff --git a/go.mod b/go.mod index 16c88d3a61c83d913721aecdf21afb07d26dbe29..822e70dbd446d067ac572446928304322b06ff6a 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/kujtimiihoxha/termai +module github.com/kujtimiihoxha/opencode go 1.24.0 diff --git a/internal/app/app.go b/internal/app/app.go index 1c16ccc1186253961e01bd60f76a541bc891364b..748fdaa7f0e2ec6dc314defa7b5c8d0c3c1b8c96 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -7,15 +7,15 @@ import ( "sync" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) type App struct { diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 4a762f1a156cc2677ffd3806723d7ea402d60005..d8a35c8b3a9646ab079bfa4fe6151fcb06aaf45f 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -4,10 +4,10 @@ import ( "context" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/watcher" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/watcher" ) func (app *App) initLSPClients(ctx context.Context) { diff --git a/internal/config/config.go b/internal/config/config.go index 147d6c83a4bb77bcb7749f81dd6867410ecc8059..20a8bac9750e0d0cfd5e7a8c194527ac714bef89 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,8 +7,8 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/spf13/viper" ) diff --git a/internal/db/connect.go b/internal/db/connect.go index 8bba9cad806c73327ffa016bd5851875403236e4..e850bc8d02a4e9c685e1e5dc71260eacb44a0563 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -12,8 +12,8 @@ import ( "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/mattn/go-sqlite3" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" ) func Connect() (*sql.DB, error) { @@ -24,7 +24,7 @@ func Connect() (*sql.DB, error) { if err := os.MkdirAll(dataDir, 0o700); err != nil { return nil, fmt.Errorf("failed to create data directory: %w", err) } - dbPath := filepath.Join(dataDir, "termai.db") + dbPath := filepath.Join(dataDir, "opencode.db") // Open the SQLite database db, err := sql.Open("sqlite3", dbPath) if err != nil { diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 829554c7e1052bb95c02dc8b586634da21938b77..f48079c9c96c7e12349cac467875530200c505eb 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -19,8 +19,8 @@ import ( "github.com/charmbracelet/x/ansi" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/object" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/sergi/go-diff/diffmatchpatch" ) diff --git a/internal/history/file.go b/internal/history/file.go index 82017d4cf84c855158cf934fc3799a0e7c18762f..1e8bc50bb24d0563788a914d06adf824b86bd2c3 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -7,8 +7,8 @@ import ( "strings" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) const ( diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index 308412bde86f8f6743997e8627b64082d2d8866f..be6e09a9b55819c32d990434f59b6342050a0026 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/session" ) type agentTool struct { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ab2742ec19b54b3c81f2c53df3d609c6f29e4a73..a5dadb89da77826f38dd15e22567b07b313f4d58 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -7,15 +7,15 @@ import ( "strings" "sync" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/prompt" - "github.com/kujtimiihoxha/termai/internal/llm/provider" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/prompt" + "github.com/kujtimiihoxha/opencode/internal/llm/provider" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) // Common errors diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index c7ea4916cad38a69df1e8bfde3fe78155a009496..16dddc1ba467dc3ae2dad62dab03bc981ff97226 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/version" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/version" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index a37f1d65d0327ad51d8bfe2830e4c3661c324a30..409d1427313c2eb7797b11a6fdd953a2e0b3414a 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -3,12 +3,12 @@ package agent import ( "context" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/session" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/session" ) func CoderAgentTools( diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 7439fd57064559e26592d1f23ea24abfc217f2c0..3a06911dadf3f9bb04aa0f66d8b136b836468f1f 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -8,9 +8,9 @@ import ( "runtime" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" ) func CoderPrompt(provider models.ModelProvider) string { @@ -24,69 +24,58 @@ func CoderPrompt(provider models.ModelProvider) string { return fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) } -const baseOpenAICoderPrompt = `You are termAI, an autonomous CLI-based software engineer. Your job is to reduce user effort by proactively reasoning, inferring context, and solving software engineering tasks end-to-end with minimal prompting. - -# Your mindset -Act like a competent, efficient software engineer who is familiar with large codebases. You should: -- Think critically about user requests. -- Proactively search the codebase for related information. -- Infer likely commands, tools, or conventions. -- Write and edit code with minimal user input. -- Anticipate next steps (tests, lints, etc.), but never commit unless explicitly told. - -# Context awareness -- Before acting, infer the purpose of a file from its name, directory, and neighboring files. -- If a file or function appears malicious, refuse to interact with it or discuss it. -- If a termai.md file exists, auto-load it as memory. Offer to update it only if new useful info appears (commands, preferences, structure). - -# CLI communication -- Use GitHub-flavored markdown in monospace font. -- Be concise. Never add preambles or postambles unless asked. Max 4 lines per response. -- Never explain your code unless asked. Do not narrate actions. -- Avoid unnecessary questions. Infer, search, act. - -# Behavior guidelines -- Follow project conventions: naming, formatting, libraries, frameworks. -- Before using any library or framework, confirm it’s already used. -- Always look at the surrounding code to match existing style. -- Do not add comments unless the code is complex or the user asks. - -# Autonomy rules -You are allowed and expected to: -- Search for commands, tools, or config files before asking the user. -- Run multiple search tool calls concurrently to gather relevant context. -- Choose test, lint, and typecheck commands based on package files or scripts. -- Offer to store these commands in termai.md if not already present. - -# Example behavior -user: write tests for new feature -assistant: [searches for existing test patterns, finds appropriate location, generates test code using existing style, optionally asks to add test command to termai.md] +const baseOpenAICoderPrompt = ` +You are **OpenCode**, an autonomous CLI assistant for software‑engineering tasks. + +### ── INTERNAL REFLECTION ── +• Silently think step‑by‑step about the user request, directory layout, and tool calls (never reveal this). +• Formulate a plan, then execute without further approval unless a blocker triggers the Ask‑Only‑If rules. + +### ── PUBLIC RESPONSE RULES ── +• Visible reply ≤ 4 lines; no fluff, preamble, or postamble. +• Use GitHub‑flavored Markdown. +• When running a non‑trivial shell command, add ≤ 1 brief purpose sentence. + +### ── CONTEXT & MEMORY ── +• Infer file intent from directory structure before editing. +• Auto‑load 'OpenCode.md'; ask once before writing new reusable commands or style notes. -user: how do I typecheck this codebase? -assistant: [searches for known commands, infers package manager, checks for scripts or config files] -tsc --noEmit +### ── AUTONOMY PRIORITY ── +**Ask‑Only‑If Decision Tree:** +1. **Safety risk?** (e.g., destructive command, secret exposure) → ask. +2. **Critical unknown?** (no docs/tests; cannot infer) → ask. +3. **Tool failure after two self‑attempts?** → ask. +Otherwise, proceed autonomously. -user: is X function used anywhere else? -assistant: [searches repo for references, returns file paths and lines] +### ── SAFETY & STYLE ── +• Mimic existing code style; verify libraries exist before import. +• Never commit unless explicitly told. +• After edits, run lint & type‑check (ask for commands once, then offer to store in 'OpenCode.md'). +• Protect secrets; follow standard security practices :contentReference[oaicite:2]{index=2}. -# Tool usage -- Use parallel calls when possible. -- Use file search and content tools before asking the user. -- Do not ask the user for information unless it cannot be determined via tools. +### ── TOOL USAGE ── +• Batch independent Agent search/file calls in one block for efficiency :contentReference[oaicite:3]{index=3}. +• Communicate with the user only via visible text; do not expose tool output or internal reasoning. -Never commit changes unless the user explicitly asks you to.` +### ── EXAMPLES ── +user: list files +assistant: ls + +user: write tests for new feature +assistant: [searches & edits autonomously, no extra chit‑chat] +` -const baseAnthropicCoderPrompt = `You are termAI, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. +const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. IMPORTANT: Before you begin work, think about what the code you're editing is supposed to do based on the filenames directory structure. # Memory -If the current working directory contains a file called termai.md, it will be automatically added to your context. This file serves multiple purposes: +If the current working directory contains a file called OpenCode.md, it will be automatically added to your context. This file serves multiple purposes: 1. Storing frequently used bash commands (build, test, lint, etc.) so you can use them without searching each time 2. Recording the user's code style preferences (naming conventions, preferred libraries, etc.) 3. Maintaining useful information about the codebase structure and organization -When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to termai.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to termai.md so you can remember it for next time. +When you spend time searching for commands to typecheck, lint, build, or test, you should ask the user if it's okay to add those commands to OpenCode.md. Similarly, when learning about code style preferences or important codebase information, ask if it's okay to add that to OpenCode.md so you can remember it for next time. # Tone and style You should be concise, direct, and to the point. When you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system). @@ -161,7 +150,7 @@ The user will primarily request you perform software engineering tasks. This inc 1. Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially. 2. Implement the solution using all tools available to you 3. Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach. -4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to termai.md so that you will know to run it next time. +4. VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to opencode.md so that you will know to run it next time. NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 63fc2df7bcecd30f37ee04259c05e6b425cd75fd..cdc3560cefe9d0d464fe1786e12e92165503868f 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -1,8 +1,8 @@ package prompt import ( - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index 8bf604ad99750595505d334659a4f596d54edf0c..88cd1a0f46d98dd6c4637e87921c8de94644776a 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -3,11 +3,11 @@ package prompt import ( "fmt" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) func TaskPrompt(_ models.ModelProvider) string { - agentPrompt := `You are an agent for termAI. Given the user's prompt, you should use the tools available to you to answer the user's question. + agentPrompt := `You are an agent for OpenCode. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 3023a8550d18f7c0451d52f907150d76499b1552..6e5289b24e984f92d9dbb9b86bfd530dfb4ae441 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -1,6 +1,6 @@ package prompt -import "github.com/kujtimiihoxha/termai/internal/llm/models" +import "github.com/kujtimiihoxha/opencode/internal/llm/models" func TitlePrompt(_ models.ModelProvider) string { return `you will generate a short title based on the first message a user begins a conversation with diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index c3a4efc49bea916f55a48be32771a1f06d6a9617..7bbc02103df7d81b121d19244b4e6da8ce0fd600 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -12,10 +12,10 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" ) type anthropicOptions struct { diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index d76925ad10274bd6ded1401e00bbb7954035b122..9415b30feead1dac5f8bdf59bfe49e9a6992fc2e 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -7,8 +7,8 @@ import ( "os" "strings" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/message" ) type bedrockOptions struct { diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 804baea281bdd9e609e113fed811a684f2049a81..384bff900aeafeb7b6e066c2e2dbb35616a4e808 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -11,10 +11,10 @@ import ( "github.com/google/generative-ai-go/genai" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" "google.golang.org/api/iterator" "google.golang.org/api/option" ) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 9c2ad201263d0daf182d6feda8c6351d764f896e..13ce934f29fb3a543d2d13f7e3b626cf61f7f0aa 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -8,10 +8,10 @@ import ( "io" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" ) diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 1a5b3dc8ace7f2b363761c9defd37a53407f42d4..e04bee71bce87b642b6b1b227adc22d72dfcdb26 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/message" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/message" ) type EventType string diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index c7c970e5a1a9b10b3905d51701414bd464be4247..18533b761d0cfeca17f842df0db1e3756c38619d 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -7,9 +7,9 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/tools/shell" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/tools/shell" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type BashParams struct { @@ -122,16 +122,16 @@ When the user asks you to create a new git commit, follow these steps carefully: 4. Create the commit with a message ending with: -🤖 Generated with termai -Co-Authored-By: termai +🤖 Generated with opencode +Co-Authored-By: opencode - In order to ensure good formatting, ALWAYS pass the commit message via a HEREDOC, a la this example: git commit -m "$(cat <<'EOF' Commit message here. - 🤖 Generated with termai - Co-Authored-By: termai + 🤖 Generated with opencode + Co-Authored-By: opencode EOF )" @@ -193,7 +193,7 @@ gh pr create --title "the pr title" --body "$(cat <<'EOF' ## Test plan [Checklist of TODOs for testing the pull request...] -🤖 Generated with termai +🤖 Generated with opencode EOF )" diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go index b7b2bb8bab55b0ecf507c7bd6bd25b9454131cbc..82989c774206ec98bf2d88e8059b0acfdc5ca1b7 100644 --- a/internal/llm/tools/diagnostics.go +++ b/internal/llm/tools/diagnostics.go @@ -9,8 +9,8 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) type DiagnosticsParams struct { diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 148e7aba7a78b11df9eeea9b3ac57658e15dc070..6a16160109cffbdcc61edf66d00d84217dcc4868 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -9,11 +9,11 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type EditParams struct { diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go index 0971775ddb508f50e1bbf3a4cf4a8f372f419dbe..1b58a0d7d401a47a3002ebd05e081a87283ed94f 100644 --- a/internal/llm/tools/edit_test.go +++ b/internal/llm/tools/edit_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 91bcb36a0696e293658ea11cf333b18a9aa2d767..827755863d053ab9e3073c859523a5cc8048500b 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -11,8 +11,8 @@ import ( md "github.com/JohannesKaufmann/html-to-markdown" "github.com/PuerkitoBio/goquery" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type FetchParams struct { @@ -146,7 +146,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("User-Agent", "termai/1.0") + req.Header.Set("User-Agent", "opencode/1.0") resp, err := client.Do(req) if err != nil { diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 7b4fb11870aa1f2c79755c5f8de30a7e49ead11b..40262ce2ba5bb43a218d28a7d24e5f45b1a89ea3 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -12,7 +12,7 @@ import ( "time" "github.com/bmatcuk/doublestar/v4" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) const ( diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 19333f50b9ac0820588fda90413488c3d99e3aa4..3436dd7eb6b21a6cfec6e80ff249823a23c0f5d2 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) type GrepParams struct { diff --git a/internal/llm/tools/ls.go b/internal/llm/tools/ls.go index a63bf0eebfb98869e28b02b1e4bb1c31c81fbd3d..05f300c0e7885f1a4e4cf0d1fc02171a818572ab 100644 --- a/internal/llm/tools/ls.go +++ b/internal/llm/tools/ls.go @@ -8,7 +8,7 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/opencode/internal/config" ) type LSParams struct { diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go index 321f09ac1ab00b8c413176db72cba676c4c45dd1..81993160c0384c31a4cc49bd53c40fa94cefd1a3 100644 --- a/internal/llm/tools/mocks_test.go +++ b/internal/llm/tools/mocks_test.go @@ -9,9 +9,9 @@ import ( "time" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) // Mock permission service for testing diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 4a776478ab67bb4031a698b56beaddfb9734fa1d..e25bdf3eafccc5d6331edec1746a86040beda2f5 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -126,10 +126,10 @@ func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx } tempDir := os.TempDir() - stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano())) - stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano())) - statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano())) - cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano())) + stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano())) + stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano())) + statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano())) + cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano())) defer func() { os.Remove(stdoutFile) diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index a6f2c8afb578c4fce7c29bcf1ffabfffdba75d71..0d38c975fbe202a8cd16f580586795bf2213fabb 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -218,7 +218,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "termai/1.0") + req.Header.Set("User-Agent", "opencode/1.0") resp, err := client.Do(req) if err != nil { diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 7450a84bfb1bcedf4c4c265f202d35ca5ffb6259..3fa4ca11616eda5daf21a14861d524d782f74e5d 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -10,8 +10,8 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/lsp" ) type ViewParams struct { diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index bb49381fd4423f70cad38dfdbc99197c765c3080..261865c398b061264a23c43aff9b3fcef0ee1283 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -8,11 +8,11 @@ import ( "path/filepath" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/permission" ) type WriteParams struct { diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go index 2264f36fb8aff7537095d751f5b847159d9743c2..b5ecb3fda17287c5b6cfda75e1309fcb62547c55 100644 --- a/internal/llm/tools/write_test.go +++ b/internal/llm/tools/write_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/kujtimiihoxha/termai/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/logging/writer.go b/internal/logging/writer.go index 9fe469c5e0953009413bc5171aba95d04936e92e..1dc07e8531e7c0ef658023dfb52d236b9b33a796 100644 --- a/internal/logging/writer.go +++ b/internal/logging/writer.go @@ -9,7 +9,7 @@ import ( "time" "github.com/go-logfmt/logfmt" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) const ( diff --git a/internal/lsp/client.go b/internal/lsp/client.go index 0f03e7fcb12f8745accd0e8b33fb9af98d806d08..dad07f3c0e9d9a6bf816c8dbfb12361969ba8d68 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -13,9 +13,9 @@ import ( "sync/atomic" "time" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) type Client struct { diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index c3088d6852061ee26ab94aa7e2b783cf3b52ca54..7a11286e602f25672f46c01db4d158858cde301a 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -3,10 +3,10 @@ package lsp import ( "encoding/json" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/lsp/util" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/util" ) // Requests diff --git a/internal/lsp/language.go b/internal/lsp/language.go index 2e276c464867680417dea33259c1ae30bca4150f..65ccd54f33b185afc9c8657b660b07b912641e70 100644 --- a/internal/lsp/language.go +++ b/internal/lsp/language.go @@ -4,7 +4,7 @@ import ( "path/filepath" "strings" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) func DetectLanguageID(uri string) protocol.LanguageKind { diff --git a/internal/lsp/methods.go b/internal/lsp/methods.go index 079b3bfe365398e83e8ecb3b2c66b309ca97e5f3..ab33d7e1bbf196e48fc31f04013b22e25dab3c15 100644 --- a/internal/lsp/methods.go +++ b/internal/lsp/methods.go @@ -4,7 +4,7 @@ package lsp import ( "context" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) // Implementation sends a textDocument/implementation request to the LSP server. diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 89255fd78bfe356afc52cf4e82ac64b0a6553e10..fe59b0fbb22dc35d6444d7eed2458f4bf6def70e 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -8,8 +8,8 @@ import ( "io" "strings" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" ) // Write writes an LSP message to the given writer diff --git a/internal/lsp/util/edit.go b/internal/lsp/util/edit.go index 3b94fb39ff70afd463aac83bb7ffd698e43a3e6f..52f03ee772fa78cd835d4a6913538785343594ab 100644 --- a/internal/lsp/util/edit.go +++ b/internal/lsp/util/edit.go @@ -7,7 +7,7 @@ import ( "sort" "strings" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) func applyTextEdits(uri protocol.DocumentUri, edits []protocol.TextEdit) error { diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index 156f38e1aa897b5195921607cd322a7196bebbc2..595c78db9154c3ba94d2e0a8dc4590750280c6f6 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -10,10 +10,10 @@ import ( "time" "github.com/fsnotify/fsnotify" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" ) // WorkspaceWatcher manages LSP file watching diff --git a/internal/message/content.go b/internal/message/content.go index f9e76b11c1a44fde90bedb8721eac44603ab8af3..f52449f4a394cdc4fdebcb8921f3f7cf9601d068 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -5,7 +5,7 @@ import ( "slices" "time" - "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/models" ) type MessageRole string diff --git a/internal/message/message.go b/internal/message/message.go index 2871780a79f91cca018c6ad1a398c023123310c6..f165fcfc75a3f56f0fc454cb1aabcb5bb75d8c6f 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -7,9 +7,9 @@ import ( "fmt" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) type CreateMessageParams struct { diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 8aa280906a397dd5766b4fcd52daf86b76506f8b..4cb379dea133a4b101f16d987f34c4647a8dd2b5 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -6,7 +6,7 @@ import ( "time" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) var ErrorPermissionDenied = errors.New("permission denied") diff --git a/internal/session/session.go b/internal/session/session.go index 019019df47d5ab46e5a2ccb22dff8f3ceadae29e..280da1ff0a36166ece694c7acf80707f17f0d7d1 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -5,8 +5,8 @@ import ( "database/sql" "github.com/google/uuid" - "github.com/kujtimiihoxha/termai/internal/db" - "github.com/kujtimiihoxha/termai/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/db" + "github.com/kujtimiihoxha/opencode/internal/pubsub" ) type Session struct { diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index e98001efa8345b310829cb8c09c57a6869c57b20..52ff4c8bf3e28f500c9890487f9d7ba9e48eb62c 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -5,10 +5,10 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/version" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/version" ) type SendMsg struct { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index e2f4da9e240802b6872c08aa93d78a192fccfa99..4d6ef5ca0a4b902dcd50c750659caeb02d6aeb18 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -5,11 +5,11 @@ import ( "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type editorCmp struct { diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go index 26a98970ee61b60eff0a7396d250dbae09f5a6e5..c2ce7d88b13ea9858f7adf5cf343504525f35973 100644 --- a/internal/tui/components/chat/messages.go +++ b/internal/tui/components/chat/messages.go @@ -15,17 +15,17 @@ import ( "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/llm/agent" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/message" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type uiMessageType int diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index b90269d1a643847a0c039f77055890aa802db3da..54b39f4a1468340a7d769eaf23a682839746d980 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -7,12 +7,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/history" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type sidebarCmp struct { diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 089dffa2c33fa3b9b27348a65a77bdac989b049d..411cac1c518c8a37c187f2b31500eb1f64e47988 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -7,12 +7,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/config" - "github.com/kujtimiihoxha/termai/internal/llm/models" - "github.com/kujtimiihoxha/termai/internal/lsp" - "github.com/kujtimiihoxha/termai/internal/lsp/protocol" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type statusCmp struct { diff --git a/internal/tui/components/dialog/help.go b/internal/tui/components/dialog/help.go index 1d3c2b077b1b393c4bc1e2147af4cc0c8d519d9c..6242017f100cd589314f8e12722f7234a1536f1d 100644 --- a/internal/tui/components/dialog/help.go +++ b/internal/tui/components/dialog/help.go @@ -6,7 +6,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type helpCmp struct { diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 9c55effde1c3b18692eb4063cf56b560313329cb..200a7970d95c4a494802f164774017704bd9868a 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -9,12 +9,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/glamour" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/diff" - "github.com/kujtimiihoxha/termai/internal/llm/tools" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type PermissionAction string diff --git a/internal/tui/components/dialog/quit.go b/internal/tui/components/dialog/quit.go index 10d9ba8a2c363168fc6bf0048451c899c367d9cc..5bbe6696cf9c73ac498ad98841e11c2f23c35df2 100644 --- a/internal/tui/components/dialog/quit.go +++ b/internal/tui/components/dialog/quit.go @@ -6,9 +6,9 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) const question = "Are you sure you want to quit?" diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index 18eb1a526806779fa0ec20d2501ec122e8c1a0e4..3a8f1799931ed2d16f65c6e572ae58520bd299e4 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -9,9 +9,9 @@ import ( "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type DetailComponent interface { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 6e8eb58b13772d294da3fb28eeb3fa4020c44402..dc6184e3df5f4509fc65805a7aaea9ecc65bcb31 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -7,11 +7,11 @@ import ( "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/table" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type TableComponent interface { diff --git a/internal/tui/layout/border.go b/internal/tui/layout/border.go index 8fe5c430c6d7adc7ffab02c9f8c94a611f54d58e..ea9f5e0bc50c1d11710c8f784971f1037110c4da 100644 --- a/internal/tui/layout/border.go +++ b/internal/tui/layout/border.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type BorderPosition int diff --git a/internal/tui/layout/container.go b/internal/tui/layout/container.go index db07d49fb925ad829f3e05f907c0207e6e0dbe89..60369995591b757cd57416d129dce92c0a465d6f 100644 --- a/internal/tui/layout/container.go +++ b/internal/tui/layout/container.go @@ -4,7 +4,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type Container interface { diff --git a/internal/tui/layout/overlay.go b/internal/tui/layout/overlay.go index 4a1bcf661ad14ec7bfb186fce308dc1f3d255a97..4c05e84629da4727f26ea39c4c3d98891e8b9181 100644 --- a/internal/tui/layout/overlay.go +++ b/internal/tui/layout/overlay.go @@ -5,8 +5,8 @@ import ( "strings" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" "github.com/mattn/go-runewidth" "github.com/muesli/ansi" "github.com/muesli/reflow/truncate" diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index 6482fc74cebc10149548543eccd4fed67e32db44..bfb616a5364da9e62fe5058ddf28bb87780527c5 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -4,7 +4,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) type SplitPaneLayout interface { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index cebc0e4610ca696c1fefb8b10c271462e0fcbbb3..c268e677f4a49493d9bad69f86d70a4c9de71e0e 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -5,11 +5,11 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/session" - "github.com/kujtimiihoxha/termai/internal/tui/components/chat" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/components/chat" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) var ChatPage PageID = "chat" diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index d1e557eab548673ac76ba59cdc716a7dc2410c23..c77a033f466d0b6912ea66c00c49ae2fe150199f 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -2,8 +2,8 @@ package page import ( tea "github.com/charmbracelet/bubbletea" - "github.com/kujtimiihoxha/termai/internal/tui/components/logs" - "github.com/kujtimiihoxha/termai/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/components/logs" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" ) var LogsPage PageID = "logs" diff --git a/internal/tui/tui.go b/internal/tui/tui.go index dff7ad63d11e61463b659bea70f61d79b57c4e1c..657de6b6e56d3bf739e8a4d3bdff447abd8227eb 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -4,15 +4,15 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/termai/internal/app" - "github.com/kujtimiihoxha/termai/internal/logging" - "github.com/kujtimiihoxha/termai/internal/permission" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "github.com/kujtimiihoxha/termai/internal/tui/components/core" - "github.com/kujtimiihoxha/termai/internal/tui/components/dialog" - "github.com/kujtimiihoxha/termai/internal/tui/layout" - "github.com/kujtimiihoxha/termai/internal/tui/page" - "github.com/kujtimiihoxha/termai/internal/tui/util" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/permission" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui/components/core" + "github.com/kujtimiihoxha/opencode/internal/tui/components/dialog" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/page" + "github.com/kujtimiihoxha/opencode/internal/tui/util" ) type keyMap struct { diff --git a/main.go b/main.go index 2e6954646852703947bab5ccf5783032cf56c169..06578c7efb105990b6196917757c653d1ca8bdca 100644 --- a/main.go +++ b/main.go @@ -1,8 +1,8 @@ package main import ( - "github.com/kujtimiihoxha/termai/cmd" - "github.com/kujtimiihoxha/termai/internal/logging" + "github.com/kujtimiihoxha/opencode/cmd" + "github.com/kujtimiihoxha/opencode/internal/logging" ) func main() { From 36172979b45facc8ccec6861f124193eaebc42e9 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 17 Apr 2025 00:00:19 +0200 Subject: [PATCH 21/41] Update agent prompt, improve TUI patch UI, remove obsolete tool tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace and expand agent coder prompt for clarity and safety - Add patch tool and TUI dialog support for patch diffs - Sort sidebar modified files by name - Remove Bash/Edit/Sourcegraph/Write tool tests 🤖 Generated with opencode Co-Authored-By: opencode --- internal/llm/agent/tools.go | 2 + internal/llm/prompt/coder.go | 95 ++-- internal/llm/tools/bash_test.go | 340 -------------- internal/llm/tools/edit_test.go | 461 ------------------- internal/llm/tools/mocks_test.go | 246 ---------- internal/llm/tools/patch.go | 300 ++++++++++++ internal/llm/tools/sourcegraph_test.go | 86 ---- internal/llm/tools/write_test.go | 307 ------------ internal/tui/components/chat/sidebar.go | 12 +- internal/tui/components/dialog/permission.go | 14 + main.go | 4 +- 11 files changed, 385 insertions(+), 1482 deletions(-) delete mode 100644 internal/llm/tools/bash_test.go delete mode 100644 internal/llm/tools/edit_test.go delete mode 100644 internal/llm/tools/mocks_test.go create mode 100644 internal/llm/tools/patch.go delete mode 100644 internal/llm/tools/sourcegraph_test.go delete mode 100644 internal/llm/tools/write_test.go diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index 409d1427313c2eb7797b11a6fdd953a2e0b3414a..9120809ffefa27a2a5c15968665bf559a266c4e3 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -31,6 +31,8 @@ func CoderAgentTools( tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), + // TODO: see if we want to use this tool + // tools.NewPatchTool(lspClients, permissions, history), tools.NewSourcegraphTool(), tools.NewViewTool(lspClients), tools.NewWriteTool(lspClients, permissions, history), diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 3a06911dadf3f9bb04aa0f66d8b136b836468f1f..febdea4d250bdfc7b9bb2619a68d65354f0c6fa1 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -25,44 +25,63 @@ func CoderPrompt(provider models.ModelProvider) string { } const baseOpenAICoderPrompt = ` -You are **OpenCode**, an autonomous CLI assistant for software‑engineering tasks. - -### ── INTERNAL REFLECTION ── -• Silently think step‑by‑step about the user request, directory layout, and tool calls (never reveal this). -• Formulate a plan, then execute without further approval unless a blocker triggers the Ask‑Only‑If rules. - -### ── PUBLIC RESPONSE RULES ── -• Visible reply ≤ 4 lines; no fluff, preamble, or postamble. -• Use GitHub‑flavored Markdown. -• When running a non‑trivial shell command, add ≤ 1 brief purpose sentence. - -### ── CONTEXT & MEMORY ── -• Infer file intent from directory structure before editing. -• Auto‑load 'OpenCode.md'; ask once before writing new reusable commands or style notes. - -### ── AUTONOMY PRIORITY ── -**Ask‑Only‑If Decision Tree:** -1. **Safety risk?** (e.g., destructive command, secret exposure) → ask. -2. **Critical unknown?** (no docs/tests; cannot infer) → ask. -3. **Tool failure after two self‑attempts?** → ask. -Otherwise, proceed autonomously. - -### ── SAFETY & STYLE ── -• Mimic existing code style; verify libraries exist before import. -• Never commit unless explicitly told. -• After edits, run lint & type‑check (ask for commands once, then offer to store in 'OpenCode.md'). -• Protect secrets; follow standard security practices :contentReference[oaicite:2]{index=2}. - -### ── TOOL USAGE ── -• Batch independent Agent search/file calls in one block for efficiency :contentReference[oaicite:3]{index=3}. -• Communicate with the user only via visible text; do not expose tool output or internal reasoning. - -### ── EXAMPLES ── -user: list files -assistant: ls - -user: write tests for new feature -assistant: [searches & edits autonomously, no extra chit‑chat] +# OpenCode CLI Agent Prompt + +You are operating within the **OpenCode CLI**, a terminal-based, agentic coding assistant that interfaces with local codebases through natural language. Your primary objectives are to be precise, safe, and helpful. + +## Capabilities + +- Receive user prompts, project context, and files. +- Stream responses and emit function calls (e.g., shell commands, code edits). +- Apply patches, run commands, and manage user approvals based on policy. +- Operate within a sandboxed, git-backed workspace with rollback support. +- Log telemetry for session replay or inspection. +- Access detailed functionality via the help command. + +## Operational Guidelines + +### 1. Task Resolution + +- Continue processing until the user's query is fully resolved. +- Only conclude your turn when confident the problem is solved. +- If uncertain about file content or codebase structure, utilize available tools to gather necessary information—avoid assumptions. + +### 2. Code Modification & Testing + +- Edit and test code files within your current execution session. +- Work on the local repositories, even if proprietary. +- Analyze code for vulnerabilities when applicable. +- Display user code and tool call details transparently. + +### 3. Coding Guidelines + +- Address root causes rather than applying superficial fixes. +- Avoid unnecessary complexity; focus on the task at hand. +- Update documentation as needed. +- Maintain consistency with the existing codebase style. +- Utilize version control tools for additional context; note that internet access is disabled. +- Refrain from adding copyright or license headers unless explicitly requested. +- No need to perform commit operations; this will be handled automatically. +- If a pre-commit configuration file exists, run the appropriate checks to ensure changes pass. Do not fix pre-existing errors on untouched lines. +- If pre-commit checks fail after retries, inform the user that the setup may be broken. + +### 4. Post-Modification Checks + +- Use version control status commands to verify changes; revert any unintended modifications. +- Remove all added inline comments unless they are essential for understanding. +- Ensure no accidental addition of copyright or license headers. +- Attempt to run pre-commit checks if available. +- For smaller tasks, provide brief bullet points summarizing changes. +- For complex tasks, include a high-level description, bullet points, and relevant details for code reviewers. + +### 5. Non-Code Modification Tasks + +- Respond in a friendly, collaborative tone, akin to a knowledgeable remote teammate eager to assist with coding inquiries. + +### 6. File Handling + +- Do not instruct the user to save or copy code into files if modifications have already been made using the editing tools. +- Avoid displaying full contents of large files unless explicitly requested by the user. ` const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go deleted file mode 100644 index dafb0ccc5fe4b7f612c451755b09cb70f17ac388..0000000000000000000000000000000000000000 --- a/internal/llm/tools/bash_test.go +++ /dev/null @@ -1,340 +0,0 @@ -package tools - -import ( - "context" - "encoding/json" - "os" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestBashTool_Info(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - info := tool.Info() - - assert.Equal(t, BashToolName, info.Name) - assert.NotEmpty(t, info.Description) - assert.Contains(t, info.Parameters, "command") - assert.Contains(t, info.Parameters, "timeout") - assert.Contains(t, info.Required, "command") -} - -func TestBashTool_Run(t *testing.T) { - // Save original working directory - origWd, err := os.Getwd() - require.NoError(t, err) - defer func() { - os.Chdir(origWd) - }() - - t.Run("executes command successfully", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - params := BashParams{ - Command: "echo 'Hello World'", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Equal(t, "Hello World\n", response.Content) - }) - - t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - call := ToolCall{ - Name: BashToolName, - Input: "invalid json", - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "invalid parameters") - }) - - t.Run("handles missing command", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - params := BashParams{ - Command: "", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "missing command") - }) - - t.Run("handles banned commands", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - - for _, bannedCmd := range bannedCommands { - params := BashParams{ - Command: bannedCmd + " arg1 arg2", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd) - } - }) - - t.Run("handles multi-word safe commands without permission check", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(false)) - - // Test with multi-word safe commands - multiWordCommands := []string{ - "go env", - } - - for _, cmd := range multiWordCommands { - params := BashParams{ - Command: cmd, - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.NotContains(t, response.Content, "permission denied", - "Command %s should be allowed without permission", cmd) - } - }) - - t.Run("handles permission denied", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(false)) - - // Test with a command that requires permission - params := BashParams{ - Command: "mkdir test_dir", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "permission denied") - }) - - t.Run("handles command timeout", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - params := BashParams{ - Command: "sleep 2", - Timeout: 100, // 100ms timeout - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "aborted") - }) - - t.Run("handles command with stderr output", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - params := BashParams{ - Command: "echo 'error message' >&2", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "error message") - }) - - t.Run("handles command with both stdout and stderr", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - params := BashParams{ - Command: "echo 'stdout message' && echo 'stderr message' >&2", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "stdout message") - assert.Contains(t, response.Content, "stderr message") - }) - - t.Run("handles context cancellation", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - params := BashParams{ - Command: "sleep 5", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - ctx, cancel := context.WithCancel(context.Background()) - - // Cancel the context after a short delay - go func() { - time.Sleep(100 * time.Millisecond) - cancel() - }() - - response, err := tool.Run(ctx, call) - require.NoError(t, err) - assert.Contains(t, response.Content, "aborted") - }) - - t.Run("respects max timeout", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - params := BashParams{ - Command: "echo 'test'", - Timeout: MaxTimeout + 1000, // Exceeds max timeout - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Equal(t, "test\n", response.Content) - }) - - t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) { - tool := NewBashTool(newMockPermissionService(true)) - params := BashParams{ - Command: "echo 'test'", - Timeout: -100, // Negative timeout - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: BashToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Equal(t, "test\n", response.Content) - }) -} - -func TestTruncateOutput(t *testing.T) { - t.Run("does not truncate short output", func(t *testing.T) { - output := "short output" - result := truncateOutput(output) - assert.Equal(t, output, result) - }) - - t.Run("truncates long output", func(t *testing.T) { - // Create a string longer than MaxOutputLength - longOutput := strings.Repeat("a\n", MaxOutputLength) - result := truncateOutput(longOutput) - - // Check that the result is shorter than the original - assert.Less(t, len(result), len(longOutput)) - - // Check that the truncation message is included - assert.Contains(t, result, "lines truncated") - - // Check that we have the beginning and end of the original string - assert.True(t, strings.HasPrefix(result, "a\n")) - assert.True(t, strings.HasSuffix(result, "a\n")) - }) -} - -func TestCountLines(t *testing.T) { - testCases := []struct { - name string - input string - expected int - }{ - { - name: "empty string", - input: "", - expected: 0, - }, - { - name: "single line", - input: "line1", - expected: 1, - }, - { - name: "multiple lines", - input: "line1\nline2\nline3", - expected: 3, - }, - { - name: "trailing newline", - input: "line1\nline2\n", - expected: 3, // Empty string after last newline counts as a line - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := countLines(tc.input) - assert.Equal(t, tc.expected, result) - }) - } -} diff --git a/internal/llm/tools/edit_test.go b/internal/llm/tools/edit_test.go deleted file mode 100644 index 1b58a0d7d401a47a3002ebd05e081a87283ed94f..0000000000000000000000000000000000000000 --- a/internal/llm/tools/edit_test.go +++ /dev/null @@ -1,461 +0,0 @@ -package tools - -import ( - "context" - "encoding/json" - "os" - "path/filepath" - "testing" - "time" - - "github.com/kujtimiihoxha/opencode/internal/lsp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestEditTool_Info(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - info := tool.Info() - - assert.Equal(t, EditToolName, info.Name) - assert.NotEmpty(t, info.Description) - assert.Contains(t, info.Parameters, "file_path") - assert.Contains(t, info.Parameters, "old_string") - assert.Contains(t, info.Parameters, "new_string") - assert.Contains(t, info.Required, "file_path") - assert.Contains(t, info.Required, "old_string") - assert.Contains(t, info.Required, "new_string") -} - -func TestEditTool_Run(t *testing.T) { - // Create a temporary directory for testing - tempDir, err := os.MkdirTemp("", "edit_tool_test") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - filePath := filepath.Join(tempDir, "new_file.txt") - content := "This is a test content" - - params := EditParams{ - FilePath: filePath, - OldString: "", - NewString: content, - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "File created") - - // Verify file was created with correct content - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, content, string(fileContent)) - }) - - t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") - content := "Content in nested directory" - - params := EditParams{ - FilePath: filePath, - OldString: "", - NewString: content, - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "File created") - - // Verify file was created with correct content - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, content, string(fileContent)) - }) - - t.Run("fails to create file that already exists", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file first - filePath := filepath.Join(tempDir, "existing_file.txt") - initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Try to create the same file - params := EditParams{ - FilePath: filePath, - OldString: "", - NewString: "New content", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "file already exists") - }) - - t.Run("fails to create file when path is a directory", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a directory - dirPath := filepath.Join(tempDir, "test_dir") - err := os.Mkdir(dirPath, 0o755) - require.NoError(t, err) - - // Try to create a file with the same path as the directory - params := EditParams{ - FilePath: dirPath, - OldString: "", - NewString: "Some content", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "path is a directory") - }) - - t.Run("replaces content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file first - filePath := filepath.Join(tempDir, "replace_content.txt") - initialContent := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Record the file read to avoid modification time check failure - recordFileRead(filePath) - - // Replace content - oldString := "Line 2\nLine 3" - newString := "Line 2 modified\nLine 3 modified" - params := EditParams{ - FilePath: filePath, - OldString: oldString, - NewString: newString, - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "Content replaced") - - // Verify file was updated with correct content - expectedContent := "Line 1\nLine 2 modified\nLine 3 modified\nLine 4\nLine 5" - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, expectedContent, string(fileContent)) - }) - - t.Run("deletes content successfully", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file first - filePath := filepath.Join(tempDir, "delete_content.txt") - initialContent := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Record the file read to avoid modification time check failure - recordFileRead(filePath) - - // Delete content - oldString := "Line 2\nLine 3\n" - params := EditParams{ - FilePath: filePath, - OldString: oldString, - NewString: "", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "Content deleted") - - // Verify file was updated with correct content - expectedContent := "Line 1\nLine 4\nLine 5" - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, expectedContent, string(fileContent)) - }) - - t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - call := ToolCall{ - Name: EditToolName, - Input: "invalid json", - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "invalid parameters") - }) - - t.Run("handles missing file_path", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - params := EditParams{ - FilePath: "", - OldString: "old", - NewString: "new", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "file_path is required") - }) - - t.Run("handles file not found", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - filePath := filepath.Join(tempDir, "non_existent_file.txt") - params := EditParams{ - FilePath: filePath, - OldString: "old content", - NewString: "new content", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "file not found") - }) - - t.Run("handles old_string not found in file", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file first - filePath := filepath.Join(tempDir, "content_not_found.txt") - initialContent := "Line 1\nLine 2\nLine 3" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Record the file read to avoid modification time check failure - recordFileRead(filePath) - - // Try to replace content that doesn't exist - params := EditParams{ - FilePath: filePath, - OldString: "This content does not exist", - NewString: "new content", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "old_string not found in file") - }) - - t.Run("handles multiple occurrences of old_string", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file with duplicate content - filePath := filepath.Join(tempDir, "duplicate_content.txt") - initialContent := "Line 1\nDuplicate\nLine 3\nDuplicate\nLine 5" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Record the file read to avoid modification time check failure - recordFileRead(filePath) - - // Try to replace content that appears multiple times - params := EditParams{ - FilePath: filePath, - OldString: "Duplicate", - NewString: "Replaced", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "appears multiple times") - }) - - t.Run("handles file modified since last read", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file - filePath := filepath.Join(tempDir, "modified_file.txt") - initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Record an old read time - fileRecordMutex.Lock() - fileRecords[filePath] = fileRecord{ - path: filePath, - readTime: time.Now().Add(-1 * time.Hour), - } - fileRecordMutex.Unlock() - - // Try to update the file - params := EditParams{ - FilePath: filePath, - OldString: "Initial", - NewString: "Updated", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "has been modified since it was last read") - - // Verify file was not modified - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, initialContent, string(fileContent)) - }) - - t.Run("handles file not read before editing", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file - filePath := filepath.Join(tempDir, "not_read_file.txt") - initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Try to update the file without reading it first - params := EditParams{ - FilePath: filePath, - OldString: "Initial", - NewString: "Updated", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "you must read the file before editing it") - }) - - t.Run("handles permission denied", func(t *testing.T) { - tool := NewEditTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) - - // Create a file - filePath := filepath.Join(tempDir, "permission_denied.txt") - initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Record the file read to avoid modification time check failure - recordFileRead(filePath) - - // Try to update the file - params := EditParams{ - FilePath: filePath, - OldString: "Initial", - NewString: "Updated", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: EditToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "permission denied") - - // Verify file was not modified - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, initialContent, string(fileContent)) - }) -} diff --git a/internal/llm/tools/mocks_test.go b/internal/llm/tools/mocks_test.go deleted file mode 100644 index 81993160c0384c31a4cc49bd53c40fa94cefd1a3..0000000000000000000000000000000000000000 --- a/internal/llm/tools/mocks_test.go +++ /dev/null @@ -1,246 +0,0 @@ -package tools - -import ( - "context" - "fmt" - "sort" - "strconv" - "strings" - "time" - - "github.com/google/uuid" - "github.com/kujtimiihoxha/opencode/internal/history" - "github.com/kujtimiihoxha/opencode/internal/permission" - "github.com/kujtimiihoxha/opencode/internal/pubsub" -) - -// Mock permission service for testing -type mockPermissionService struct { - *pubsub.Broker[permission.PermissionRequest] - allow bool -} - -func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Grant(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Deny(permission permission.PermissionRequest) { - // Not needed for tests -} - -func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool { - return m.allow -} - -func newMockPermissionService(allow bool) permission.Service { - return &mockPermissionService{ - Broker: pubsub.NewBroker[permission.PermissionRequest](), - allow: allow, - } -} - -type mockFileHistoryService struct { - *pubsub.Broker[history.File] - files map[string]history.File // ID -> File - timeNow func() int64 -} - -// Create implements history.Service. -func (m *mockFileHistoryService) Create(ctx context.Context, sessionID string, path string, content string) (history.File, error) { - return m.createWithVersion(ctx, sessionID, path, content, history.InitialVersion) -} - -// CreateVersion implements history.Service. -func (m *mockFileHistoryService) CreateVersion(ctx context.Context, sessionID string, path string, content string) (history.File, error) { - var files []history.File - for _, file := range m.files { - if file.Path == path { - files = append(files, file) - } - } - - if len(files) == 0 { - // No previous versions, create initial - return m.Create(ctx, sessionID, path, content) - } - - // Sort files by CreatedAt in descending order - sort.Slice(files, func(i, j int) bool { - return files[i].CreatedAt > files[j].CreatedAt - }) - - // Get the latest version - latestFile := files[0] - latestVersion := latestFile.Version - - // Generate the next version - var nextVersion string - if latestVersion == history.InitialVersion { - nextVersion = "v1" - } else if strings.HasPrefix(latestVersion, "v") { - versionNum, err := strconv.Atoi(latestVersion[1:]) - if err != nil { - // If we can't parse the version, just use a timestamp-based version - nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) - } else { - nextVersion = fmt.Sprintf("v%d", versionNum+1) - } - } else { - // If the version format is unexpected, use a timestamp-based version - nextVersion = fmt.Sprintf("v%d", latestFile.CreatedAt) - } - - return m.createWithVersion(ctx, sessionID, path, content, nextVersion) -} - -func (m *mockFileHistoryService) createWithVersion(_ context.Context, sessionID, path, content, version string) (history.File, error) { - now := m.timeNow() - file := history.File{ - ID: uuid.New().String(), - SessionID: sessionID, - Path: path, - Content: content, - Version: version, - CreatedAt: now, - UpdatedAt: now, - } - - m.files[file.ID] = file - m.Publish(pubsub.CreatedEvent, file) - return file, nil -} - -// Delete implements history.Service. -func (m *mockFileHistoryService) Delete(ctx context.Context, id string) error { - file, ok := m.files[id] - if !ok { - return fmt.Errorf("file not found: %s", id) - } - - delete(m.files, id) - m.Publish(pubsub.DeletedEvent, file) - return nil -} - -// DeleteSessionFiles implements history.Service. -func (m *mockFileHistoryService) DeleteSessionFiles(ctx context.Context, sessionID string) error { - files, err := m.ListBySession(ctx, sessionID) - if err != nil { - return err - } - - for _, file := range files { - err = m.Delete(ctx, file.ID) - if err != nil { - return err - } - } - - return nil -} - -// Get implements history.Service. -func (m *mockFileHistoryService) Get(ctx context.Context, id string) (history.File, error) { - file, ok := m.files[id] - if !ok { - return history.File{}, fmt.Errorf("file not found: %s", id) - } - return file, nil -} - -// GetByPathAndSession implements history.Service. -func (m *mockFileHistoryService) GetByPathAndSession(ctx context.Context, path string, sessionID string) (history.File, error) { - var latestFile history.File - var found bool - var latestTime int64 - - for _, file := range m.files { - if file.Path == path && file.SessionID == sessionID { - if !found || file.CreatedAt > latestTime { - latestFile = file - latestTime = file.CreatedAt - found = true - } - } - } - - if !found { - return history.File{}, fmt.Errorf("file not found: %s for session %s", path, sessionID) - } - return latestFile, nil -} - -// ListBySession implements history.Service. -func (m *mockFileHistoryService) ListBySession(ctx context.Context, sessionID string) ([]history.File, error) { - var files []history.File - for _, file := range m.files { - if file.SessionID == sessionID { - files = append(files, file) - } - } - - // Sort by CreatedAt in descending order - sort.Slice(files, func(i, j int) bool { - return files[i].CreatedAt > files[j].CreatedAt - }) - - return files, nil -} - -// ListLatestSessionFiles implements history.Service. -func (m *mockFileHistoryService) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]history.File, error) { - // Map to track the latest file for each path - latestFiles := make(map[string]history.File) - - for _, file := range m.files { - if file.SessionID == sessionID { - existing, ok := latestFiles[file.Path] - if !ok || file.CreatedAt > existing.CreatedAt { - latestFiles[file.Path] = file - } - } - } - - // Convert map to slice - var result []history.File - for _, file := range latestFiles { - result = append(result, file) - } - - // Sort by CreatedAt in descending order - sort.Slice(result, func(i, j int) bool { - return result[i].CreatedAt > result[j].CreatedAt - }) - - return result, nil -} - -// Subscribe implements history.Service. -func (m *mockFileHistoryService) Subscribe(ctx context.Context) <-chan pubsub.Event[history.File] { - return m.Broker.Subscribe(ctx) -} - -// Update implements history.Service. -func (m *mockFileHistoryService) Update(ctx context.Context, file history.File) (history.File, error) { - _, ok := m.files[file.ID] - if !ok { - return history.File{}, fmt.Errorf("file not found: %s", file.ID) - } - - file.UpdatedAt = m.timeNow() - m.files[file.ID] = file - m.Publish(pubsub.UpdatedEvent, file) - return file, nil -} - -func newMockFileHistoryService() history.Service { - return &mockFileHistoryService{ - Broker: pubsub.NewBroker[history.File](), - files: make(map[string]history.File), - timeNow: func() int64 { return time.Now().Unix() }, - } -} diff --git a/internal/llm/tools/patch.go b/internal/llm/tools/patch.go new file mode 100644 index 0000000000000000000000000000000000000000..12060d72a363b16fffc8af2c43fa14cf496597dd --- /dev/null +++ b/internal/llm/tools/patch.go @@ -0,0 +1,300 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/lsp" + "github.com/kujtimiihoxha/opencode/internal/permission" +) + +type PatchParams struct { + FilePath string `json:"file_path"` + Patch string `json:"patch"` +} + +type PatchPermissionsParams struct { + FilePath string `json:"file_path"` + Diff string `json:"diff"` +} + +type PatchResponseMetadata struct { + Diff string `json:"diff"` + Additions int `json:"additions"` + Removals int `json:"removals"` +} + +type patchTool struct { + lspClients map[string]*lsp.Client + permissions permission.Service + files history.Service +} + +const ( + // TODO: test if this works as expected + PatchToolName = "patch" + patchDescription = `Applies a patch to a file. This tool is similar to the edit tool but accepts a unified diff patch instead of old/new strings. + +Before using this tool: + +1. Use the FileRead tool to understand the file's contents and context + +2. Verify the directory path is correct: + - Use the LS tool to verify the parent directory exists and is the correct location + +To apply a patch, provide the following: +1. file_path: The absolute path to the file to modify (must be absolute, not relative) +2. patch: A unified diff patch to apply to the file + +The tool will apply the patch to the specified file. The patch must be in unified diff format. + +CRITICAL REQUIREMENTS FOR USING THIS TOOL: + +1. PATCH FORMAT: The patch must be in unified diff format, which includes: + - File headers (--- a/file_path, +++ b/file_path) + - Hunk headers (@@ -start,count +start,count @@) + - Added lines (prefixed with +) + - Removed lines (prefixed with -) + +2. CONTEXT: The patch must include sufficient context around the changes to ensure it applies correctly. + +3. VERIFICATION: Before using this tool: + - Ensure the patch applies cleanly to the current state of the file + - Check that the file exists and you have read it first + +WARNING: If you do not follow these requirements: + - The tool will fail if the patch doesn't apply cleanly + - You may change the wrong parts of the file if the context is insufficient + +When applying patches: + - Ensure the patch results in idiomatic, correct code + - Do not leave the code in a broken state + - Always use absolute file paths (starting with /) + +Remember: patches are a powerful way to make multiple related changes at once, but they require careful preparation.` +) + +func NewPatchTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { + return &patchTool{ + lspClients: lspClients, + permissions: permissions, + files: files, + } +} + +func (p *patchTool) Info() ToolInfo { + return ToolInfo{ + Name: PatchToolName, + Description: patchDescription, + Parameters: map[string]any{ + "file_path": map[string]any{ + "type": "string", + "description": "The absolute path to the file to modify", + }, + "patch": map[string]any{ + "type": "string", + "description": "The unified diff patch to apply", + }, + }, + Required: []string{"file_path", "patch"}, + } +} + +func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + var params PatchParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return NewTextErrorResponse("invalid parameters"), nil + } + + if params.FilePath == "" { + return NewTextErrorResponse("file_path is required"), nil + } + + if params.Patch == "" { + return NewTextErrorResponse("patch is required"), nil + } + + if !filepath.IsAbs(params.FilePath) { + wd := config.WorkingDirectory() + params.FilePath = filepath.Join(wd, params.FilePath) + } + + // Check if file exists + fileInfo, err := os.Stat(params.FilePath) + if err != nil { + if os.IsNotExist(err) { + return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil + } + return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) + } + + if fileInfo.IsDir() { + return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil + } + + if getLastReadTime(params.FilePath).IsZero() { + return NewTextErrorResponse("you must read the file before patching it. Use the View tool first"), nil + } + + modTime := fileInfo.ModTime() + lastRead := getLastReadTime(params.FilePath) + if modTime.After(lastRead) { + return NewTextErrorResponse( + fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), + )), nil + } + + // Read the current file content + content, err := os.ReadFile(params.FilePath) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) + } + + oldContent := string(content) + + // Parse and apply the patch + diffResult, err := diff.ParseUnifiedDiff(params.Patch) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("failed to parse patch: %v", err)), nil + } + + // Apply the patch to get the new content + newContent, err := applyPatch(oldContent, diffResult) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("failed to apply patch: %v", err)), nil + } + + if oldContent == newContent { + return NewTextErrorResponse("patch did not result in any changes to the file"), nil + } + + sessionID, messageID := GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for patching a file") + } + + // Generate a diff for permission request and metadata + diffText, additions, removals := diff.GenerateDiff( + oldContent, + newContent, + params.FilePath, + ) + + // Request permission to apply the patch + p.permissions.Request( + permission.CreatePermissionRequest{ + Path: filepath.Dir(params.FilePath), + ToolName: PatchToolName, + Action: "patch", + Description: fmt.Sprintf("Apply patch to file %s", params.FilePath), + Params: PatchPermissionsParams{ + FilePath: params.FilePath, + Diff: diffText, + }, + }, + ) + + // Write the new content to the file + err = os.WriteFile(params.FilePath, []byte(newContent), 0o644) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) + } + + // Update file history + file, err := p.files.GetByPathAndSession(ctx, params.FilePath, sessionID) + if err != nil { + _, err = p.files.Create(ctx, sessionID, params.FilePath, oldContent) + if err != nil { + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User manually changed the content, store an intermediate version + _, err = p.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + } + // Store the new version + _, err = p.files.CreateVersion(ctx, sessionID, params.FilePath, newContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + + recordFileWrite(params.FilePath) + recordFileRead(params.FilePath) + + // Wait for LSP diagnostics and include them in the response + waitForLspDiagnostics(ctx, params.FilePath, p.lspClients) + text := fmt.Sprintf("\nPatch applied to file: %s\n\n", params.FilePath) + text += getDiagnostics(params.FilePath, p.lspClients) + + return WithResponseMetadata( + NewTextResponse(text), + PatchResponseMetadata{ + Diff: diffText, + Additions: additions, + Removals: removals, + }), nil +} + +// applyPatch applies a parsed diff to a string and returns the resulting content +func applyPatch(content string, diffResult diff.DiffResult) (string, error) { + lines := strings.Split(content, "\n") + + // Process each hunk in the diff + for _, hunk := range diffResult.Hunks { + // Parse the hunk header to get line numbers + var oldStart, oldCount, newStart, newCount int + _, err := fmt.Sscanf(hunk.Header, "@@ -%d,%d +%d,%d @@", &oldStart, &oldCount, &newStart, &newCount) + if err != nil { + // Try alternative format with single line counts + _, err = fmt.Sscanf(hunk.Header, "@@ -%d +%d @@", &oldStart, &newStart) + if err != nil { + return "", fmt.Errorf("invalid hunk header format: %s", hunk.Header) + } + oldCount = 1 + newCount = 1 + } + + // Adjust for 0-based array indexing + oldStart-- + newStart-- + + // Apply the changes + newLines := make([]string, 0) + newLines = append(newLines, lines[:oldStart]...) + + // Process the hunk lines in order + currentOldLine := oldStart + for _, line := range hunk.Lines { + switch line.Kind { + case diff.LineContext: + newLines = append(newLines, line.Content) + currentOldLine++ + case diff.LineRemoved: + // Skip this line in the output (it's being removed) + currentOldLine++ + case diff.LineAdded: + // Add the new line + newLines = append(newLines, line.Content) + } + } + + // Append the rest of the file + newLines = append(newLines, lines[currentOldLine:]...) + lines = newLines + } + + return strings.Join(lines, "\n"), nil +} + diff --git a/internal/llm/tools/sourcegraph_test.go b/internal/llm/tools/sourcegraph_test.go deleted file mode 100644 index 89829aefcb1bffd399f7cc4f738a9c32cb7b6049..0000000000000000000000000000000000000000 --- a/internal/llm/tools/sourcegraph_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package tools - -import ( - "context" - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSourcegraphTool_Info(t *testing.T) { - tool := NewSourcegraphTool() - info := tool.Info() - - assert.Equal(t, SourcegraphToolName, info.Name) - assert.NotEmpty(t, info.Description) - assert.Contains(t, info.Parameters, "query") - assert.Contains(t, info.Parameters, "count") - assert.Contains(t, info.Parameters, "timeout") - assert.Contains(t, info.Required, "query") -} - -func TestSourcegraphTool_Run(t *testing.T) { - t.Run("handles missing query parameter", func(t *testing.T) { - tool := NewSourcegraphTool() - params := SourcegraphParams{ - Query: "", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: SourcegraphToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "Query parameter is required") - }) - - t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewSourcegraphTool() - call := ToolCall{ - Name: SourcegraphToolName, - Input: "invalid json", - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters") - }) - - t.Run("normalizes count parameter", func(t *testing.T) { - // Test cases for count normalization - testCases := []struct { - name string - inputCount int - expectedCount int - }{ - {"negative count", -5, 10}, // Should use default (10) - {"zero count", 0, 10}, // Should use default (10) - {"valid count", 50, 50}, // Should keep as is - {"excessive count", 150, 100}, // Should cap at 100 - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Verify count normalization logic directly - assert.NotPanics(t, func() { - // Apply the same normalization logic as in the tool - normalizedCount := tc.inputCount - if normalizedCount <= 0 { - normalizedCount = 10 - } else if normalizedCount > 100 { - normalizedCount = 100 - } - - assert.Equal(t, tc.expectedCount, normalizedCount) - }) - }) - } - }) -} diff --git a/internal/llm/tools/write_test.go b/internal/llm/tools/write_test.go deleted file mode 100644 index b5ecb3fda17287c5b6cfda75e1309fcb62547c55..0000000000000000000000000000000000000000 --- a/internal/llm/tools/write_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package tools - -import ( - "context" - "encoding/json" - "os" - "path/filepath" - "testing" - "time" - - "github.com/kujtimiihoxha/opencode/internal/lsp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestWriteTool_Info(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - info := tool.Info() - - assert.Equal(t, WriteToolName, info.Name) - assert.NotEmpty(t, info.Description) - assert.Contains(t, info.Parameters, "file_path") - assert.Contains(t, info.Parameters, "content") - assert.Contains(t, info.Required, "file_path") - assert.Contains(t, info.Required, "content") -} - -func TestWriteTool_Run(t *testing.T) { - // Create a temporary directory for testing - tempDir, err := os.MkdirTemp("", "write_tool_test") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - t.Run("creates a new file successfully", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - filePath := filepath.Join(tempDir, "new_file.txt") - content := "This is a test content" - - params := WriteParams{ - FilePath: filePath, - Content: content, - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "successfully written") - - // Verify file was created with correct content - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, content, string(fileContent)) - }) - - t.Run("creates file with nested directories", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt") - content := "Content in nested directory" - - params := WriteParams{ - FilePath: filePath, - Content: content, - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "successfully written") - - // Verify file was created with correct content - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, content, string(fileContent)) - }) - - t.Run("updates existing file", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file first - filePath := filepath.Join(tempDir, "existing_file.txt") - initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Record the file read to avoid modification time check failure - recordFileRead(filePath) - - // Update the file - updatedContent := "Updated content" - params := WriteParams{ - FilePath: filePath, - Content: updatedContent, - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "successfully written") - - // Verify file was updated with correct content - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, updatedContent, string(fileContent)) - }) - - t.Run("handles invalid parameters", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - call := ToolCall{ - Name: WriteToolName, - Input: "invalid json", - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "error parsing parameters") - }) - - t.Run("handles missing file_path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - params := WriteParams{ - FilePath: "", - Content: "Some content", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "file_path is required") - }) - - t.Run("handles missing content", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - params := WriteParams{ - FilePath: filepath.Join(tempDir, "file.txt"), - Content: "", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "content is required") - }) - - t.Run("handles writing to a directory path", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a directory - dirPath := filepath.Join(tempDir, "test_dir") - err := os.Mkdir(dirPath, 0o755) - require.NoError(t, err) - - params := WriteParams{ - FilePath: dirPath, - Content: "Some content", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "Path is a directory") - }) - - t.Run("handles permission denied", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(false), newMockFileHistoryService()) - - filePath := filepath.Join(tempDir, "permission_denied.txt") - params := WriteParams{ - FilePath: filePath, - Content: "Content that should not be written", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "Permission denied") - - // Verify file was not created - _, err = os.Stat(filePath) - assert.True(t, os.IsNotExist(err)) - }) - - t.Run("detects file modified since last read", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file - filePath := filepath.Join(tempDir, "modified_file.txt") - initialContent := "Initial content" - err := os.WriteFile(filePath, []byte(initialContent), 0o644) - require.NoError(t, err) - - // Record an old read time - fileRecordMutex.Lock() - fileRecords[filePath] = fileRecord{ - path: filePath, - readTime: time.Now().Add(-1 * time.Hour), - } - fileRecordMutex.Unlock() - - // Try to update the file - params := WriteParams{ - FilePath: filePath, - Content: "Updated content", - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "has been modified since it was last read") - - // Verify file was not modified - fileContent, err := os.ReadFile(filePath) - require.NoError(t, err) - assert.Equal(t, initialContent, string(fileContent)) - }) - - t.Run("skips writing when content is identical", func(t *testing.T) { - tool := NewWriteTool(make(map[string]*lsp.Client), newMockPermissionService(true), newMockFileHistoryService()) - - // Create a file - filePath := filepath.Join(tempDir, "identical_content.txt") - content := "Content that won't change" - err := os.WriteFile(filePath, []byte(content), 0o644) - require.NoError(t, err) - - // Record a read time - recordFileRead(filePath) - - // Try to write the same content - params := WriteParams{ - FilePath: filePath, - Content: content, - } - - paramsJSON, err := json.Marshal(params) - require.NoError(t, err) - - call := ToolCall{ - Name: WriteToolName, - Input: string(paramsJSON), - } - - response, err := tool.Run(context.Background(), call) - require.NoError(t, err) - assert.Contains(t, response.Content, "already contains the exact content") - }) -} diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index 54b39f4a1468340a7d769eaf23a682839746d980..fe2845a08d6de1ec669f9a98a1df4e9cae6820ae 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -3,6 +3,7 @@ package chat import ( "context" "fmt" + "sort" "strings" tea "github.com/charmbracelet/bubbletea" @@ -141,8 +142,17 @@ func (m *sidebarCmp) modifiedFiles() string { ) } + // Sort file paths alphabetically for consistent ordering + var paths []string + for path := range m.modFiles { + paths = append(paths, path) + } + sort.Strings(paths) + + // Create views for each file in sorted order var fileViews []string - for path, stats := range m.modFiles { + for _, path := range paths { + stats := m.modFiles[path] fileViews = append(fileViews, m.modifiedFile(path, stats.additions, stats.removals)) } diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 200a7970d95c4a494802f164774017704bd9868a..2958844320f5dec51447fdf5f88bf6a243d686eb 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -266,6 +266,18 @@ func (p *permissionDialogCmp) renderEditContent() string { return "" } +func (p *permissionDialogCmp) renderPatchContent() string { + if pr, ok := p.permission.Params.(tools.PatchPermissionsParams); ok { + diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { + return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) + }) + + p.contentViewPort.SetContent(diff) + return p.styleViewport() + } + return "" +} + func (p *permissionDialogCmp) renderWriteContent() string { if pr, ok := p.permission.Params.(tools.WritePermissionsParams); ok { // Use the cache for diff rendering @@ -350,6 +362,8 @@ func (p *permissionDialogCmp) render() string { contentFinal = p.renderBashContent() case tools.EditToolName: contentFinal = p.renderEditContent() + case tools.PatchToolName: + contentFinal = p.renderPatchContent() case tools.WriteToolName: contentFinal = p.renderWriteContent() case tools.FetchToolName: diff --git a/main.go b/main.go index 06578c7efb105990b6196917757c653d1ca8bdca..2b0761c69a7b8a2212ce57fef1a36b3e3008c8ca 100644 --- a/main.go +++ b/main.go @@ -6,11 +6,9 @@ import ( ) func main() { - // Set up panic recovery for the main function defer logging.RecoverPanic("main", func() { - // Perform any necessary cleanup before exit logging.ErrorPersist("Application terminated due to unhandled panic") }) - + cmd.Execute() } From caea29375994373f6027c8dc4d8aa536c4e135e7 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 17 Apr 2025 12:44:15 +0200 Subject: [PATCH 22/41] small fixes --- internal/config/config.go | 4 ++ internal/llm/prompt/coder.go | 95 +++++++++++++++--------------------- 2 files changed, 42 insertions(+), 57 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 20a8bac9750e0d0cfd5e7a8c194527ac714bef89..5b6d51efa9b2fa20b86d4e27a85619706fc57d97 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -156,6 +156,10 @@ func Load(workingDir string, debug bool) (*Config, error) { slog.SetDefault(logger) } + if cfg.Agents == nil { + cfg.Agents = make(map[AgentName]Agent) + } + // Override the max tokens for title agent cfg.Agents[AgentTitle] = Agent{ Model: cfg.Agents[AgentTitle].Model, diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index febdea4d250bdfc7b9bb2619a68d65354f0c6fa1..3a06911dadf3f9bb04aa0f66d8b136b836468f1f 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -25,63 +25,44 @@ func CoderPrompt(provider models.ModelProvider) string { } const baseOpenAICoderPrompt = ` -# OpenCode CLI Agent Prompt - -You are operating within the **OpenCode CLI**, a terminal-based, agentic coding assistant that interfaces with local codebases through natural language. Your primary objectives are to be precise, safe, and helpful. - -## Capabilities - -- Receive user prompts, project context, and files. -- Stream responses and emit function calls (e.g., shell commands, code edits). -- Apply patches, run commands, and manage user approvals based on policy. -- Operate within a sandboxed, git-backed workspace with rollback support. -- Log telemetry for session replay or inspection. -- Access detailed functionality via the help command. - -## Operational Guidelines - -### 1. Task Resolution - -- Continue processing until the user's query is fully resolved. -- Only conclude your turn when confident the problem is solved. -- If uncertain about file content or codebase structure, utilize available tools to gather necessary information—avoid assumptions. - -### 2. Code Modification & Testing - -- Edit and test code files within your current execution session. -- Work on the local repositories, even if proprietary. -- Analyze code for vulnerabilities when applicable. -- Display user code and tool call details transparently. - -### 3. Coding Guidelines - -- Address root causes rather than applying superficial fixes. -- Avoid unnecessary complexity; focus on the task at hand. -- Update documentation as needed. -- Maintain consistency with the existing codebase style. -- Utilize version control tools for additional context; note that internet access is disabled. -- Refrain from adding copyright or license headers unless explicitly requested. -- No need to perform commit operations; this will be handled automatically. -- If a pre-commit configuration file exists, run the appropriate checks to ensure changes pass. Do not fix pre-existing errors on untouched lines. -- If pre-commit checks fail after retries, inform the user that the setup may be broken. - -### 4. Post-Modification Checks - -- Use version control status commands to verify changes; revert any unintended modifications. -- Remove all added inline comments unless they are essential for understanding. -- Ensure no accidental addition of copyright or license headers. -- Attempt to run pre-commit checks if available. -- For smaller tasks, provide brief bullet points summarizing changes. -- For complex tasks, include a high-level description, bullet points, and relevant details for code reviewers. - -### 5. Non-Code Modification Tasks - -- Respond in a friendly, collaborative tone, akin to a knowledgeable remote teammate eager to assist with coding inquiries. - -### 6. File Handling - -- Do not instruct the user to save or copy code into files if modifications have already been made using the editing tools. -- Avoid displaying full contents of large files unless explicitly requested by the user. +You are **OpenCode**, an autonomous CLI assistant for software‑engineering tasks. + +### ── INTERNAL REFLECTION ── +• Silently think step‑by‑step about the user request, directory layout, and tool calls (never reveal this). +• Formulate a plan, then execute without further approval unless a blocker triggers the Ask‑Only‑If rules. + +### ── PUBLIC RESPONSE RULES ── +• Visible reply ≤ 4 lines; no fluff, preamble, or postamble. +• Use GitHub‑flavored Markdown. +• When running a non‑trivial shell command, add ≤ 1 brief purpose sentence. + +### ── CONTEXT & MEMORY ── +• Infer file intent from directory structure before editing. +• Auto‑load 'OpenCode.md'; ask once before writing new reusable commands or style notes. + +### ── AUTONOMY PRIORITY ── +**Ask‑Only‑If Decision Tree:** +1. **Safety risk?** (e.g., destructive command, secret exposure) → ask. +2. **Critical unknown?** (no docs/tests; cannot infer) → ask. +3. **Tool failure after two self‑attempts?** → ask. +Otherwise, proceed autonomously. + +### ── SAFETY & STYLE ── +• Mimic existing code style; verify libraries exist before import. +• Never commit unless explicitly told. +• After edits, run lint & type‑check (ask for commands once, then offer to store in 'OpenCode.md'). +• Protect secrets; follow standard security practices :contentReference[oaicite:2]{index=2}. + +### ── TOOL USAGE ── +• Batch independent Agent search/file calls in one block for efficiency :contentReference[oaicite:3]{index=3}. +• Communicate with the user only via visible text; do not expose tool output or internal reasoning. + +### ── EXAMPLES ── +user: list files +assistant: ls + +user: write tests for new feature +assistant: [searches & edits autonomously, no extra chit‑chat] ` const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. From c24e3c18e0177e2e059f622a63e87249c52db2d5 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 17 Apr 2025 13:45:26 +0200 Subject: [PATCH 23/41] small improvements --- .opencode.json | 11 --- internal/app/app.go | 2 +- internal/db/migrations/000001_initial.up.sql | 3 +- internal/history/file.go | 76 ++++++++++++++++---- internal/tui/components/chat/editor.go | 4 +- internal/tui/components/chat/sidebar.go | 5 +- internal/tui/components/core/status.go | 56 ++++++++++++++- internal/tui/components/dialog/help.go | 4 +- internal/tui/layout/container.go | 1 + internal/tui/page/chat.go | 22 ++++-- internal/tui/tui.go | 2 + 11 files changed, 149 insertions(+), 37 deletions(-) diff --git a/.opencode.json b/.opencode.json index 4b2944f869b29804808361de59352957fc18ef81..b7fc19b524371cf7e4a625173f2fe305914694d3 100644 --- a/.opencode.json +++ b/.opencode.json @@ -3,16 +3,5 @@ "gopls": { "command": "gopls" } - }, - "agents": { - "coder": { - "model": "gpt-4.1" - }, - "task": { - "model": "gpt-4.1" - }, - "title": { - "model": "gpt-4.1" - } } } diff --git a/internal/app/app.go b/internal/app/app.go index 748fdaa7f0e2ec6dc314defa7b5c8d0c3c1b8c96..8f4f5e098197530812a21c8d84973f570920458b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -39,7 +39,7 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { q := db.New(conn) sessions := session.NewService(q) messages := message.NewService(q) - files := history.NewService(q) + files := history.NewService(q, conn) app := &App{ Sessions: sessions, diff --git a/internal/db/migrations/000001_initial.up.sql b/internal/db/migrations/000001_initial.up.sql index 4ac297dc5f1ab16f38eeb81a7b3135f64cbf9860..b846ec600e47663a63a998f75bc14be32fa08898 100644 --- a/internal/db/migrations/000001_initial.up.sql +++ b/internal/db/migrations/000001_initial.up.sql @@ -27,7 +27,8 @@ CREATE TABLE IF NOT EXISTS files ( version TEXT NOT NULL, created_at INTEGER NOT NULL, -- Unix timestamp in milliseconds updated_at INTEGER NOT NULL, -- Unix timestamp in milliseconds - FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE + FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE, + UNIQUE(path, session_id, version) ); CREATE INDEX IF NOT EXISTS idx_files_session_id ON files (session_id); diff --git a/internal/history/file.go b/internal/history/file.go index 1e8bc50bb24d0563788a914d06adf824b86bd2c3..8453ac272926761344c289942d66c05d44401d4e 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -2,9 +2,11 @@ package history import ( "context" + "database/sql" "fmt" "strconv" "strings" + "time" "github.com/google/uuid" "github.com/kujtimiihoxha/opencode/internal/db" @@ -40,10 +42,11 @@ type Service interface { type service struct { *pubsub.Broker[File] - q db.Querier + db *sql.DB + q *db.Queries } -func NewService(q db.Querier) Service { +func NewService(q *db.Queries, db *sql.DB) Service { return &service{ Broker: pubsub.NewBroker[File](), q: q, @@ -91,19 +94,64 @@ func (s *service) CreateVersion(ctx context.Context, sessionID, path, content st } func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string) (File, error) { - dbFile, err := s.q.CreateFile(ctx, db.CreateFileParams{ - ID: uuid.New().String(), - SessionID: sessionID, - Path: path, - Content: content, - Version: version, - }) - if err != nil { - return File{}, err + // Maximum number of retries for transaction conflicts + const maxRetries = 3 + var file File + var err error + + // Retry loop for transaction conflicts + for attempt := 0; attempt < maxRetries; attempt++ { + // Start a transaction + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return File{}, fmt.Errorf("failed to begin transaction: %w", err) + } + + // Create a new queries instance with the transaction + qtx := s.q.WithTx(tx) + + // Try to create the file within the transaction + dbFile, err := qtx.CreateFile(ctx, db.CreateFileParams{ + ID: uuid.New().String(), + SessionID: sessionID, + Path: path, + Content: content, + Version: version, + }) + if err != nil { + // Rollback the transaction + tx.Rollback() + + // Check if this is a uniqueness constraint violation + if strings.Contains(err.Error(), "UNIQUE constraint failed") { + if attempt < maxRetries-1 { + // If we have retries left, generate a new version and try again + if strings.HasPrefix(version, "v") { + versionNum, parseErr := strconv.Atoi(version[1:]) + if parseErr == nil { + version = fmt.Sprintf("v%d", versionNum+1) + continue + } + } + // If we can't parse the version, use a timestamp-based version + version = fmt.Sprintf("v%d", time.Now().Unix()) + continue + } + } + return File{}, err + } + + // Commit the transaction + if err = tx.Commit(); err != nil { + return File{}, fmt.Errorf("failed to commit transaction: %w", err) + } + + file = s.fromDBItem(dbFile) + s.Publish(pubsub.CreatedEvent, file) + return file, nil } - file := s.fromDBItem(dbFile) - s.Publish(pubsub.CreatedEvent, file) - return file, nil + + return file, err } func (s *service) Get(ctx context.Context, id string) (File, error) { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index 4d6ef5ca0a4b902dcd50c750659caeb02d6aeb18..ded0639bb612b66a91fde5a6220c6fae2e67e23c 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -118,12 +118,14 @@ func (m *editorCmp) GetSize() (int, int) { } func (m *editorCmp) BindingKeys() []key.Binding { - bindings := layout.KeyMapToSlice(m.textarea.KeyMap) + bindings := []key.Binding{} if m.textarea.Focused() { bindings = append(bindings, layout.KeyMapToSlice(focusedKeyMaps)...) } else { bindings = append(bindings, layout.KeyMapToSlice(bluredKeyMaps)...) } + + bindings = append(bindings, layout.KeyMapToSlice(m.textarea.KeyMap)...) return bindings } diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index fe2845a08d6de1ec669f9a98a1df4e9cae6820ae..5a275c0cfd1d264571fff33dedb5dae1c282b8d3 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -127,7 +127,7 @@ func (m *sidebarCmp) modifiedFiles() string { // If no modified files, show a placeholder message if m.modFiles == nil || len(m.modFiles) == 0 { message := "No modified files" - remainingWidth := m.width - lipgloss.Width(modifiedFiles) + remainingWidth := m.width - lipgloss.Width(message) if remainingWidth > 0 { message += strings.Repeat(" ", remainingWidth) } @@ -223,6 +223,9 @@ func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) { if initialVersion.ID == "" { continue } + if initialVersion.Content == file.Content { + continue + } // Calculate diff between initial and latest version _, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path) diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 411cac1c518c8a37c187f2b31500eb1f64e47988..e76ecde84e9575265ca06007a5b0697f543c6861 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -11,6 +11,9 @@ import ( "github.com/kujtimiihoxha/opencode/internal/llm/models" "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/kujtimiihoxha/opencode/internal/lsp/protocol" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/components/chat" "github.com/kujtimiihoxha/opencode/internal/tui/styles" "github.com/kujtimiihoxha/opencode/internal/tui/util" ) @@ -20,6 +23,7 @@ type statusCmp struct { width int messageTTL time.Duration lspClients map[string]*lsp.Client + session session.Session } // clearMessageCmd is a command that clears status messages after a timeout @@ -38,6 +42,16 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.WindowSizeMsg: m.width = msg.Width return m, nil + case chat.SessionSelectedMsg: + m.session = msg + case chat.SessionClearedMsg: + m.session = session.Session{} + case pubsub.Event[session.Session]: + if msg.Type == pubsub.UpdatedEvent { + if m.session.ID == msg.Payload.ID { + m.session = msg.Payload + } + } case util.InfoMsg: m.info = msg ttl := msg.TTL @@ -53,8 +67,43 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var helpWidget = styles.Padded.Background(styles.ForgroundMid).Foreground(styles.BackgroundDarker).Bold(true).Render("ctrl+? help") +func formatTokensAndCost(tokens int64, cost float64) string { + // Format tokens in human-readable format (e.g., 110K, 1.2M) + var formattedTokens string + switch { + case tokens >= 1_000_000: + formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000) + case tokens >= 1_000: + formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000) + default: + formattedTokens = fmt.Sprintf("%d", tokens) + } + + // Remove .0 suffix if present + if strings.HasSuffix(formattedTokens, ".0K") { + formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1) + } + if strings.HasSuffix(formattedTokens, ".0M") { + formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1) + } + + // Format cost with $ symbol and 2 decimal places + formattedCost := fmt.Sprintf("$%.2f", cost) + + return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost) +} + func (m statusCmp) View() string { status := helpWidget + if m.session.ID != "" { + tokens := formatTokensAndCost(m.session.PromptTokens+m.session.CompletionTokens, m.session.Cost) + tokensStyle := styles.Padded. + Background(styles.Forground). + Foreground(styles.BackgroundDim). + Render(tokens) + status += tokensStyle + } + diagnostics := styles.Padded.Background(styles.BackgroundDarker).Render(m.projectDiagnostics()) if m.info.Msg != "" { infoStyle := styles.Padded. @@ -82,6 +131,7 @@ func (m statusCmp) View() string { Width(m.availableFooterMsgWidth(diagnostics)). Render("") } + status += diagnostics status += m.model() return status @@ -136,7 +186,11 @@ func (m *statusCmp) projectDiagnostics() string { } func (m statusCmp) availableFooterMsgWidth(diagnostics string) int { - return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)) + tokens := "" + if m.session.ID != "" { + tokens = formatTokensAndCost(m.session.PromptTokens+m.session.CompletionTokens, m.session.Cost) + } + return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)-lipgloss.Width(tokens)) } func (m statusCmp) model() string { diff --git a/internal/tui/components/dialog/help.go b/internal/tui/components/dialog/help.go index 6242017f100cd589314f8e12722f7234a1536f1d..644b294cb9e4330710249f5fd4b3570ba430d377 100644 --- a/internal/tui/components/dialog/help.go +++ b/internal/tui/components/dialog/help.go @@ -26,7 +26,7 @@ func (h *helpCmp) SetBindings(k []key.Binding) { func (h *helpCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: - h.width = 80 + h.width = 90 h.height = msg.Height } return h, nil @@ -62,7 +62,7 @@ func (h *helpCmp) render() string { var ( pairs []string width int - rows = 12 - 2 + rows = 14 - 2 ) for i := 0; i < len(bindings); i += rows { var ( diff --git a/internal/tui/layout/container.go b/internal/tui/layout/container.go index 60369995591b757cd57416d129dce92c0a465d6f..c86d954ead7c3db2733d853c2ef208f45270ad67 100644 --- a/internal/tui/layout/container.go +++ b/internal/tui/layout/container.go @@ -10,6 +10,7 @@ import ( type Container interface { tea.Model Sizeable + Bindings } type container struct { width int diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index c268e677f4a49493d9bad69f86d70a4c9de71e0e..632e107641c0c7fb1fd2339ec4ea304f895948f5 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -15,9 +15,12 @@ import ( var ChatPage PageID = "chat" type chatPage struct { - app *app.App - layout layout.SplitPaneLayout - session session.Session + app *app.App + editor layout.Container + messages layout.Container + layout layout.SplitPaneLayout + session session.Session + editingMode bool } type ChatKeyMap struct { @@ -59,6 +62,8 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if cmd != nil { return p, cmd } + case chat.EditorFocusMsg: + p.editingMode = bool(msg) case tea.KeyMsg: switch { case key.Matches(msg, keyMap.NewSession): @@ -133,7 +138,11 @@ func (p *chatPage) View() string { func (p *chatPage) BindingKeys() []key.Binding { bindings := layout.KeyMapToSlice(keyMap) - bindings = append(bindings, p.layout.BindingKeys()...) + if p.editingMode { + bindings = append(bindings, p.editor.BindingKeys()...) + } else { + bindings = append(bindings, p.messages.BindingKeys()...) + } return bindings } @@ -148,7 +157,10 @@ func NewChatPage(app *app.App) tea.Model { layout.WithBorder(true, false, false, false), ) return &chatPage{ - app: app, + app: app, + editor: editorContainer, + messages: messagesContainer, + editingMode: true, layout: layout.NewSplitPane( layout.WithLeftPanel(messagesContainer), layout.WithBottomPanel(editorContainer), diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 657de6b6e56d3bf739e8a4d3bdff447abd8227eb..840ad4905875a31ae0a588fca378d875a5512105 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -215,6 +215,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, tea.Batch(cmds...) } } + + a.status, _ = a.status.Update(msg) a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) return a, tea.Batch(cmds...) From e3a62736db3f16c4d2b55a9eeb6b080b2c625a83 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 17 Apr 2025 15:17:12 +0200 Subject: [PATCH 24/41] add license --- LICENSE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e6208d7752eded10870a415286eb4cd3b1e28912 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Kujtim Hoxha + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From 05d0e86f10369fd0e51a924ac88029fb92591499 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 17 Apr 2025 16:03:34 +0200 Subject: [PATCH 25/41] update logs --- .gitignore | 4 +- internal/tui/components/logs/details.go | 29 +--------- internal/tui/components/logs/table.go | 45 +++++++-------- internal/tui/page/logs.go | 77 ++++++++++++++++++++----- 4 files changed, 86 insertions(+), 69 deletions(-) diff --git a/.gitignore b/.gitignore index b4d5d61ea39cc833b733b8dcb24f142367a9e85c..2603e630d2be36af272cbe05b7730d3ddc9dac5c 100644 --- a/.gitignore +++ b/.gitignore @@ -41,7 +41,5 @@ Thumbs.db .env .env.local -.opencode +.opencode/ -internal/assets/diff/index.mjs -cmd/test/* diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index 3a8f1799931ed2d16f65c6e572ae58520bd299e4..7c74da10497fa2e8882be95b10e4014247e7b302 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -22,7 +22,6 @@ type DetailComponent interface { type detailCmp struct { width, height int - focused bool currentLog logging.LogMessage viewport viewport.Model } @@ -37,11 +36,6 @@ func (i *detailCmp) Init() tea.Cmd { } func (i *detailCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var ( - cmd tea.Cmd - cmds []tea.Cmd - ) - switch msg := msg.(type) { case selectedLogMsg: if msg.ID != i.currentLog.ID { @@ -50,12 +44,7 @@ func (i *detailCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } - if i.focused { - i.viewport, cmd = i.viewport.Update(msg) - cmds = append(cmds, cmd) - } - - return i, tea.Batch(cmds...) + return i, nil } func (i *detailCmp) updateContent() { @@ -123,21 +112,7 @@ func getLevelStyle(level string) lipgloss.Style { } func (i *detailCmp) View() string { - return i.viewport.View() -} - -func (i *detailCmp) Blur() tea.Cmd { - i.focused = false - return nil -} - -func (i *detailCmp) Focus() tea.Cmd { - i.focused = true - return nil -} - -func (i *detailCmp) IsFocused() bool { - return i.focused + return styles.ForceReplaceBackgroundWithLipgloss(i.viewport.View(), styles.Background) } func (i *detailCmp) GetSize() (int, int) { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index dc6184e3df5f4509fc65805a7aaea9ecc65bcb31..2d0f9c533da41eb58c16836278500d780831984e 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -33,37 +33,35 @@ func (i *tableCmp) Init() tea.Cmd { func (i *tableCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd - if i.table.Focused() { - switch msg.(type) { - case pubsub.Event[logging.LogMessage]: - i.setRows() - return i, nil - } - prevSelectedRow := i.table.SelectedRow() - t, cmd := i.table.Update(msg) - cmds = append(cmds, cmd) - i.table = t - selectedRow := i.table.SelectedRow() - if selectedRow != nil { - if prevSelectedRow == nil || selectedRow[0] == prevSelectedRow[0] { - var log logging.LogMessage - for _, row := range logging.List() { - if row.ID == selectedRow[0] { - log = row - break - } - } - if log.ID != "" { - cmds = append(cmds, util.CmdHandler(selectedLogMsg(log))) + switch msg.(type) { + case pubsub.Event[logging.LogMessage]: + i.setRows() + return i, nil + } + prevSelectedRow := i.table.SelectedRow() + t, cmd := i.table.Update(msg) + cmds = append(cmds, cmd) + i.table = t + selectedRow := i.table.SelectedRow() + if selectedRow != nil { + if prevSelectedRow == nil || selectedRow[0] == prevSelectedRow[0] { + var log logging.LogMessage + for _, row := range logging.List() { + if row.ID == selectedRow[0] { + log = row + break } } + if log.ID != "" { + cmds = append(cmds, util.CmdHandler(selectedLogMsg(log))) + } } } return i, tea.Batch(cmds...) } func (i *tableCmp) View() string { - return i.table.View() + return styles.ForceReplaceBackgroundWithLipgloss(i.table.View(), styles.Background) } func (i *tableCmp) GetSize() (int, int) { @@ -128,6 +126,7 @@ func NewLogsTable() TableComponent { table.WithColumns(columns), table.WithStyles(defaultStyles), ) + tableModel.Focus() return &tableCmp{ table: tableModel, } diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index c77a033f466d0b6912ea66c00c49ae2fe150199f..0efc69e6e4f39bbcc8eb0db2b08a7b5ebae92255 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -1,37 +1,82 @@ package page import ( + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/opencode/internal/tui/components/logs" "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" ) var LogsPage PageID = "logs" -type logsPage struct { - table logs.TableComponent - details logs.DetailComponent +type LogPage interface { + tea.Model + layout.Sizeable + layout.Bindings } - -func (p *logsPage) Init() tea.Cmd { - return nil +type logsPage struct { + width, height int + table layout.Container + details layout.Container } func (p *logsPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - return p, nil + switch msg := msg.(type) { + case tea.WindowSizeMsg: + p.width = msg.Width + p.height = msg.Height + p.table.SetSize(msg.Width, msg.Height/2) + p.details.SetSize(msg.Width, msg.Height/2) + } + + var cmds []tea.Cmd + table, cmd := p.table.Update(msg) + cmds = append(cmds, cmd) + p.table = table.(layout.Container) + details, cmd := p.details.Update(msg) + cmds = append(cmds, cmd) + p.details = details.(layout.Container) + + return p, tea.Batch(cmds...) } func (p *logsPage) View() string { - return p.table.View() + "\n" + p.details.View() + style := styles.BaseStyle.Width(p.width).Height(p.height) + return style.Render(lipgloss.JoinVertical(lipgloss.Top, + p.table.View(), + p.details.View(), + )) +} + +func (p *logsPage) BindingKeys() []key.Binding { + return p.table.BindingKeys() } -func NewLogsPage() tea.Model { - return layout.NewBentoLayout( - layout.BentoPanes{ - layout.BentoRightTopPane: logs.NewLogsTable(), - layout.BentoRightBottomPane: logs.NewLogsDetails(), - }, - layout.WithBentoLayoutCurrentPane(layout.BentoRightTopPane), - layout.WithBentoLayoutRightTopHeightRatio(0.5), +// GetSize implements LogPage. +func (p *logsPage) GetSize() (int, int) { + return p.width, p.height +} + +// SetSize implements LogPage. +func (p *logsPage) SetSize(width int, height int) { + p.width = width + p.height = height + p.table.SetSize(width, height/2) + p.details.SetSize(width, height/2) +} + +func (p *logsPage) Init() tea.Cmd { + return tea.Batch( + p.table.Init(), + p.details.Init(), ) } + +func NewLogsPage() LogPage { + return &logsPage{ + table: layout.NewContainer(logs.NewLogsTable(), layout.WithBorderAll(), layout.WithBorderColor(styles.ForgroundDim)), + details: layout.NewContainer(logs.NewLogsDetails(), layout.WithBorderAll(), layout.WithBorderColor(styles.ForgroundDim)), + } +} From 333ea6ec4b2abfc2c1a9c3f6b0918ca5d296347f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 18 Apr 2025 20:17:38 +0200 Subject: [PATCH 26/41] implement patch, update ui, improve rendering --- README.md | 301 +++++++- internal/config/config.go | 16 +- internal/diff/diff.go | 28 +- internal/diff/patch.go | 739 ++++++++++++++++++ internal/history/file.go | 21 +- internal/llm/agent/agent.go | 53 +- internal/llm/agent/tools.go | 3 +- internal/llm/models/anthropic.go | 6 + internal/llm/models/models.go | 64 +- internal/llm/models/openai.go | 169 +++++ internal/llm/prompt/coder.go | 83 ++- internal/llm/provider/openai.go | 49 +- internal/llm/tools/glob.go | 38 +- internal/llm/tools/grep.go | 37 +- internal/llm/tools/patch.go | 450 ++++++----- internal/llm/tools/view.go | 13 +- internal/tui/components/chat/editor.go | 58 +- internal/tui/components/chat/list.go | 463 ++++++++++++ internal/tui/components/chat/message.go | 561 ++++++++++++++ internal/tui/components/chat/messages.go | 742 ------------------- internal/tui/components/chat/sidebar.go | 94 ++- internal/tui/components/core/status.go | 24 +- internal/tui/components/dialog/permission.go | 16 +- internal/tui/components/dialog/session.go | 224 ++++++ internal/tui/components/logs/details.go | 16 +- internal/tui/components/logs/table.go | 3 +- internal/tui/layout/bento.go | 392 ---------- internal/tui/layout/border.go | 121 --- internal/tui/layout/container.go | 5 +- internal/tui/layout/grid.go | 254 ------- internal/tui/layout/layout.go | 6 +- internal/tui/layout/single.go | 189 ----- internal/tui/layout/split.go | 37 +- internal/tui/page/chat.go | 31 +- internal/tui/page/logs.go | 13 +- internal/tui/styles/background.go | 114 ++- internal/tui/styles/icons.go | 14 +- internal/tui/tui.go | 111 ++- 38 files changed, 3304 insertions(+), 2254 deletions(-) create mode 100644 internal/diff/patch.go create mode 100644 internal/llm/models/openai.go create mode 100644 internal/tui/components/chat/list.go create mode 100644 internal/tui/components/chat/message.go delete mode 100644 internal/tui/components/chat/messages.go create mode 100644 internal/tui/components/dialog/session.go delete mode 100644 internal/tui/layout/bento.go delete mode 100644 internal/tui/layout/border.go delete mode 100644 internal/tui/layout/grid.go delete mode 100644 internal/tui/layout/single.go diff --git a/README.md b/README.md index 564284c7f138fdc54ebca74d64745730c496635e..ef55b69294670f2e6b4fc410204224f3a785dee2 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,6 @@ A powerful terminal-based AI assistant for developers, providing intelligent coding assistance directly in your terminal. -[![OpenCode Demo](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy.svg)](https://asciinema.org/a/dtc4nJyGSZX79HRUmFLY3gmoy) - ## Overview OpenCode is a Go-based CLI application that brings AI assistance to your terminal. It provides a TUI (Terminal User Interface) for interacting with various AI models to help with coding tasks, debugging, and more. @@ -13,11 +11,13 @@ OpenCode is a Go-based CLI application that brings AI assistance to your termina ## Features - **Interactive TUI**: Built with [Bubble Tea](https://github.com/charmbracelet/bubbletea) for a smooth terminal experience -- **Multiple AI Providers**: Support for OpenAI, Anthropic Claude, and Google Gemini models +- **Multiple AI Providers**: Support for OpenAI, Anthropic Claude, Google Gemini, AWS Bedrock, and Groq - **Session Management**: Save and manage multiple conversation sessions - **Tool Integration**: AI can execute commands, search files, and modify code -- **Vim-like Editor**: Integrated editor with Vim keybindings for text input +- **Vim-like Editor**: Integrated editor with text input capabilities - **Persistent Storage**: SQLite database for storing conversations and sessions +- **LSP Integration**: Language Server Protocol support for code intelligence +- **File Change Tracking**: Track and visualize file changes during sessions ## Installation @@ -34,11 +34,107 @@ OpenCode looks for configuration in the following locations: - `$XDG_CONFIG_HOME/opencode/.opencode.json` - `./.opencode.json` (local directory) -You can also use environment variables: +### Environment Variables + +You can configure OpenCode using environment variables: + +| Environment Variable | Purpose | +| ----------------------- | ------------------------ | +| `ANTHROPIC_API_KEY` | For Claude models | +| `OPENAI_API_KEY` | For OpenAI models | +| `GEMINI_API_KEY` | For Google Gemini models | +| `GROQ_API_KEY` | For Groq models | +| `AWS_ACCESS_KEY_ID` | For AWS Bedrock (Claude) | +| `AWS_SECRET_ACCESS_KEY` | For AWS Bedrock (Claude) | +| `AWS_REGION` | For AWS Bedrock (Claude) | + +### Configuration File Structure + +```json +{ + "data": { + "directory": ".opencode" + }, + "providers": { + "openai": { + "apiKey": "your-api-key", + "disabled": false + }, + "anthropic": { + "apiKey": "your-api-key", + "disabled": false + } + }, + "agents": { + "coder": { + "model": "claude-3.7-sonnet", + "maxTokens": 5000 + }, + "task": { + "model": "claude-3.7-sonnet", + "maxTokens": 5000 + }, + "title": { + "model": "claude-3.7-sonnet", + "maxTokens": 80 + } + }, + "mcpServers": { + "example": { + "type": "stdio", + "command": "path/to/mcp-server", + "env": [], + "args": [] + } + }, + "lsp": { + "go": { + "disabled": false, + "command": "gopls" + } + }, + "debug": false, + "debugLSP": false +} +``` -- `ANTHROPIC_API_KEY`: For Claude models -- `OPENAI_API_KEY`: For OpenAI models -- `GEMINI_API_KEY`: For Google Gemini models +## Supported AI Models + +### OpenAI Models + +| Model ID | Name | Context Window | +| ----------------- | --------------- | ---------------- | +| `gpt-4.1` | GPT 4.1 | 1,047,576 tokens | +| `gpt-4.1-mini` | GPT 4.1 Mini | 200,000 tokens | +| `gpt-4.1-nano` | GPT 4.1 Nano | 1,047,576 tokens | +| `gpt-4.5-preview` | GPT 4.5 Preview | 128,000 tokens | +| `gpt-4o` | GPT-4o | 128,000 tokens | +| `gpt-4o-mini` | GPT-4o Mini | 128,000 tokens | +| `o1` | O1 | 200,000 tokens | +| `o1-pro` | O1 Pro | 200,000 tokens | +| `o1-mini` | O1 Mini | 128,000 tokens | +| `o3` | O3 | 200,000 tokens | +| `o3-mini` | O3 Mini | 200,000 tokens | +| `o4-mini` | O4 Mini | 128,000 tokens | + +### Anthropic Models + +| Model ID | Name | Context Window | +| ------------------- | ----------------- | -------------- | +| `claude-3.5-sonnet` | Claude 3.5 Sonnet | 200,000 tokens | +| `claude-3-haiku` | Claude 3 Haiku | 200,000 tokens | +| `claude-3.7-sonnet` | Claude 3.7 Sonnet | 200,000 tokens | +| `claude-3.5-haiku` | Claude 3.5 Haiku | 200,000 tokens | +| `claude-3-opus` | Claude 3 Opus | 200,000 tokens | + +### Other Models + +| Model ID | Provider | Name | Context Window | +| --------------------------- | ----------- | ----------------- | -------------- | +| `gemini-2.5` | Google | Gemini 2.5 Pro | - | +| `gemini-2.0-flash` | Google | Gemini 2.0 Flash | - | +| `qwen-qwq` | Groq | Qwen Qwq | - | +| `bedrock.claude-3.7-sonnet` | AWS Bedrock | Claude 3.7 Sonnet | - | ## Usage @@ -48,36 +144,78 @@ opencode # Start with debug logging opencode -d + +# Start with a specific working directory +opencode -c /path/to/project ``` -### Keyboard Shortcuts +## Command-line Flags + +| Flag | Short | Description | +| --------- | ----- | ----------------------------- | +| `--help` | `-h` | Display help information | +| `--debug` | `-d` | Enable debug mode | +| `--cwd` | `-c` | Set current working directory | + +## Keyboard Shortcuts + +### Global Shortcuts + +| Shortcut | Action | +| -------- | ------------------------------------------------------- | +| `Ctrl+C` | Quit application | +| `Ctrl+?` | Toggle help dialog | +| `Ctrl+L` | View logs | +| `Esc` | Close current overlay/dialog or return to previous mode | + +### Chat Page Shortcuts + +| Shortcut | Action | +| -------- | --------------------------------------- | +| `Ctrl+N` | Create new session | +| `Ctrl+X` | Cancel current operation/generation | +| `i` | Focus editor (when not in writing mode) | +| `Esc` | Exit writing mode and focus messages | + +### Editor Shortcuts -#### Global Shortcuts +| Shortcut | Action | +| ------------------- | ----------------------------------------- | +| `Ctrl+S` | Send message (when editor is focused) | +| `Enter` or `Ctrl+S` | Send message (when editor is not focused) | +| `Esc` | Blur editor and focus messages | -- `?`: Toggle help panel -- `Ctrl+C` or `q`: Quit application -- `L`: View logs -- `Backspace`: Go back to previous page -- `Esc`: Close current view/dialog or return to normal mode +### Logs Page Shortcuts -#### Session Management +| Shortcut | Action | +| ----------- | ------------------- | +| `Backspace` | Return to chat page | -- `N`: Create new session -- `Enter` or `Space`: Select session (in sessions list) +## AI Assistant Tools -#### Editor Shortcuts (Vim-like) +OpenCode's AI assistant has access to various tools to help with coding tasks: -- `i`: Enter insert mode -- `Esc`: Enter normal mode -- `v`: Enter visual mode -- `V`: Enter visual line mode -- `Enter`: Send message (in normal mode) -- `Ctrl+S`: Send message (in insert mode) +### File and Code Tools -#### Navigation +| Tool | Description | Parameters | +| ------------- | --------------------------- | ---------------------------------------------------------------------------------------- | +| `glob` | Find files by pattern | `pattern` (required), `path` (optional) | +| `grep` | Search file contents | `pattern` (required), `path` (optional), `include` (optional), `literal_text` (optional) | +| `ls` | List directory contents | `path` (optional), `ignore` (optional array of patterns) | +| `view` | View file contents | `file_path` (required), `offset` (optional), `limit` (optional) | +| `write` | Write to files | `file_path` (required), `content` (required) | +| `edit` | Edit files | Various parameters for file editing | +| `patch` | Apply patches to files | `file_path` (required), `diff` (required) | +| `diagnostics` | Get diagnostics information | `file_path` (optional) | -- Arrow keys: Navigate through lists and content -- Page Up/Down: Scroll through content +### Other Tools + +| Tool | Description | Parameters | +| ------------- | -------------------------------------- | ----------------------------------------------------------------------------------------- | +| `bash` | Execute shell commands | `command` (required), `timeout` (optional) | +| `fetch` | Fetch data from URLs | `url` (required), `format` (required), `timeout` (optional) | +| `sourcegraph` | Search code across public repositories | `query` (required), `count` (optional), `context_window` (optional), `timeout` (optional) | +| `agent` | Run sub-tasks with the AI agent | `prompt` (required) | ## Architecture @@ -92,6 +230,101 @@ OpenCode is built with a modular architecture: - **internal/logging**: Logging infrastructure - **internal/message**: Message handling - **internal/session**: Session management +- **internal/lsp**: Language Server Protocol integration + +## MCP (Model Context Protocol) + +OpenCode implements the Model Context Protocol (MCP) to extend its capabilities through external tools. MCP provides a standardized way for the AI assistant to interact with external services and tools. + +### MCP Features + +- **External Tool Integration**: Connect to external tools and services via a standardized protocol +- **Tool Discovery**: Automatically discover available tools from MCP servers +- **Multiple Connection Types**: + - **Stdio**: Communicate with tools via standard input/output + - **SSE**: Communicate with tools via Server-Sent Events +- **Security**: Permission system for controlling access to MCP tools + +### Configuring MCP Servers + +MCP servers are defined in the configuration file under the `mcpServers` section: + +```json +{ + "mcpServers": { + "example": { + "type": "stdio", + "command": "path/to/mcp-server", + "env": [], + "args": [] + }, + "web-example": { + "type": "sse", + "url": "https://example.com/mcp", + "headers": { + "Authorization": "Bearer token" + } + } + } +} +``` + +### MCP Tool Usage + +Once configured, MCP tools are automatically available to the AI assistant alongside built-in tools. They follow the same permission model as other tools, requiring user approval before execution. + +## LSP (Language Server Protocol) + +OpenCode integrates with Language Server Protocol to provide rich code intelligence features across multiple programming languages. + +### LSP Features + +- **Multi-language Support**: Connect to language servers for different programming languages +- **Code Intelligence**: Get diagnostics, completions, and navigation assistance +- **File Watching**: Automatically notify language servers of file changes +- **Diagnostics**: Display errors, warnings, and hints in your code + +### Supported LSP Features + +| Feature | Description | +| ----------------- | ----------------------------------- | +| Diagnostics | Error checking and linting | +| Completions | Code suggestions and autocompletion | +| Hover | Documentation on hover | +| Definition | Go to definition | +| References | Find all references | +| Document Symbols | Navigate symbols in current file | +| Workspace Symbols | Search symbols across workspace | +| Formatting | Code formatting | +| Code Actions | Quick fixes and refactorings | + +### Configuring LSP + +Language servers are configured in the configuration file under the `lsp` section: + +```json +{ + "lsp": { + "go": { + "disabled": false, + "command": "gopls" + }, + "typescript": { + "disabled": false, + "command": "typescript-language-server", + "args": ["--stdio"] + } + } +} +``` + +### LSP Integration with AI + +The AI assistant can access LSP features through the `diagnostics` tool, allowing it to: + +- Check for errors in your code +- Suggest fixes based on diagnostics +- Provide intelligent code assistance ## Development @@ -124,8 +357,16 @@ OpenCode builds upon the work of several open source projects and developers: ## License -[License information coming soon] +OpenCode is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. ## Contributing -[Contribution guidelines coming soon] +Contributions are welcome! Here's how you can contribute: + +1. Fork the repository +2. Create a feature branch (`git checkout -b feature/amazing-feature`) +3. Commit your changes (`git commit -m 'Add some amazing feature'`) +4. Push to the branch (`git push origin feature/amazing-feature`) +5. Open a Pull Request + +Please make sure to update tests as appropriate and follow the existing code style. diff --git a/internal/config/config.go b/internal/config/config.go index 5b6d51efa9b2fa20b86d4e27a85619706fc57d97..0cb727158aa5ff413caec01d9d990305ebb37572 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,8 +41,9 @@ const ( // Agent defines configuration for different LLM models and their token limits. type Agent struct { - Model models.ModelID `json:"model"` - MaxTokens int64 `json:"maxTokens"` + Model models.ModelID `json:"model"` + MaxTokens int64 `json:"maxTokens"` + ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh } // Provider defines configuration for an LLM provider. @@ -80,7 +81,6 @@ type Config struct { const ( defaultDataDirectory = ".opencode" defaultLogLevel = "info" - defaultMaxTokens = int64(5000) appName = "opencode" ) @@ -202,9 +202,7 @@ func setProviderDefaults() { if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { viper.SetDefault("providers.groq.apiKey", apiKey) viper.SetDefault("agents.coder.model", models.QWENQwq) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.QWENQwq) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.QWENQwq) } @@ -212,9 +210,7 @@ func setProviderDefaults() { if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { viper.SetDefault("providers.gemini.apiKey", apiKey) viper.SetDefault("agents.coder.model", models.GRMINI20Flash) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.GRMINI20Flash) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.GRMINI20Flash) } @@ -222,9 +218,7 @@ func setProviderDefaults() { if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { viper.SetDefault("providers.openai.apiKey", apiKey) viper.SetDefault("agents.coder.model", models.GPT4o) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.GPT4o) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.GPT4o) } @@ -233,17 +227,13 @@ func setProviderDefaults() { if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { viper.SetDefault("providers.anthropic.apiKey", apiKey) viper.SetDefault("agents.coder.model", models.Claude37Sonnet) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.Claude37Sonnet) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.Claude37Sonnet) } if hasAWSCredentials() { viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.coder.maxTokens", defaultMaxTokens) viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet) - viper.SetDefault("agents.task.maxTokens", defaultMaxTokens) viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet) } } diff --git a/internal/diff/diff.go b/internal/diff/diff.go index f48079c9c96c7e12349cac467875530200c505eb..7b48de25f8426b05cdc8052ccef1eb1aa3e04ef7 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -79,8 +79,9 @@ type linePair struct { // StyleConfig defines styling for diff rendering type StyleConfig struct { - ShowHeader bool - FileNameFg lipgloss.Color + ShowHeader bool + ShowHunkHeader bool + FileNameFg lipgloss.Color // Background colors RemovedLineBg lipgloss.Color AddedLineBg lipgloss.Color @@ -111,7 +112,8 @@ func NewStyleConfig(opts ...StyleOption) StyleConfig { // Default color scheme config := StyleConfig{ ShowHeader: true, - FileNameFg: lipgloss.Color("#fab283"), + ShowHunkHeader: true, + FileNameFg: lipgloss.Color("#a0a0a0"), RemovedLineBg: lipgloss.Color("#3A3030"), AddedLineBg: lipgloss.Color("#303A30"), ContextLineBg: lipgloss.Color("#212121"), @@ -204,6 +206,10 @@ func WithShowHeader(show bool) StyleOption { return func(s *StyleConfig) { s.ShowHeader = show } } +func WithShowHunkHeader(show bool) StyleOption { + return func(s *StyleConfig) { s.ShowHunkHeader = show } +} + // ------------------------------------------------------------------------- // Parse Configuration // ------------------------------------------------------------------------- @@ -914,13 +920,15 @@ func FormatDiff(diffText string, opts ...SideBySideOption) (string, error) { for _, h := range diffResult.Hunks { // Render hunk header - sb.WriteString( - lipgloss.NewStyle(). - Background(config.Style.HunkLineBg). - Foreground(config.Style.HunkLineFg). - Width(config.TotalWidth). - Render(h.Header) + "\n", - ) + if config.Style.ShowHunkHeader { + sb.WriteString( + lipgloss.NewStyle(). + Background(config.Style.HunkLineBg). + Foreground(config.Style.HunkLineFg). + Width(config.TotalWidth). + Render(h.Header) + "\n", + ) + } sb.WriteString(RenderSideBySideHunk(diffResult.OldFile, h, opts...)) } diff --git a/internal/diff/patch.go b/internal/diff/patch.go new file mode 100644 index 0000000000000000000000000000000000000000..aab0f956dcdad92b6ef6c468940f544c51cb2106 --- /dev/null +++ b/internal/diff/patch.go @@ -0,0 +1,739 @@ +package diff + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +type ActionType string + +const ( + ActionAdd ActionType = "add" + ActionDelete ActionType = "delete" + ActionUpdate ActionType = "update" +) + +type FileChange struct { + Type ActionType + OldContent *string + NewContent *string + MovePath *string +} + +type Commit struct { + Changes map[string]FileChange +} + +type Chunk struct { + OrigIndex int // line index of the first line in the original file + DelLines []string // lines to delete + InsLines []string // lines to insert +} + +type PatchAction struct { + Type ActionType + NewFile *string + Chunks []Chunk + MovePath *string +} + +type Patch struct { + Actions map[string]PatchAction +} + +type DiffError struct { + message string +} + +func (e DiffError) Error() string { + return e.message +} + +// Helper functions for error handling +func NewDiffError(message string) DiffError { + return DiffError{message: message} +} + +func fileError(action, reason, path string) DiffError { + return NewDiffError(fmt.Sprintf("%s File Error: %s: %s", action, reason, path)) +} + +func contextError(index int, context string, isEOF bool) DiffError { + prefix := "Invalid Context" + if isEOF { + prefix = "Invalid EOF Context" + } + return NewDiffError(fmt.Sprintf("%s %d:\n%s", prefix, index, context)) +} + +type Parser struct { + currentFiles map[string]string + lines []string + index int + patch Patch + fuzz int +} + +func NewParser(currentFiles map[string]string, lines []string) *Parser { + return &Parser{ + currentFiles: currentFiles, + lines: lines, + index: 0, + patch: Patch{Actions: make(map[string]PatchAction, len(currentFiles))}, + fuzz: 0, + } +} + +func (p *Parser) isDone(prefixes []string) bool { + if p.index >= len(p.lines) { + return true + } + if prefixes != nil { + for _, prefix := range prefixes { + if strings.HasPrefix(p.lines[p.index], prefix) { + return true + } + } + } + return false +} + +func (p *Parser) startsWith(prefix any) bool { + var prefixes []string + switch v := prefix.(type) { + case string: + prefixes = []string{v} + case []string: + prefixes = v + } + + for _, pfx := range prefixes { + if strings.HasPrefix(p.lines[p.index], pfx) { + return true + } + } + return false +} + +func (p *Parser) readStr(prefix string, returnEverything bool) string { + if p.index >= len(p.lines) { + return "" // Changed from panic to return empty string for safer operation + } + if strings.HasPrefix(p.lines[p.index], prefix) { + var text string + if returnEverything { + text = p.lines[p.index] + } else { + text = p.lines[p.index][len(prefix):] + } + p.index++ + return text + } + return "" +} + +func (p *Parser) Parse() error { + endPatchPrefixes := []string{"*** End Patch"} + + for !p.isDone(endPatchPrefixes) { + path := p.readStr("*** Update File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Update", "Duplicate Path", path) + } + moveTo := p.readStr("*** Move to: ", false) + if _, exists := p.currentFiles[path]; !exists { + return fileError("Update", "Missing File", path) + } + text := p.currentFiles[path] + action, err := p.parseUpdateFile(text) + if err != nil { + return err + } + if moveTo != "" { + action.MovePath = &moveTo + } + p.patch.Actions[path] = action + continue + } + + path = p.readStr("*** Delete File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Delete", "Duplicate Path", path) + } + if _, exists := p.currentFiles[path]; !exists { + return fileError("Delete", "Missing File", path) + } + p.patch.Actions[path] = PatchAction{Type: ActionDelete, Chunks: []Chunk{}} + continue + } + + path = p.readStr("*** Add File: ", false) + if path != "" { + if _, exists := p.patch.Actions[path]; exists { + return fileError("Add", "Duplicate Path", path) + } + if _, exists := p.currentFiles[path]; exists { + return fileError("Add", "File already exists", path) + } + action, err := p.parseAddFile() + if err != nil { + return err + } + p.patch.Actions[path] = action + continue + } + + return NewDiffError(fmt.Sprintf("Unknown Line: %s", p.lines[p.index])) + } + + if !p.startsWith("*** End Patch") { + return NewDiffError("Missing End Patch") + } + p.index++ + + return nil +} + +func (p *Parser) parseUpdateFile(text string) (PatchAction, error) { + action := PatchAction{Type: ActionUpdate, Chunks: []Chunk{}} + fileLines := strings.Split(text, "\n") + index := 0 + + endPrefixes := []string{ + "*** End Patch", + "*** Update File:", + "*** Delete File:", + "*** Add File:", + "*** End of File", + } + + for !p.isDone(endPrefixes) { + defStr := p.readStr("@@ ", false) + sectionStr := "" + if defStr == "" && p.index < len(p.lines) && p.lines[p.index] == "@@" { + sectionStr = p.lines[p.index] + p.index++ + } + if !(defStr != "" || sectionStr != "" || index == 0) { + return action, NewDiffError(fmt.Sprintf("Invalid Line:\n%s", p.lines[p.index])) + } + if strings.TrimSpace(defStr) != "" { + found := false + for i := range fileLines[:index] { + if fileLines[i] == defStr { + found = true + break + } + } + + if !found { + for i := index; i < len(fileLines); i++ { + if fileLines[i] == defStr { + index = i + 1 + found = true + break + } + } + } + + if !found { + for i := range fileLines[:index] { + if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) { + found = true + break + } + } + } + + if !found { + for i := index; i < len(fileLines); i++ { + if strings.TrimSpace(fileLines[i]) == strings.TrimSpace(defStr) { + index = i + 1 + p.fuzz++ + found = true + break + } + } + } + } + + nextChunkContext, chunks, endPatchIndex, eof := peekNextSection(p.lines, p.index) + newIndex, fuzz := findContext(fileLines, nextChunkContext, index, eof) + if newIndex == -1 { + ctxText := strings.Join(nextChunkContext, "\n") + return action, contextError(index, ctxText, eof) + } + p.fuzz += fuzz + + for _, ch := range chunks { + ch.OrigIndex += newIndex + action.Chunks = append(action.Chunks, ch) + } + index = newIndex + len(nextChunkContext) + p.index = endPatchIndex + } + return action, nil +} + +func (p *Parser) parseAddFile() (PatchAction, error) { + lines := make([]string, 0, 16) // Preallocate space for better performance + endPrefixes := []string{ + "*** End Patch", + "*** Update File:", + "*** Delete File:", + "*** Add File:", + } + + for !p.isDone(endPrefixes) { + s := p.readStr("", true) + if !strings.HasPrefix(s, "+") { + return PatchAction{}, NewDiffError(fmt.Sprintf("Invalid Add File Line: %s", s)) + } + lines = append(lines, s[1:]) + } + + newFile := strings.Join(lines, "\n") + return PatchAction{ + Type: ActionAdd, + NewFile: &newFile, + Chunks: []Chunk{}, + }, nil +} + +// Refactored to use a matcher function for each comparison type +func findContextCore(lines []string, context []string, start int) (int, int) { + if len(context) == 0 { + return start, 0 + } + + // Try exact match + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return a == b + }); idx >= 0 { + return idx, fuzz + } + + // Try trimming right whitespace + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return strings.TrimRight(a, " \t") == strings.TrimRight(b, " \t") + }); idx >= 0 { + return idx, fuzz + } + + // Try trimming all whitespace + if idx, fuzz := tryFindMatch(lines, context, start, func(a, b string) bool { + return strings.TrimSpace(a) == strings.TrimSpace(b) + }); idx >= 0 { + return idx, fuzz + } + + return -1, 0 +} + +// Helper function to DRY up the match logic +func tryFindMatch(lines []string, context []string, start int, + compareFunc func(string, string) bool, +) (int, int) { + for i := start; i < len(lines); i++ { + if i+len(context) <= len(lines) { + match := true + for j := range context { + if !compareFunc(lines[i+j], context[j]) { + match = false + break + } + } + if match { + // Return fuzz level: 0 for exact, 1 for trimRight, 100 for trimSpace + var fuzz int + if compareFunc("a ", "a") && !compareFunc("a", "b") { + fuzz = 1 + } else if compareFunc("a ", "a") { + fuzz = 100 + } + return i, fuzz + } + } + } + return -1, 0 +} + +func findContext(lines []string, context []string, start int, eof bool) (int, int) { + if eof { + newIndex, fuzz := findContextCore(lines, context, len(lines)-len(context)) + if newIndex != -1 { + return newIndex, fuzz + } + newIndex, fuzz = findContextCore(lines, context, start) + return newIndex, fuzz + 10000 + } + return findContextCore(lines, context, start) +} + +func peekNextSection(lines []string, initialIndex int) ([]string, []Chunk, int, bool) { + index := initialIndex + old := make([]string, 0, 32) // Preallocate for better performance + delLines := make([]string, 0, 8) + insLines := make([]string, 0, 8) + chunks := make([]Chunk, 0, 4) + mode := "keep" + + // End conditions for the section + endSectionConditions := func(s string) bool { + return strings.HasPrefix(s, "@@") || + strings.HasPrefix(s, "*** End Patch") || + strings.HasPrefix(s, "*** Update File:") || + strings.HasPrefix(s, "*** Delete File:") || + strings.HasPrefix(s, "*** Add File:") || + strings.HasPrefix(s, "*** End of File") || + s == "***" || + strings.HasPrefix(s, "***") + } + + for index < len(lines) { + s := lines[index] + if endSectionConditions(s) { + break + } + index++ + lastMode := mode + line := s + + if len(line) > 0 { + switch line[0] { + case '+': + mode = "add" + case '-': + mode = "delete" + case ' ': + mode = "keep" + default: + mode = "keep" + line = " " + line + } + } else { + mode = "keep" + line = " " + } + + line = line[1:] + if mode == "keep" && lastMode != mode { + if len(insLines) > 0 || len(delLines) > 0 { + chunks = append(chunks, Chunk{ + OrigIndex: len(old) - len(delLines), + DelLines: delLines, + InsLines: insLines, + }) + } + delLines = make([]string, 0, 8) + insLines = make([]string, 0, 8) + } + if mode == "delete" { + delLines = append(delLines, line) + old = append(old, line) + } else if mode == "add" { + insLines = append(insLines, line) + } else { + old = append(old, line) + } + } + + if len(insLines) > 0 || len(delLines) > 0 { + chunks = append(chunks, Chunk{ + OrigIndex: len(old) - len(delLines), + DelLines: delLines, + InsLines: insLines, + }) + } + + if index < len(lines) && lines[index] == "*** End of File" { + index++ + return old, chunks, index, true + } + return old, chunks, index, false +} + +func TextToPatch(text string, orig map[string]string) (Patch, int, error) { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + if len(lines) < 2 || !strings.HasPrefix(lines[0], "*** Begin Patch") || lines[len(lines)-1] != "*** End Patch" { + return Patch{}, 0, NewDiffError("Invalid patch text") + } + parser := NewParser(orig, lines) + parser.index = 1 + if err := parser.Parse(); err != nil { + return Patch{}, 0, err + } + return parser.patch, parser.fuzz, nil +} + +func IdentifyFilesNeeded(text string) []string { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + result := make(map[string]bool) + + for _, line := range lines { + if strings.HasPrefix(line, "*** Update File: ") { + result[line[len("*** Update File: "):]] = true + } + if strings.HasPrefix(line, "*** Delete File: ") { + result[line[len("*** Delete File: "):]] = true + } + } + + files := make([]string, 0, len(result)) + for file := range result { + files = append(files, file) + } + return files +} + +func IdentifyFilesAdded(text string) []string { + text = strings.TrimSpace(text) + lines := strings.Split(text, "\n") + result := make(map[string]bool) + + for _, line := range lines { + if strings.HasPrefix(line, "*** Add File: ") { + result[line[len("*** Add File: "):]] = true + } + } + + files := make([]string, 0, len(result)) + for file := range result { + files = append(files, file) + } + return files +} + +func getUpdatedFile(text string, action PatchAction, path string) (string, error) { + if action.Type != ActionUpdate { + return "", errors.New("Expected UPDATE action") + } + origLines := strings.Split(text, "\n") + destLines := make([]string, 0, len(origLines)) // Preallocate with capacity + origIndex := 0 + + for _, chunk := range action.Chunks { + if chunk.OrigIndex > len(origLines) { + return "", NewDiffError(fmt.Sprintf("%s: chunk.orig_index %d > len(lines) %d", path, chunk.OrigIndex, len(origLines))) + } + if origIndex > chunk.OrigIndex { + return "", NewDiffError(fmt.Sprintf("%s: orig_index %d > chunk.orig_index %d", path, origIndex, chunk.OrigIndex)) + } + destLines = append(destLines, origLines[origIndex:chunk.OrigIndex]...) + delta := chunk.OrigIndex - origIndex + origIndex += delta + + if len(chunk.InsLines) > 0 { + destLines = append(destLines, chunk.InsLines...) + } + origIndex += len(chunk.DelLines) + } + + destLines = append(destLines, origLines[origIndex:]...) + return strings.Join(destLines, "\n"), nil +} + +func PatchToCommit(patch Patch, orig map[string]string) (Commit, error) { + commit := Commit{Changes: make(map[string]FileChange, len(patch.Actions))} + for pathKey, action := range patch.Actions { + if action.Type == ActionDelete { + oldContent := orig[pathKey] + commit.Changes[pathKey] = FileChange{ + Type: ActionDelete, + OldContent: &oldContent, + } + } else if action.Type == ActionAdd { + commit.Changes[pathKey] = FileChange{ + Type: ActionAdd, + NewContent: action.NewFile, + } + } else if action.Type == ActionUpdate { + newContent, err := getUpdatedFile(orig[pathKey], action, pathKey) + if err != nil { + return Commit{}, err + } + oldContent := orig[pathKey] + fileChange := FileChange{ + Type: ActionUpdate, + OldContent: &oldContent, + NewContent: &newContent, + } + if action.MovePath != nil { + fileChange.MovePath = action.MovePath + } + commit.Changes[pathKey] = fileChange + } + } + return commit, nil +} + +func AssembleChanges(orig map[string]string, updatedFiles map[string]string) Commit { + commit := Commit{Changes: make(map[string]FileChange, len(updatedFiles))} + for p, newContent := range updatedFiles { + oldContent, exists := orig[p] + if exists && oldContent == newContent { + continue + } + + if exists && newContent != "" { + commit.Changes[p] = FileChange{ + Type: ActionUpdate, + OldContent: &oldContent, + NewContent: &newContent, + } + } else if newContent != "" { + commit.Changes[p] = FileChange{ + Type: ActionAdd, + NewContent: &newContent, + } + } else if exists { + commit.Changes[p] = FileChange{ + Type: ActionDelete, + OldContent: &oldContent, + } + } else { + return commit // Changed from panic to simply return current commit + } + } + return commit +} + +func LoadFiles(paths []string, openFn func(string) (string, error)) (map[string]string, error) { + orig := make(map[string]string, len(paths)) + for _, p := range paths { + content, err := openFn(p) + if err != nil { + return nil, fileError("Open", "File not found", p) + } + orig[p] = content + } + return orig, nil +} + +func ApplyCommit(commit Commit, writeFn func(string, string) error, removeFn func(string) error) error { + for p, change := range commit.Changes { + if change.Type == ActionDelete { + if err := removeFn(p); err != nil { + return err + } + } else if change.Type == ActionAdd { + if change.NewContent == nil { + return NewDiffError(fmt.Sprintf("Add action for %s has nil new_content", p)) + } + if err := writeFn(p, *change.NewContent); err != nil { + return err + } + } else if change.Type == ActionUpdate { + if change.NewContent == nil { + return NewDiffError(fmt.Sprintf("Update action for %s has nil new_content", p)) + } + if change.MovePath != nil { + if err := writeFn(*change.MovePath, *change.NewContent); err != nil { + return err + } + if err := removeFn(p); err != nil { + return err + } + } else { + if err := writeFn(p, *change.NewContent); err != nil { + return err + } + } + } + } + return nil +} + +func ProcessPatch(text string, openFn func(string) (string, error), writeFn func(string, string) error, removeFn func(string) error) (string, error) { + if !strings.HasPrefix(text, "*** Begin Patch") { + return "", NewDiffError("Patch must start with *** Begin Patch") + } + paths := IdentifyFilesNeeded(text) + orig, err := LoadFiles(paths, openFn) + if err != nil { + return "", err + } + + patch, fuzz, err := TextToPatch(text, orig) + if err != nil { + return "", err + } + + if fuzz > 0 { + return "", NewDiffError(fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz)) + } + + commit, err := PatchToCommit(patch, orig) + if err != nil { + return "", err + } + + if err := ApplyCommit(commit, writeFn, removeFn); err != nil { + return "", err + } + + return "Patch applied successfully", nil +} + +func OpenFile(p string) (string, error) { + data, err := os.ReadFile(p) + if err != nil { + return "", err + } + return string(data), nil +} + +func WriteFile(p string, content string) error { + if filepath.IsAbs(p) { + return NewDiffError("We do not support absolute paths.") + } + + dir := filepath.Dir(p) + if dir != "." { + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + } + + return os.WriteFile(p, []byte(content), 0o644) +} + +func RemoveFile(p string) error { + return os.Remove(p) +} + +func ValidatePatch(patchText string, files map[string]string) (bool, string, error) { + if !strings.HasPrefix(patchText, "*** Begin Patch") { + return false, "Patch must start with *** Begin Patch", nil + } + + neededFiles := IdentifyFilesNeeded(patchText) + for _, filePath := range neededFiles { + if _, exists := files[filePath]; !exists { + return false, fmt.Sprintf("File not found: %s", filePath), nil + } + } + + patch, fuzz, err := TextToPatch(patchText, files) + if err != nil { + return false, err.Error(), nil + } + + if fuzz > 0 { + return false, fmt.Sprintf("Patch contains fuzzy matches (fuzz level: %d)", fuzz), nil + } + + _, err = PatchToCommit(patch, files) + if err != nil { + return false, err.Error(), nil + } + + return true, "Patch is valid", nil +} diff --git a/internal/history/file.go b/internal/history/file.go index 8453ac272926761344c289942d66c05d44401d4e..7e206a2d99cdc4a8d5066e30f86549fcf42d75fb 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -50,6 +50,7 @@ func NewService(q *db.Queries, db *sql.DB) Service { return &service{ Broker: pubsub.NewBroker[File](), q: q, + db: db, } } @@ -100,30 +101,30 @@ func (s *service) createWithVersion(ctx context.Context, sessionID, path, conten var err error // Retry loop for transaction conflicts - for attempt := 0; attempt < maxRetries; attempt++ { + for attempt := range maxRetries { // Start a transaction - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return File{}, fmt.Errorf("failed to begin transaction: %w", err) + tx, txErr := s.db.Begin() + if txErr != nil { + return File{}, fmt.Errorf("failed to begin transaction: %w", txErr) } // Create a new queries instance with the transaction qtx := s.q.WithTx(tx) // Try to create the file within the transaction - dbFile, err := qtx.CreateFile(ctx, db.CreateFileParams{ + dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{ ID: uuid.New().String(), SessionID: sessionID, Path: path, Content: content, Version: version, }) - if err != nil { + if txErr != nil { // Rollback the transaction tx.Rollback() // Check if this is a uniqueness constraint violation - if strings.Contains(err.Error(), "UNIQUE constraint failed") { + if strings.Contains(txErr.Error(), "UNIQUE constraint failed") { if attempt < maxRetries-1 { // If we have retries left, generate a new version and try again if strings.HasPrefix(version, "v") { @@ -138,12 +139,12 @@ func (s *service) createWithVersion(ctx context.Context, sessionID, path, conten continue } } - return File{}, err + return File{}, txErr } // Commit the transaction - if err = tx.Commit(); err != nil { - return File{}, fmt.Errorf("failed to commit transaction: %w", err) + if txErr = tx.Commit(); txErr != nil { + return File{}, fmt.Errorf("failed to commit transaction: %w", txErr) } file = s.fromDBItem(dbFile) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index a5dadb89da77826f38dd15e22567b07b313f4d58..5e9785991d311349b339d12a3a02d64dcdda76d8 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -41,6 +41,7 @@ type Service interface { Run(ctx context.Context, sessionID string, content string) (<-chan AgentEvent, error) Cancel(sessionID string) IsSessionBusy(sessionID string) bool + IsBusy() bool } type agent struct { @@ -95,6 +96,20 @@ func (a *agent) Cancel(sessionID string) { } } +func (a *agent) IsBusy() bool { + busy := false + a.activeRequests.Range(func(key, value interface{}) bool { + if cancelFunc, ok := value.(context.CancelFunc); ok { + if cancelFunc != nil { + busy = true + return false // Stop iterating + } + } + return true // Continue iterating + }) + return busy +} + func (a *agent) IsSessionBusy(sessionID string) bool { _, busy := a.activeRequests.Load(sessionID) return busy @@ -313,23 +328,8 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg } } a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied) - } else { - toolResults[i] = message.ToolResult{ - ToolCallID: toolCall.ID, - Content: toolErr.Error(), - IsError: true, - } - for j := i; j < len(toolCalls); j++ { - toolResults[j] = message.ToolResult{ - ToolCallID: toolCalls[j].ID, - Content: "Previous tool failed", - IsError: true, - } - } - a.finishMessage(ctx, &assistantMsg, message.FinishReasonError) + break } - // If permission is denied or an error happens we cancel all the following tools - break } toolResults[i] = message.ToolResult{ ToolCallID: toolCall.ID, @@ -437,12 +437,27 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) if providerCfg.Disabled { return nil, fmt.Errorf("provider %s is not enabled", model.Provider) } - agentProvider, err := provider.NewProvider( - model.Provider, + maxTokens := model.DefaultMaxTokens + if agentConfig.MaxTokens > 0 { + maxTokens = agentConfig.MaxTokens + } + opts := []provider.ProviderClientOption{ provider.WithAPIKey(providerCfg.APIKey), provider.WithModel(model), provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)), - provider.WithMaxTokens(agentConfig.MaxTokens), + provider.WithMaxTokens(maxTokens), + } + if model.Provider == models.ProviderOpenAI && model.CanReason { + opts = append( + opts, + provider.WithOpenAIOptions( + provider.WithReasoningEffort(agentConfig.ReasoningEffort), + ), + ) + } + agentProvider, err := provider.NewProvider( + model.Provider, + opts..., ) if err != nil { return nil, fmt.Errorf("could not create provider: %v", err) diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index 9120809ffefa27a2a5c15968665bf559a266c4e3..b2e6816d5f76ddc3fa3604c46cde21b80ef03899 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -31,10 +31,9 @@ func CoderAgentTools( tools.NewGlobTool(), tools.NewGrepTool(), tools.NewLsTool(), - // TODO: see if we want to use this tool - // tools.NewPatchTool(lspClients, permissions, history), tools.NewSourcegraphTool(), tools.NewViewTool(lspClients), + tools.NewPatchTool(lspClients, permissions, history), tools.NewWriteTool(lspClients, permissions, history), NewAgentTool(sessions, messages, lspClients), }, otherTools..., diff --git a/internal/llm/models/anthropic.go b/internal/llm/models/anthropic.go index 48307e6d3fe72af892944234767ca8dba723398f..87e9b4c89cd25b9f99f2c10c8612fd7aba4b40d2 100644 --- a/internal/llm/models/anthropic.go +++ b/internal/llm/models/anthropic.go @@ -23,6 +23,7 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 0.30, CostPer1MOut: 15.0, ContextWindow: 200000, + DefaultMaxTokens: 5000, }, Claude3Haiku: { ID: Claude3Haiku, @@ -34,6 +35,7 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 0.03, CostPer1MOut: 1.25, ContextWindow: 200000, + DefaultMaxTokens: 5000, }, Claude37Sonnet: { ID: Claude37Sonnet, @@ -45,6 +47,8 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 0.30, CostPer1MOut: 15.0, ContextWindow: 200000, + DefaultMaxTokens: 50000, + CanReason: true, }, Claude35Haiku: { ID: Claude35Haiku, @@ -56,6 +60,7 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 0.08, CostPer1MOut: 4.0, ContextWindow: 200000, + DefaultMaxTokens: 4096, }, Claude3Opus: { ID: Claude3Opus, @@ -67,5 +72,6 @@ var AnthropicModels = map[ModelID]Model{ CostPer1MOutCached: 1.50, CostPer1MOut: 75.0, ContextWindow: 200000, + DefaultMaxTokens: 4096, }, } diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 4d4589bfdf15b777782633fa4c0c09b324cffa9a..bbce6130e244d1fd69dabd50ea7c806741e5ce24 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -17,15 +17,12 @@ type Model struct { CostPer1MInCached float64 `json:"cost_per_1m_in_cached"` CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"` ContextWindow int64 `json:"context_window"` + DefaultMaxTokens int64 `json:"default_max_tokens"` + CanReason bool `json:"can_reason"` } // Model IDs -const ( - // OpenAI - GPT4o ModelID = "gpt-4o" - GPT41 ModelID = "gpt-4.1" - - // GEMINI +const ( // GEMINI GEMINI25 ModelID = "gemini-2.5" GRMINI20Flash ModelID = "gemini-2.0-flash" @@ -37,7 +34,6 @@ const ( ) const ( - ProviderOpenAI ModelProvider = "openai" ProviderBedrock ModelProvider = "bedrock" ProviderGemini ModelProvider = "gemini" ProviderGROQ ModelProvider = "groq" @@ -47,59 +43,6 @@ const ( ) var SupportedModels = map[ModelID]Model{ - // // Anthropic - // Claude35Sonnet: { - // ID: Claude35Sonnet, - // Name: "Claude 3.5 Sonnet", - // Provider: ProviderAnthropic, - // APIModel: "claude-3-5-sonnet-latest", - // CostPer1MIn: 3.0, - // CostPer1MInCached: 3.75, - // CostPer1MOutCached: 0.30, - // CostPer1MOut: 15.0, - // }, - // Claude3Haiku: { - // ID: Claude3Haiku, - // Name: "Claude 3 Haiku", - // Provider: ProviderAnthropic, - // APIModel: "claude-3-haiku-latest", - // CostPer1MIn: 0.80, - // CostPer1MInCached: 1, - // CostPer1MOutCached: 0.08, - // CostPer1MOut: 4, - // }, - // Claude37Sonnet: { - // ID: Claude37Sonnet, - // Name: "Claude 3.7 Sonnet", - // Provider: ProviderAnthropic, - // APIModel: "claude-3-7-sonnet-latest", - // CostPer1MIn: 3.0, - // CostPer1MInCached: 3.75, - // CostPer1MOutCached: 0.30, - // CostPer1MOut: 15.0, - // }, - // - // // OpenAI - GPT4o: { - ID: GPT4o, - Name: "GPT-4o", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1", - CostPer1MIn: 2.00, - CostPer1MInCached: 0.50, - CostPer1MOutCached: 0, - CostPer1MOut: 8.00, - }, - GPT41: { - ID: GPT41, - Name: "GPT-4.1", - Provider: ProviderOpenAI, - APIModel: "gpt-4.1", - CostPer1MIn: 2.00, - CostPer1MInCached: 0.50, - CostPer1MOutCached: 0, - CostPer1MOut: 8.00, - }, // // // GEMINI // GEMINI25: { @@ -151,4 +94,5 @@ var SupportedModels = map[ModelID]Model{ func init() { maps.Copy(SupportedModels, AnthropicModels) + maps.Copy(SupportedModels, OpenAIModels) } diff --git a/internal/llm/models/openai.go b/internal/llm/models/openai.go new file mode 100644 index 0000000000000000000000000000000000000000..f0cbb298cb99b679fa7eb3b919633a49d81eea18 --- /dev/null +++ b/internal/llm/models/openai.go @@ -0,0 +1,169 @@ +package models + +const ( + ProviderOpenAI ModelProvider = "openai" + + GPT41 ModelID = "gpt-4.1" + GPT41Mini ModelID = "gpt-4.1-mini" + GPT41Nano ModelID = "gpt-4.1-nano" + GPT45Preview ModelID = "gpt-4.5-preview" + GPT4o ModelID = "gpt-4o" + GPT4oMini ModelID = "gpt-4o-mini" + O1 ModelID = "o1" + O1Pro ModelID = "o1-pro" + O1Mini ModelID = "o1-mini" + O3 ModelID = "o3" + O3Mini ModelID = "o3-mini" + O4Mini ModelID = "o4-mini" +) + +var OpenAIModels = map[ModelID]Model{ + GPT41: { + ID: GPT41, + Name: "GPT 4.1", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1", + CostPer1MIn: 2.00, + CostPer1MInCached: 0.50, + CostPer1MOutCached: 0.0, + CostPer1MOut: 8.00, + ContextWindow: 1_047_576, + DefaultMaxTokens: 20000, + }, + GPT41Mini: { + ID: GPT41Mini, + Name: "GPT 4.1 mini", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1", + CostPer1MIn: 0.40, + CostPer1MInCached: 0.10, + CostPer1MOutCached: 0.0, + CostPer1MOut: 1.60, + ContextWindow: 200_000, + DefaultMaxTokens: 20000, + }, + GPT41Nano: { + ID: GPT41Nano, + Name: "GPT 4.1 nano", + Provider: ProviderOpenAI, + APIModel: "gpt-4.1-nano", + CostPer1MIn: 0.10, + CostPer1MInCached: 0.025, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.40, + ContextWindow: 1_047_576, + DefaultMaxTokens: 20000, + }, + GPT45Preview: { + ID: GPT45Preview, + Name: "GPT 4.5 preview", + Provider: ProviderOpenAI, + APIModel: "gpt-4.5-preview", + CostPer1MIn: 75.00, + CostPer1MInCached: 37.50, + CostPer1MOutCached: 0.0, + CostPer1MOut: 150.00, + ContextWindow: 128_000, + DefaultMaxTokens: 15000, + }, + GPT4o: { + ID: GPT4o, + Name: "GPT 4o", + Provider: ProviderOpenAI, + APIModel: "gpt-4o", + CostPer1MIn: 2.50, + CostPer1MInCached: 1.25, + CostPer1MOutCached: 0.0, + CostPer1MOut: 10.00, + ContextWindow: 128_000, + DefaultMaxTokens: 4096, + }, + GPT4oMini: { + ID: GPT4oMini, + Name: "GPT 4o mini", + Provider: ProviderOpenAI, + APIModel: "gpt-4o-mini", + CostPer1MIn: 0.15, + CostPer1MInCached: 0.075, + CostPer1MOutCached: 0.0, + CostPer1MOut: 0.60, + ContextWindow: 128_000, + }, + O1: { + ID: O1, + Name: "O1", + Provider: ProviderOpenAI, + APIModel: "o1", + CostPer1MIn: 15.00, + CostPer1MInCached: 7.50, + CostPer1MOutCached: 0.0, + CostPer1MOut: 60.00, + ContextWindow: 200_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, + O1Pro: { + ID: O1Pro, + Name: "o1 pro", + Provider: ProviderOpenAI, + APIModel: "o1-pro", + CostPer1MIn: 150.00, + CostPer1MInCached: 0.0, + CostPer1MOutCached: 0.0, + CostPer1MOut: 600.00, + ContextWindow: 200_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, + O1Mini: { + ID: O1Mini, + Name: "o1 mini", + Provider: ProviderOpenAI, + APIModel: "o1-mini", + CostPer1MIn: 1.10, + CostPer1MInCached: 0.55, + CostPer1MOutCached: 0.0, + CostPer1MOut: 4.40, + ContextWindow: 128_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, + O3: { + ID: O3, + Name: "o3", + Provider: ProviderOpenAI, + APIModel: "o3", + CostPer1MIn: 10.00, + CostPer1MInCached: 2.50, + CostPer1MOutCached: 0.0, + CostPer1MOut: 40.00, + ContextWindow: 200_000, + CanReason: true, + }, + O3Mini: { + ID: O3Mini, + Name: "o3 mini", + Provider: ProviderOpenAI, + APIModel: "o3-mini", + CostPer1MIn: 1.10, + CostPer1MInCached: 0.55, + CostPer1MOutCached: 0.0, + CostPer1MOut: 4.40, + ContextWindow: 200_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, + O4Mini: { + ID: O4Mini, + Name: "o4 mini", + Provider: ProviderOpenAI, + APIModel: "o4-mini", + CostPer1MIn: 1.10, + CostPer1MInCached: 0.275, + CostPer1MOutCached: 0.0, + CostPer1MOut: 4.40, + ContextWindow: 128_000, + DefaultMaxTokens: 50000, + CanReason: true, + }, +} diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 3a06911dadf3f9bb04aa0f66d8b136b836468f1f..d7ca7b2fde3629bfc77dc9105279d61d674dad3f 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -25,44 +25,49 @@ func CoderPrompt(provider models.ModelProvider) string { } const baseOpenAICoderPrompt = ` -You are **OpenCode**, an autonomous CLI assistant for software‑engineering tasks. - -### ── INTERNAL REFLECTION ── -• Silently think step‑by‑step about the user request, directory layout, and tool calls (never reveal this). -• Formulate a plan, then execute without further approval unless a blocker triggers the Ask‑Only‑If rules. - -### ── PUBLIC RESPONSE RULES ── -• Visible reply ≤ 4 lines; no fluff, preamble, or postamble. -• Use GitHub‑flavored Markdown. -• When running a non‑trivial shell command, add ≤ 1 brief purpose sentence. - -### ── CONTEXT & MEMORY ── -• Infer file intent from directory structure before editing. -• Auto‑load 'OpenCode.md'; ask once before writing new reusable commands or style notes. - -### ── AUTONOMY PRIORITY ── -**Ask‑Only‑If Decision Tree:** -1. **Safety risk?** (e.g., destructive command, secret exposure) → ask. -2. **Critical unknown?** (no docs/tests; cannot infer) → ask. -3. **Tool failure after two self‑attempts?** → ask. -Otherwise, proceed autonomously. - -### ── SAFETY & STYLE ── -• Mimic existing code style; verify libraries exist before import. -• Never commit unless explicitly told. -• After edits, run lint & type‑check (ask for commands once, then offer to store in 'OpenCode.md'). -• Protect secrets; follow standard security practices :contentReference[oaicite:2]{index=2}. - -### ── TOOL USAGE ── -• Batch independent Agent search/file calls in one block for efficiency :contentReference[oaicite:3]{index=3}. -• Communicate with the user only via visible text; do not expose tool output or internal reasoning. - -### ── EXAMPLES ── -user: list files -assistant: ls - -user: write tests for new feature -assistant: [searches & edits autonomously, no extra chit‑chat] +You are operating as and within the OpenCode CLI, a terminal-based agentic coding assistant built by OpenAI. It wraps OpenAI models to enable natural language interaction with a local codebase. You are expected to be precise, safe, and helpful. + +You can: +- Receive user prompts, project context, and files. +- Stream responses and emit function calls (e.g., shell commands, code edits). +- Apply patches, run commands, and manage user approvals based on policy. +- Work inside a sandboxed, git-backed workspace with rollback support. +- Log telemetry so sessions can be replayed or inspected later. +- More details on your functionality are available at "opencode --help" + + +You are an agent - please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. If you are not sure about file content or codebase structure pertaining to the user's request, use your tools to read files and gather the relevant information: do NOT guess or make up an answer. + +Please resolve the user's task by editing and testing the code files in your current code execution session. You are a deployed coding agent. Your session allows for you to modify and run code. The repo(s) are already cloned in your working directory, and you must fully solve the problem for your answer to be considered correct. + +You MUST adhere to the following criteria when executing the task: +- Working on the repo(s) in the current environment is allowed, even if they are proprietary. +- Analyzing code for vulnerabilities is allowed. +- Showing user code and tool call details is allowed. +- User instructions may overwrite the *CODING GUIDELINES* section in this developer message. +- If completing the user's task requires writing or modifying files: + - Your code and final answer should follow these *CODING GUIDELINES*: + - Fix the problem at the root cause rather than applying surface-level patches, when possible. + - Avoid unneeded complexity in your solution. + - Ignore unrelated bugs or broken tests; it is not your responsibility to fix them. + - Update documentation as necessary. + - Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task. + - Use "git log" and "git blame" to search the history of the codebase if additional context is required; internet access is disabled. + - NEVER add copyright or license headers unless specifically requested. + - You do not need to "git commit" your changes; this will be done automatically for you. + - Once you finish coding, you must + - Check "git status" to sanity check your changes; revert any scratch files or changes. + - Remove all inline comments you added as much as possible, even if they look normal. Check using "git diff". Inline comments must be generally avoided, unless active maintainers of the repo, after long careful study of the code and the issue, will still misinterpret the code without the comments. + - Check if you accidentally add copyright or license headers. If so, remove them. + - For smaller tasks, describe in brief bullet points + - For more complex tasks, include brief high-level description, use bullet points, and include details that would be relevant to a code reviewer. +- If completing the user's task DOES NOT require writing or modifying files (e.g., the user asks a question about the code base): + - Respond in a friendly tune as a remote teammate, who is knowledgeable, capable and eager to help with coding. +- When your task involves writing or modifying files: + - Do NOT tell the user to "save the file" or "copy the code into a file" if you already created or modified the file using "apply_patch". Instead, reference the file as already saved. + - Do NOT show the full contents of large files you have already written, unless the user explicitly asks for them. +- When doing things with paths, always use use the full path, if the working directory is /abc/xyz and you want to edit the file abc.go in the working dir refer to it as /abc/xyz/abc.go. +- If you send a path not including the working dir, the working dir will be prepended to it. ` const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. @@ -125,7 +130,7 @@ assistant: src/foo.c user: write tests for new feature -assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit file tool to write new tests] +assistant: [uses grep and glob search tools to find where similar tests are defined, uses concurrent read file tool use blocks in one tool call to read relevant files at the same time, uses edit/patch file tool to write new tests] # Proactiveness diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 13ce934f29fb3a543d2d13f7e3b626cf61f7f0aa..6c6f74988168c8a01e5cd4d8fd95b54b8c8618b6 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -14,11 +14,13 @@ import ( "github.com/kujtimiihoxha/opencode/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" + "github.com/openai/openai-go/shared" ) type openaiOptions struct { - baseURL string - disableCache bool + baseURL string + disableCache bool + reasoningEffort string } type OpenAIOption func(*openaiOptions) @@ -32,7 +34,9 @@ type openaiClient struct { type OpenAIClient ProviderClient func newOpenAIClient(opts providerClientOptions) OpenAIClient { - openaiOpts := openaiOptions{} + openaiOpts := openaiOptions{ + reasoningEffort: "medium", + } for _, o := range opts.openaiOptions { o(&openaiOpts) } @@ -138,12 +142,29 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { } func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { - return openai.ChatCompletionNewParams{ - Model: openai.ChatModel(o.providerOptions.model.APIModel), - Messages: messages, - MaxTokens: openai.Int(o.providerOptions.maxTokens), - Tools: tools, + params := openai.ChatCompletionNewParams{ + Model: openai.ChatModel(o.providerOptions.model.APIModel), + Messages: messages, + Tools: tools, } + + if o.providerOptions.model.CanReason == true { + params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens) + switch o.options.reasoningEffort { + case "low": + params.ReasoningEffort = shared.ReasoningEffortLow + case "medium": + params.ReasoningEffort = shared.ReasoningEffortMedium + case "high": + params.ReasoningEffort = shared.ReasoningEffortHigh + default: + params.ReasoningEffort = shared.ReasoningEffortMedium + } + } else { + params.MaxTokens = openai.Int(o.providerOptions.maxTokens) + } + + return params } func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { @@ -359,3 +380,15 @@ func WithOpenAIDisableCache() OpenAIOption { } } +func WithReasoningEffort(effort string) OpenAIOption { + return func(options *openaiOptions) { + defaultReasoningEffort := "medium" + switch effort { + case "low", "medium", "high": + defaultReasoningEffort = effort + default: + logging.Warn("Invalid reasoning effort, using default: medium") + } + options.reasoningEffort = defaultReasoningEffort + } +} diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index 40262ce2ba5bb43a218d28a7d24e5f45b1a89ea3..e3c7b7b61ddd6f60610d69aaeecb5ded11488e56 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -192,6 +192,42 @@ func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) { } func skipHidden(path string) bool { + // Check for hidden files (starting with a dot) base := filepath.Base(path) - return base != "." && strings.HasPrefix(base, ".") + if base != "." && strings.HasPrefix(base, ".") { + return true + } + + // List of commonly ignored directories in development projects + commonIgnoredDirs := map[string]bool{ + "node_modules": true, + "vendor": true, + "dist": true, + "build": true, + "target": true, + ".git": true, + ".idea": true, + ".vscode": true, + "__pycache__": true, + "bin": true, + "obj": true, + "out": true, + "coverage": true, + "tmp": true, + "temp": true, + "logs": true, + "generated": true, + "bower_components": true, + "jspm_packages": true, + } + + // Check if any path component is in our ignore list + parts := strings.SplitSeq(path, string(os.PathSeparator)) + for part := range parts { + if commonIgnoredDirs[part] { + return true + } + } + + return false } diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 3436dd7eb6b21a6cfec6e80ff249823a23c0f5d2..086a5e686cce7e90a386a0273a76d62018ea7181 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -17,9 +17,10 @@ import ( ) type GrepParams struct { - Pattern string `json:"pattern"` - Path string `json:"path"` - Include string `json:"include"` + Pattern string `json:"pattern"` + Path string `json:"path"` + Include string `json:"include"` + LiteralText bool `json:"literal_text"` } type grepMatch struct { @@ -45,11 +46,12 @@ WHEN TO USE THIS TOOL: HOW TO USE: - Provide a regex pattern to search for within file contents +- Set literal_text=true if you want to search for the exact text with special characters (recommended for non-regex users) - Optionally specify a starting directory (defaults to current working directory) - Optionally provide an include pattern to filter which files to search - Results are sorted with most recently modified files first -REGEX PATTERN SYNTAX: +REGEX PATTERN SYNTAX (when literal_text=false): - Supports standard regular expression syntax - 'function' searches for the literal text "function" - 'log\..*Error' finds text starting with "log." and ending with "Error" @@ -69,7 +71,8 @@ LIMITATIONS: TIPS: - For faster, more targeted searches, first use Glob to find relevant files, then use Grep - When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead -- Always check if results are truncated and refine your search pattern if needed` +- Always check if results are truncated and refine your search pattern if needed +- Use literal_text=true when searching for exact text containing special characters like dots, parentheses, etc.` ) func NewGrepTool() BaseTool { @@ -93,11 +96,27 @@ func (g *grepTool) Info() ToolInfo { "type": "string", "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")", }, + "literal_text": map[string]any{ + "type": "boolean", + "description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.", + }, }, Required: []string{"pattern"}, } } +// escapeRegexPattern escapes special regex characters so they're treated as literal characters +func escapeRegexPattern(pattern string) string { + specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"} + escaped := pattern + + for _, char := range specialChars { + escaped = strings.ReplaceAll(escaped, char, "\\"+char) + } + + return escaped +} + func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { var params GrepParams if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { @@ -108,12 +127,18 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return NewTextErrorResponse("pattern is required"), nil } + // If literal_text is true, escape the pattern + searchPattern := params.Pattern + if params.LiteralText { + searchPattern = escapeRegexPattern(params.Pattern) + } + searchPath := params.Path if searchPath == "" { searchPath = config.WorkingDirectory() } - matches, truncated, err := searchFiles(params.Pattern, searchPath, params.Include, 100) + matches, truncated, err := searchFiles(searchPattern, searchPath, params.Include, 100) if err != nil { return ToolResponse{}, fmt.Errorf("error searching files: %w", err) } diff --git a/internal/llm/tools/patch.go b/internal/llm/tools/patch.go index 12060d72a363b16fffc8af2c43fa14cf496597dd..0f879462caf833fd121600153f8b70ff2c85c898 100644 --- a/internal/llm/tools/patch.go +++ b/internal/llm/tools/patch.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "path/filepath" - "strings" "time" "github.com/kujtimiihoxha/opencode/internal/config" @@ -17,19 +16,13 @@ import ( ) type PatchParams struct { - FilePath string `json:"file_path"` - Patch string `json:"patch"` -} - -type PatchPermissionsParams struct { - FilePath string `json:"file_path"` - Diff string `json:"diff"` + PatchText string `json:"patch_text"` } type PatchResponseMetadata struct { - Diff string `json:"diff"` - Additions int `json:"additions"` - Removals int `json:"removals"` + FilesChanged []string `json:"files_changed"` + Additions int `json:"additions"` + Removals int `json:"removals"` } type patchTool struct { @@ -39,47 +32,35 @@ type patchTool struct { } const ( - // TODO: test if this works as expected PatchToolName = "patch" - patchDescription = `Applies a patch to a file. This tool is similar to the edit tool but accepts a unified diff patch instead of old/new strings. + patchDescription = `Applies a patch to multiple files in one operation. This tool is useful for making coordinated changes across multiple files. + +The patch text must follow this format: +*** Begin Patch +*** Update File: /path/to/file +@@ Context line (unique within the file) + Line to keep +-Line to remove ++Line to add + Line to keep +*** Add File: /path/to/new/file ++Content of the new file ++More content +*** Delete File: /path/to/file/to/delete +*** End Patch Before using this tool: - -1. Use the FileRead tool to understand the file's contents and context - -2. Verify the directory path is correct: - - Use the LS tool to verify the parent directory exists and is the correct location - -To apply a patch, provide the following: -1. file_path: The absolute path to the file to modify (must be absolute, not relative) -2. patch: A unified diff patch to apply to the file - -The tool will apply the patch to the specified file. The patch must be in unified diff format. +1. Use the FileRead tool to understand the files' contents and context +2. Verify all file paths are correct (use the LS tool) CRITICAL REQUIREMENTS FOR USING THIS TOOL: -1. PATCH FORMAT: The patch must be in unified diff format, which includes: - - File headers (--- a/file_path, +++ b/file_path) - - Hunk headers (@@ -start,count +start,count @@) - - Added lines (prefixed with +) - - Removed lines (prefixed with -) - -2. CONTEXT: The patch must include sufficient context around the changes to ensure it applies correctly. - -3. VERIFICATION: Before using this tool: - - Ensure the patch applies cleanly to the current state of the file - - Check that the file exists and you have read it first - -WARNING: If you do not follow these requirements: - - The tool will fail if the patch doesn't apply cleanly - - You may change the wrong parts of the file if the context is insufficient - -When applying patches: - - Ensure the patch results in idiomatic, correct code - - Do not leave the code in a broken state - - Always use absolute file paths (starting with /) +1. UNIQUENESS: Context lines MUST uniquely identify the specific sections you want to change +2. PRECISION: All whitespace, indentation, and surrounding code must match exactly +3. VALIDATION: Ensure edits result in idiomatic, correct code +4. PATHS: Always use absolute file paths (starting with /) -Remember: patches are a powerful way to make multiple related changes at once, but they require careful preparation.` +The tool will apply all changes in a single atomic operation.` ) func NewPatchTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool { @@ -95,16 +76,12 @@ func (p *patchTool) Info() ToolInfo { Name: PatchToolName, Description: patchDescription, Parameters: map[string]any{ - "file_path": map[string]any{ + "patch_text": map[string]any{ "type": "string", - "description": "The absolute path to the file to modify", - }, - "patch": map[string]any{ - "type": "string", - "description": "The unified diff patch to apply", + "description": "The full patch text that describes all changes to be made", }, }, - Required: []string{"file_path", "patch"}, + Required: []string{"patch_text"}, } } @@ -114,187 +91,278 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse("invalid parameters"), nil } - if params.FilePath == "" { - return NewTextErrorResponse("file_path is required"), nil + if params.PatchText == "" { + return NewTextErrorResponse("patch_text is required"), nil } - if params.Patch == "" { - return NewTextErrorResponse("patch is required"), nil - } + // Identify all files needed for the patch and verify they've been read + filesToRead := diff.IdentifyFilesNeeded(params.PatchText) + for _, filePath := range filesToRead { + absPath := filePath + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } - if !filepath.IsAbs(params.FilePath) { - wd := config.WorkingDirectory() - params.FilePath = filepath.Join(wd, params.FilePath) - } + if getLastReadTime(absPath).IsZero() { + return NewTextErrorResponse(fmt.Sprintf("you must read the file %s before patching it. Use the FileRead tool first", filePath)), nil + } - // Check if file exists - fileInfo, err := os.Stat(params.FilePath) - if err != nil { - if os.IsNotExist(err) { - return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil + fileInfo, err := os.Stat(absPath) + if err != nil { + if os.IsNotExist(err) { + return NewTextErrorResponse(fmt.Sprintf("file not found: %s", absPath)), nil + } + return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) } - return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) - } - if fileInfo.IsDir() { - return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil - } + if fileInfo.IsDir() { + return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", absPath)), nil + } - if getLastReadTime(params.FilePath).IsZero() { - return NewTextErrorResponse("you must read the file before patching it. Use the View tool first"), nil + modTime := fileInfo.ModTime() + lastRead := getLastReadTime(absPath) + if modTime.After(lastRead) { + return NewTextErrorResponse( + fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + absPath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), + )), nil + } } - modTime := fileInfo.ModTime() - lastRead := getLastReadTime(params.FilePath) - if modTime.After(lastRead) { - return NewTextErrorResponse( - fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", - params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), - )), nil - } + // Check for new files to ensure they don't already exist + filesToAdd := diff.IdentifyFilesAdded(params.PatchText) + for _, filePath := range filesToAdd { + absPath := filePath + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } - // Read the current file content - content, err := os.ReadFile(params.FilePath) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) + _, err := os.Stat(absPath) + if err == nil { + return NewTextErrorResponse(fmt.Sprintf("file already exists and cannot be added: %s", absPath)), nil + } else if !os.IsNotExist(err) { + return ToolResponse{}, fmt.Errorf("failed to check file: %w", err) + } } - oldContent := string(content) + // Load all required files + currentFiles := make(map[string]string) + for _, filePath := range filesToRead { + absPath := filePath + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } - // Parse and apply the patch - diffResult, err := diff.ParseUnifiedDiff(params.Patch) - if err != nil { - return NewTextErrorResponse(fmt.Sprintf("failed to parse patch: %v", err)), nil + content, err := os.ReadFile(absPath) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to read file %s: %w", absPath, err) + } + currentFiles[filePath] = string(content) } - // Apply the patch to get the new content - newContent, err := applyPatch(oldContent, diffResult) + // Process the patch + patch, fuzz, err := diff.TextToPatch(params.PatchText, currentFiles) if err != nil { - return NewTextErrorResponse(fmt.Sprintf("failed to apply patch: %v", err)), nil + return NewTextErrorResponse(fmt.Sprintf("failed to parse patch: %s", err)), nil } - if oldContent == newContent { - return NewTextErrorResponse("patch did not result in any changes to the file"), nil + if fuzz > 0 { + return NewTextErrorResponse(fmt.Sprintf("patch contains fuzzy matches (fuzz level: %d). Please make your context lines more precise", fuzz)), nil } + // Convert patch to commit + commit, err := diff.PatchToCommit(patch, currentFiles) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("failed to create commit from patch: %s", err)), nil + } + + // Get session ID and message ID sessionID, messageID := GetContextValues(ctx) if sessionID == "" || messageID == "" { - return ToolResponse{}, fmt.Errorf("session ID and message ID are required for patching a file") + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a patch") } - // Generate a diff for permission request and metadata - diffText, additions, removals := diff.GenerateDiff( - oldContent, - newContent, - params.FilePath, - ) - - // Request permission to apply the patch - p.permissions.Request( - permission.CreatePermissionRequest{ - Path: filepath.Dir(params.FilePath), - ToolName: PatchToolName, - Action: "patch", - Description: fmt.Sprintf("Apply patch to file %s", params.FilePath), - Params: PatchPermissionsParams{ - FilePath: params.FilePath, - Diff: diffText, - }, - }, - ) - - // Write the new content to the file - err = os.WriteFile(params.FilePath, []byte(newContent), 0o644) - if err != nil { - return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) + // Request permission for all changes + for path, change := range commit.Changes { + switch change.Type { + case diff.ActionAdd: + dir := filepath.Dir(path) + patchDiff, _, _ := diff.GenerateDiff("", *change.NewContent, path) + p := p.permissions.Request( + permission.CreatePermissionRequest{ + Path: dir, + ToolName: PatchToolName, + Action: "create", + Description: fmt.Sprintf("Create file %s", path), + Params: EditPermissionsParams{ + FilePath: path, + Diff: patchDiff, + }, + }, + ) + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + case diff.ActionUpdate: + currentContent := "" + if change.OldContent != nil { + currentContent = *change.OldContent + } + newContent := "" + if change.NewContent != nil { + newContent = *change.NewContent + } + patchDiff, _, _ := diff.GenerateDiff(currentContent, newContent, path) + dir := filepath.Dir(path) + p := p.permissions.Request( + permission.CreatePermissionRequest{ + Path: dir, + ToolName: PatchToolName, + Action: "update", + Description: fmt.Sprintf("Update file %s", path), + Params: EditPermissionsParams{ + FilePath: path, + Diff: patchDiff, + }, + }, + ) + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + case diff.ActionDelete: + dir := filepath.Dir(path) + patchDiff, _, _ := diff.GenerateDiff(*change.OldContent, "", path) + p := p.permissions.Request( + permission.CreatePermissionRequest{ + Path: dir, + ToolName: PatchToolName, + Action: "delete", + Description: fmt.Sprintf("Delete file %s", path), + Params: EditPermissionsParams{ + FilePath: path, + Diff: patchDiff, + }, + }, + ) + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + } } - // Update file history - file, err := p.files.GetByPathAndSession(ctx, params.FilePath, sessionID) - if err != nil { - _, err = p.files.Create(ctx, sessionID, params.FilePath, oldContent) - if err != nil { - return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + // Apply the changes to the filesystem + err = diff.ApplyCommit(commit, func(path string, content string) error { + absPath := path + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) } - } - if file.Content != oldContent { - // User manually changed the content, store an intermediate version - _, err = p.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent) - if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + + // Create parent directories if needed + dir := filepath.Dir(absPath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("failed to create parent directories for %s: %w", absPath, err) } - } - // Store the new version - _, err = p.files.CreateVersion(ctx, sessionID, params.FilePath, newContent) + + return os.WriteFile(absPath, []byte(content), 0o644) + }, func(path string) error { + absPath := path + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } + return os.Remove(absPath) + }) if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + return NewTextErrorResponse(fmt.Sprintf("failed to apply patch: %s", err)), nil } - recordFileWrite(params.FilePath) - recordFileRead(params.FilePath) + // Update file history for all modified files + changedFiles := []string{} + totalAdditions := 0 + totalRemovals := 0 - // Wait for LSP diagnostics and include them in the response - waitForLspDiagnostics(ctx, params.FilePath, p.lspClients) - text := fmt.Sprintf("\nPatch applied to file: %s\n\n", params.FilePath) - text += getDiagnostics(params.FilePath, p.lspClients) + for path, change := range commit.Changes { + absPath := path + if !filepath.IsAbs(absPath) { + wd := config.WorkingDirectory() + absPath = filepath.Join(wd, absPath) + } + changedFiles = append(changedFiles, absPath) - return WithResponseMetadata( - NewTextResponse(text), - PatchResponseMetadata{ - Diff: diffText, - Additions: additions, - Removals: removals, - }), nil -} + oldContent := "" + if change.OldContent != nil { + oldContent = *change.OldContent + } -// applyPatch applies a parsed diff to a string and returns the resulting content -func applyPatch(content string, diffResult diff.DiffResult) (string, error) { - lines := strings.Split(content, "\n") + newContent := "" + if change.NewContent != nil { + newContent = *change.NewContent + } - // Process each hunk in the diff - for _, hunk := range diffResult.Hunks { - // Parse the hunk header to get line numbers - var oldStart, oldCount, newStart, newCount int - _, err := fmt.Sscanf(hunk.Header, "@@ -%d,%d +%d,%d @@", &oldStart, &oldCount, &newStart, &newCount) - if err != nil { - // Try alternative format with single line counts - _, err = fmt.Sscanf(hunk.Header, "@@ -%d +%d @@", &oldStart, &newStart) + // Calculate diff statistics + _, additions, removals := diff.GenerateDiff(oldContent, newContent, path) + totalAdditions += additions + totalRemovals += removals + + // Update history + file, err := p.files.GetByPathAndSession(ctx, absPath, sessionID) + if err != nil && change.Type != diff.ActionAdd { + // If not adding a file, create history entry for existing file + _, err = p.files.Create(ctx, sessionID, absPath, oldContent) if err != nil { - return "", fmt.Errorf("invalid hunk header format: %s", hunk.Header) + fmt.Printf("Error creating file history: %v\n", err) } - oldCount = 1 - newCount = 1 } - // Adjust for 0-based array indexing - oldStart-- - newStart-- - - // Apply the changes - newLines := make([]string, 0) - newLines = append(newLines, lines[:oldStart]...) - - // Process the hunk lines in order - currentOldLine := oldStart - for _, line := range hunk.Lines { - switch line.Kind { - case diff.LineContext: - newLines = append(newLines, line.Content) - currentOldLine++ - case diff.LineRemoved: - // Skip this line in the output (it's being removed) - currentOldLine++ - case diff.LineAdded: - // Add the new line - newLines = append(newLines, line.Content) + if err == nil && change.Type != diff.ActionAdd && file.Content != oldContent { + // User manually changed content, store intermediate version + _, err = p.files.CreateVersion(ctx, sessionID, absPath, oldContent) + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) } } - // Append the rest of the file - newLines = append(newLines, lines[currentOldLine:]...) - lines = newLines + // Store new version + if change.Type == diff.ActionDelete { + _, err = p.files.CreateVersion(ctx, sessionID, absPath, "") + } else { + _, err = p.files.CreateVersion(ctx, sessionID, absPath, newContent) + } + if err != nil { + fmt.Printf("Error creating file history version: %v\n", err) + } + + // Record file operations + recordFileWrite(absPath) + recordFileRead(absPath) } - return strings.Join(lines, "\n"), nil -} + // Run LSP diagnostics on all changed files + for _, filePath := range changedFiles { + waitForLspDiagnostics(ctx, filePath, p.lspClients) + } + result := fmt.Sprintf("Patch applied successfully. %d files changed, %d additions, %d removals", + len(changedFiles), totalAdditions, totalRemovals) + + diagnosticsText := "" + for _, filePath := range changedFiles { + diagnosticsText += getDiagnostics(filePath, p.lspClients) + } + + if diagnosticsText != "" { + result += "\n\nDiagnostics:\n" + diagnosticsText + } + + return WithResponseMetadata( + NewTextResponse(result), + PatchResponseMetadata{ + FilesChanged: changedFiles, + Additions: totalAdditions, + Removals: totalRemovals, + }), nil +} diff --git a/internal/llm/tools/view.go b/internal/llm/tools/view.go index 3fa4ca11616eda5daf21a14861d524d782f74e5d..dc02b34f3dca935d5ffab12766369cd6743718cf 100644 --- a/internal/llm/tools/view.go +++ b/internal/llm/tools/view.go @@ -24,6 +24,11 @@ type viewTool struct { lspClients map[string]*lsp.Client } +type ViewResponseMetadata struct { + FilePath string `json:"file_path"` + Content string `json:"content"` +} + const ( ViewToolName = "view" MaxReadSize = 250 * 1024 @@ -180,7 +185,13 @@ func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) output += "\n\n" output += getDiagnostics(filePath, v.lspClients) recordFileRead(filePath) - return NewTextResponse(output), nil + return WithResponseMetadata( + NewTextResponse(output), + ViewResponseMetadata{ + FilePath: filePath, + Content: content, + }, + ), nil } func addLineNumbers(content string, startLine int) string { diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index ded0639bb612b66a91fde5a6220c6fae2e67e23c..537ef392c2f9e3ef14a67d0da777196f3a13240e 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -1,6 +1,9 @@ package chat import ( + "os" + "os/exec" + "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" @@ -19,13 +22,15 @@ type editorCmp struct { } type focusedEditorKeyMaps struct { - Send key.Binding - Blur key.Binding + Send key.Binding + OpenEditor key.Binding + Blur key.Binding } type bluredEditorKeyMaps struct { - Send key.Binding - Focus key.Binding + Send key.Binding + Focus key.Binding + OpenEditor key.Binding } var focusedKeyMaps = focusedEditorKeyMaps{ @@ -37,6 +42,10 @@ var focusedKeyMaps = focusedEditorKeyMaps{ key.WithKeys("esc"), key.WithHelp("esc", "focus messages"), ), + OpenEditor: key.NewBinding( + key.WithKeys("ctrl+e"), + key.WithHelp("ctrl+e", "open editor"), + ), } var bluredKeyMaps = bluredEditorKeyMaps{ @@ -48,6 +57,40 @@ var bluredKeyMaps = bluredEditorKeyMaps{ key.WithKeys("i"), key.WithHelp("i", "focus editor"), ), + OpenEditor: key.NewBinding( + key.WithKeys("ctrl+e"), + key.WithHelp("ctrl+e", "open editor"), + ), +} + +func openEditor() tea.Cmd { + editor := os.Getenv("EDITOR") + if editor == "" { + editor = "nvim" + } + + tmpfile, err := os.CreateTemp("", "msg_*.md") + if err != nil { + return util.ReportError(err) + } + tmpfile.Close() + c := exec.Command(editor, tmpfile.Name()) //nolint:gosec + c.Stdin = os.Stdin + c.Stdout = os.Stdout + c.Stderr = os.Stderr + return tea.ExecProcess(c, func(err error) tea.Msg { + if err != nil { + return util.ReportError(err) + } + content, err := os.ReadFile(tmpfile.Name()) + if err != nil { + return util.ReportError(err) + } + os.Remove(tmpfile.Name()) + return SendMsg{ + Text: string(content), + } + }) } func (m *editorCmp) Init() tea.Cmd { @@ -82,6 +125,10 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return m, nil case tea.KeyMsg: + if key.Matches(msg, focusedKeyMaps.OpenEditor) { + m.textarea.Blur() + return m, openEditor() + } // if the key does not match any binding, return if m.textarea.Focused() && key.Matches(msg, focusedKeyMaps.Send) { return m, m.send() @@ -108,9 +155,10 @@ func (m *editorCmp) View() string { return lipgloss.JoinHorizontal(lipgloss.Top, style.Render(">"), m.textarea.View()) } -func (m *editorCmp) SetSize(width, height int) { +func (m *editorCmp) SetSize(width, height int) tea.Cmd { m.textarea.SetWidth(width - 3) // account for the prompt and padding right m.textarea.SetHeight(height) + return nil } func (m *editorCmp) GetSize() (int, int) { diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go new file mode 100644 index 0000000000000000000000000000000000000000..f95b53731773aff6e05292821ad13910bd8c66bb --- /dev/null +++ b/internal/tui/components/chat/list.go @@ -0,0 +1,463 @@ +package chat + +import ( + "context" + "fmt" + "math" + "sync" + "time" + + "github.com/charmbracelet/bubbles/key" + "github.com/charmbracelet/bubbles/spinner" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/logging" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" +) + +type messagesCmp struct { + app *app.App + width, height int + writingMode bool + viewport viewport.Model + session session.Session + messages []message.Message + uiMessages []uiMessage + currentMsgID string + mutex sync.Mutex + cachedContent map[string][]uiMessage + spinner spinner.Model + rendering bool +} +type renderFinishedMsg struct{} + +func (m *messagesCmp) Init() tea.Cmd { + return tea.Batch(m.viewport.Init()) +} + +func (m *messagesCmp) preloadSessions() tea.Cmd { + return func() tea.Msg { + sessions, err := m.app.Sessions.List(context.Background()) + if err != nil { + return util.ReportError(err)() + } + if len(sessions) == 0 { + return nil + } + if len(sessions) > 20 { + sessions = sessions[:20] + } + for _, s := range sessions { + messages, err := m.app.Messages.List(context.Background(), s.ID) + if err != nil { + return util.ReportError(err)() + } + if len(messages) == 0 { + continue + } + m.cacheSessionMessages(messages, m.width) + + } + logging.Debug("preloaded sessions") + + return nil + } +} + +func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int) { + m.mutex.Lock() + defer m.mutex.Unlock() + pos := 0 + if m.width == 0 { + return + } + for inx, msg := range messages { + switch msg.Role { + case message.User: + userMsg := renderUserMessage( + msg, + false, + width, + pos, + ) + m.cachedContent[msg.ID] = []uiMessage{userMsg} + pos += userMsg.height + 1 // + 1 for spacing + case message.Assistant: + assistantMessages := renderAssistantMessage( + msg, + inx, + messages, + m.app.Messages, + "", + width, + pos, + ) + for _, msg := range assistantMessages { + pos += msg.height + 1 // + 1 for spacing + } + m.cachedContent[msg.ID] = assistantMessages + } + } +} + +func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd + switch msg := msg.(type) { + case EditorFocusMsg: + m.writingMode = bool(msg) + case SessionSelectedMsg: + if msg.ID != m.session.ID { + cmd := m.SetSession(msg) + return m, cmd + } + return m, nil + case SessionClearedMsg: + m.session = session.Session{} + m.messages = make([]message.Message, 0) + m.currentMsgID = "" + m.rendering = false + return m, nil + + case renderFinishedMsg: + m.rendering = false + m.viewport.GotoBottom() + case tea.KeyMsg: + if m.writingMode { + return m, nil + } + case pubsub.Event[message.Message]: + needsRerender := false + if msg.Type == pubsub.CreatedEvent { + if msg.Payload.SessionID == m.session.ID { + + messageExists := false + for _, v := range m.messages { + if v.ID == msg.Payload.ID { + messageExists = true + break + } + } + + if !messageExists { + if len(m.messages) > 0 { + lastMsgID := m.messages[len(m.messages)-1].ID + delete(m.cachedContent, lastMsgID) + } + + m.messages = append(m.messages, msg.Payload) + delete(m.cachedContent, m.currentMsgID) + m.currentMsgID = msg.Payload.ID + needsRerender = true + } + } + // There are tool calls from the child task + for _, v := range m.messages { + for _, c := range v.ToolCalls() { + if c.ID == msg.Payload.SessionID { + delete(m.cachedContent, v.ID) + needsRerender = true + } + } + } + } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { + for i, v := range m.messages { + if v.ID == msg.Payload.ID { + m.messages[i] = msg.Payload + delete(m.cachedContent, msg.Payload.ID) + needsRerender = true + break + } + } + } + if needsRerender { + m.renderView() + if len(m.messages) > 0 { + if (msg.Type == pubsub.CreatedEvent) || + (msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) { + m.viewport.GotoBottom() + } + } + } + } + + u, cmd := m.viewport.Update(msg) + m.viewport = u + cmds = append(cmds, cmd) + + spinner, cmd := m.spinner.Update(msg) + m.spinner = spinner + cmds = append(cmds, cmd) + return m, tea.Batch(cmds...) +} + +func (m *messagesCmp) IsAgentWorking() bool { + return m.app.CoderAgent.IsSessionBusy(m.session.ID) +} + +func formatTimeDifference(unixTime1, unixTime2 int64) string { + diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1))) + + if diffSeconds < 60 { + return fmt.Sprintf("%.1fs", diffSeconds) + } + + minutes := int(diffSeconds / 60) + seconds := int(diffSeconds) % 60 + return fmt.Sprintf("%dm%ds", minutes, seconds) +} + +func (m *messagesCmp) renderView() { + m.uiMessages = make([]uiMessage, 0) + pos := 0 + + if m.width == 0 { + return + } + for inx, msg := range m.messages { + switch msg.Role { + case message.User: + if messages, ok := m.cachedContent[msg.ID]; ok { + m.uiMessages = append(m.uiMessages, messages...) + continue + } + userMsg := renderUserMessage( + msg, + msg.ID == m.currentMsgID, + m.width, + pos, + ) + m.uiMessages = append(m.uiMessages, userMsg) + m.cachedContent[msg.ID] = []uiMessage{userMsg} + pos += userMsg.height + 1 // + 1 for spacing + case message.Assistant: + if messages, ok := m.cachedContent[msg.ID]; ok { + m.uiMessages = append(m.uiMessages, messages...) + continue + } + assistantMessages := renderAssistantMessage( + msg, + inx, + m.messages, + m.app.Messages, + m.currentMsgID, + m.width, + pos, + ) + for _, msg := range assistantMessages { + m.uiMessages = append(m.uiMessages, msg) + pos += msg.height + 1 // + 1 for spacing + } + m.cachedContent[msg.ID] = assistantMessages + } + } + + messages := make([]string, 0) + for _, v := range m.uiMessages { + messages = append(messages, v.content, + styles.BaseStyle. + Width(m.width). + Render( + "", + ), + ) + } + m.viewport.SetContent( + styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + messages..., + ), + ), + ) +} + +func (m *messagesCmp) View() string { + if m.rendering { + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + "Loading...", + m.working(), + m.help(), + ), + ) + } + if len(m.messages) == 0 { + content := styles.BaseStyle. + Width(m.width). + Height(m.height - 1). + Render( + m.initialScreen(), + ) + + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + content, + "", + m.help(), + ), + ) + } + + return styles.BaseStyle. + Width(m.width). + Render( + lipgloss.JoinVertical( + lipgloss.Top, + m.viewport.View(), + m.working(), + m.help(), + ), + ) +} + +func hasToolsWithoutResponse(messages []message.Message) bool { + toolCalls := make([]message.ToolCall, 0) + toolResults := make([]message.ToolResult, 0) + for _, m := range messages { + toolCalls = append(toolCalls, m.ToolCalls()...) + toolResults = append(toolResults, m.ToolResults()...) + } + + for _, v := range toolCalls { + found := false + for _, r := range toolResults { + if v.ID == r.ToolCallID { + found = true + break + } + } + if !found { + return true + } + } + + return false +} + +func (m *messagesCmp) working() string { + text := "" + if m.IsAgentWorking() { + task := "Thinking..." + lastMessage := m.messages[len(m.messages)-1] + if hasToolsWithoutResponse(m.messages) { + task = "Waiting for tool response..." + } else if !lastMessage.IsFinished() { + lastUpdate := lastMessage.UpdatedAt + currentTime := time.Now().Unix() + if lastMessage.Content().String() != "" && lastUpdate != 0 && currentTime-lastUpdate > 5 { + task = "Building tool call..." + } else if lastMessage.Content().String() == "" { + task = "Generating..." + } + task = "" + } + if task != "" { + text += styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render( + fmt.Sprintf("%s %s ", m.spinner.View(), task), + ) + } + } + return text +} + +func (m *messagesCmp) help() string { + text := "" + + if m.writingMode { + text += lipgloss.JoinHorizontal( + lipgloss.Left, + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), + styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"), + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"), + ) + } else { + text += lipgloss.JoinHorizontal( + lipgloss.Left, + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), + styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"), + styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"), + ) + } + + return styles.BaseStyle. + Width(m.width). + Render(text) +} + +func (m *messagesCmp) initialScreen() string { + return styles.BaseStyle.Width(m.width).Render( + lipgloss.JoinVertical( + lipgloss.Top, + header(m.width), + "", + lspsConfigured(m.width), + ), + ) +} + +func (m *messagesCmp) SetSize(width, height int) tea.Cmd { + if m.width == width && m.height == height { + return nil + } + m.width = width + m.height = height + m.viewport.Width = width + m.viewport.Height = height - 2 + m.renderView() + return m.preloadSessions() +} + +func (m *messagesCmp) GetSize() (int, int) { + return m.width, m.height +} + +func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { + if m.session.ID == session.ID { + return nil + } + m.rendering = true + return func() tea.Msg { + m.session = session + messages, err := m.app.Messages.List(context.Background(), session.ID) + if err != nil { + return util.ReportError(err) + } + m.messages = messages + m.currentMsgID = m.messages[len(m.messages)-1].ID + delete(m.cachedContent, m.currentMsgID) + m.renderView() + return renderFinishedMsg{} + } +} + +func (m *messagesCmp) BindingKeys() []key.Binding { + bindings := layout.KeyMapToSlice(m.viewport.KeyMap) + return bindings +} + +func NewMessagesCmp(app *app.App) tea.Model { + s := spinner.New() + s.Spinner = spinner.Pulse + return &messagesCmp{ + app: app, + writingMode: true, + cachedContent: make(map[string][]uiMessage), + viewport: viewport.New(0, 0), + spinner: s, + } +} diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go new file mode 100644 index 0000000000000000000000000000000000000000..be6c7ce5087276e7a616af1ed5da8933206ea972 --- /dev/null +++ b/internal/tui/components/chat/message.go @@ -0,0 +1,561 @@ +package chat + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/charmbracelet/glamour" + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/x/ansi" + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/diff" + "github.com/kujtimiihoxha/opencode/internal/llm/agent" + "github.com/kujtimiihoxha/opencode/internal/llm/models" + "github.com/kujtimiihoxha/opencode/internal/llm/tools" + "github.com/kujtimiihoxha/opencode/internal/message" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" +) + +type uiMessageType int + +const ( + userMessageType uiMessageType = iota + assistantMessageType + toolMessageType + + maxResultHeight = 15 +) + +var diffStyle = diff.NewStyleConfig(diff.WithShowHeader(false), diff.WithShowHunkHeader(false)) + +type uiMessage struct { + ID string + messageType uiMessageType + position int + height int + content string +} + +type renderCache struct { + mutex sync.Mutex + cache map[string][]uiMessage +} + +func toMarkdown(content string, focused bool, width int) string { + r, _ := glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(false)), + glamour.WithWordWrap(width), + ) + if focused { + r, _ = glamour.NewTermRenderer( + glamour.WithStyles(styles.MarkdownTheme(true)), + glamour.WithWordWrap(width), + ) + } + rendered, _ := r.Render(content) + return rendered +} + +func renderMessage(msg string, isUser bool, isFocused bool, width int, info ...string) string { + style := styles.BaseStyle. + Width(width - 1). + BorderLeft(true). + Foreground(styles.ForgroundDim). + BorderForeground(styles.PrimaryColor). + BorderStyle(lipgloss.ThickBorder()) + if isUser { + style = style. + BorderForeground(styles.Blue) + } + parts := []string{ + styles.ForceReplaceBackgroundWithLipgloss(toMarkdown(msg, isFocused, width), styles.Background), + } + + // remove newline at the end + parts[0] = strings.TrimSuffix(parts[0], "\n") + if len(info) > 0 { + parts = append(parts, info...) + } + rendered := style.Render( + lipgloss.JoinVertical( + lipgloss.Left, + parts..., + ), + ) + + return rendered +} + +func renderUserMessage(msg message.Message, isFocused bool, width int, position int) uiMessage { + content := renderMessage(msg.Content().String(), true, isFocused, width) + userMsg := uiMessage{ + ID: msg.ID, + messageType: userMessageType, + position: position, + height: lipgloss.Height(content), + content: content, + } + return userMsg +} + +// Returns multiple uiMessages because of the tool calls +func renderAssistantMessage( + msg message.Message, + msgIndex int, + allMessages []message.Message, // we need this to get tool results and the user message + messagesService message.Service, // We need this to get the task tool messages + focusedUIMessageId string, + width int, + position int, +) []uiMessage { + // find the user message that is before this assistant message + var userMsg message.Message + for i := msgIndex - 1; i >= 0; i-- { + msg := allMessages[i] + if msg.Role == message.User { + userMsg = allMessages[i] + break + } + } + + messages := []uiMessage{} + content := msg.Content().String() + finished := msg.IsFinished() + finishData := msg.FinishPart() + info := []string{} + + // Add finish info if available + if finished { + switch finishData.Reason { + case message.FinishReasonEndTurn: + took := formatTimeDifference(userMsg.CreatedAt, finishData.Time) + info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took), + )) + case message.FinishReasonCanceled: + info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "canceled"), + )) + case message.FinishReasonError: + info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "error"), + )) + case message.FinishReasonPermissionDenied: + info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( + fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, "permission denied"), + )) + } + } + if content != "" { + content = renderMessage(content, false, msg.ID == focusedUIMessageId, width, info...) + messages = append(messages, uiMessage{ + ID: msg.ID, + messageType: assistantMessageType, + position: position, + height: lipgloss.Height(content), + content: content, + }) + position += messages[0].height + position++ // for the space + } + + for i, toolCall := range msg.ToolCalls() { + toolCallContent := renderToolMessage( + toolCall, + allMessages, + messagesService, + focusedUIMessageId, + false, + width, + i+1, + ) + messages = append(messages, toolCallContent) + position += toolCallContent.height + position++ // for the space + } + return messages +} + +func findToolResponse(toolCallID string, futureMessages []message.Message) *message.ToolResult { + for _, msg := range futureMessages { + for _, result := range msg.ToolResults() { + if result.ToolCallID == toolCallID { + return &result + } + } + } + return nil +} + +func toolName(name string) string { + switch name { + case agent.AgentToolName: + return "Task" + case tools.BashToolName: + return "Bash" + case tools.EditToolName: + return "Edit" + case tools.FetchToolName: + return "Fetch" + case tools.GlobToolName: + return "Glob" + case tools.GrepToolName: + return "Grep" + case tools.LSToolName: + return "List" + case tools.SourcegraphToolName: + return "Sourcegraph" + case tools.ViewToolName: + return "View" + case tools.WriteToolName: + return "Write" + } + return name +} + +// renders params, params[0] (params[1]=params[2] ....) +func renderParams(paramsWidth int, params ...string) string { + if len(params) == 0 { + return "" + } + mainParam := params[0] + if len(mainParam) > paramsWidth { + mainParam = mainParam[:paramsWidth-3] + "..." + } + + if len(params) == 1 { + return mainParam + } + otherParams := params[1:] + // create pairs of key/value + // if odd number of params, the last one is a key without value + if len(otherParams)%2 != 0 { + otherParams = append(otherParams, "") + } + parts := make([]string, 0, len(otherParams)/2) + for i := 0; i < len(otherParams); i += 2 { + key := otherParams[i] + value := otherParams[i+1] + if value == "" { + continue + } + parts = append(parts, fmt.Sprintf("%s=%s", key, value)) + } + + partsRendered := strings.Join(parts, ", ") + remainingWidth := paramsWidth - lipgloss.Width(partsRendered) - 5 // for the space + if remainingWidth < 30 { + // No space for the params, just show the main + return mainParam + } + + if len(parts) > 0 { + mainParam = fmt.Sprintf("%s (%s)", mainParam, strings.Join(parts, ", ")) + } + + return ansi.Truncate(mainParam, paramsWidth, "...") +} + +func removeWorkingDirPrefix(path string) string { + wd := config.WorkingDirectory() + if strings.HasPrefix(path, wd) { + path = strings.TrimPrefix(path, wd) + } + if strings.HasPrefix(path, "/") { + path = strings.TrimPrefix(path, "/") + } + if strings.HasPrefix(path, "./") { + path = strings.TrimPrefix(path, "./") + } + if strings.HasPrefix(path, "../") { + path = strings.TrimPrefix(path, "../") + } + return path +} + +func renderToolParams(paramWidth int, toolCall message.ToolCall) string { + params := "" + switch toolCall.Name { + case agent.AgentToolName: + var params agent.AgentParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + prompt := strings.ReplaceAll(params.Prompt, "\n", " ") + return renderParams(paramWidth, prompt) + case tools.BashToolName: + var params tools.BashParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + command := strings.ReplaceAll(params.Command, "\n", " ") + return renderParams(paramWidth, command) + case tools.EditToolName: + var params tools.EditParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + filePath := removeWorkingDirPrefix(params.FilePath) + return renderParams(paramWidth, filePath) + case tools.FetchToolName: + var params tools.FetchParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + url := params.URL + toolParams := []string{ + url, + } + if params.Format != "" { + toolParams = append(toolParams, "format", params.Format) + } + if params.Timeout != 0 { + toolParams = append(toolParams, "timeout", (time.Duration(params.Timeout) * time.Second).String()) + } + return renderParams(paramWidth, toolParams...) + case tools.GlobToolName: + var params tools.GlobParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + pattern := params.Pattern + toolParams := []string{ + pattern, + } + if params.Path != "" { + toolParams = append(toolParams, "path", params.Path) + } + return renderParams(paramWidth, toolParams...) + case tools.GrepToolName: + var params tools.GrepParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + pattern := params.Pattern + toolParams := []string{ + pattern, + } + if params.Path != "" { + toolParams = append(toolParams, "path", params.Path) + } + if params.Include != "" { + toolParams = append(toolParams, "include", params.Include) + } + if params.LiteralText { + toolParams = append(toolParams, "literal", "true") + } + return renderParams(paramWidth, toolParams...) + case tools.LSToolName: + var params tools.LSParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + path := params.Path + if path == "" { + path = "." + } + return renderParams(paramWidth, path) + case tools.SourcegraphToolName: + var params tools.SourcegraphParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + return renderParams(paramWidth, params.Query) + case tools.ViewToolName: + var params tools.ViewParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + filePath := removeWorkingDirPrefix(params.FilePath) + toolParams := []string{ + filePath, + } + if params.Limit != 0 { + toolParams = append(toolParams, "limit", fmt.Sprintf("%d", params.Limit)) + } + if params.Offset != 0 { + toolParams = append(toolParams, "offset", fmt.Sprintf("%d", params.Offset)) + } + return renderParams(paramWidth, toolParams...) + case tools.WriteToolName: + var params tools.WriteParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + filePath := removeWorkingDirPrefix(params.FilePath) + return renderParams(paramWidth, filePath) + default: + input := strings.ReplaceAll(toolCall.Input, "\n", " ") + params = renderParams(paramWidth, input) + } + return params +} + +func truncateHeight(content string, height int) string { + lines := strings.Split(content, "\n") + if len(lines) > height { + return strings.Join(lines[:height], "\n") + } + return content +} + +func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, width int) string { + if response.IsError { + errContent := fmt.Sprintf("Error: %s", strings.ReplaceAll(response.Content, "\n", " ")) + errContent = ansi.Truncate(errContent, width-1, "...") + return styles.BaseStyle. + Foreground(styles.Error). + Render(errContent) + } + resultContent := truncateHeight(response.Content, maxResultHeight) + switch toolCall.Name { + case agent.AgentToolName: + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, false, width), + styles.Background, + ) + case tools.BashToolName: + resultContent = fmt.Sprintf("```bash\n%s\n```", resultContent) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + case tools.EditToolName: + metadata := tools.EditResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + truncDiff := truncateHeight(metadata.Diff, maxResultHeight) + formattedDiff, _ := diff.FormatDiff(truncDiff, diff.WithTotalWidth(width), diff.WithStyle(diffStyle)) + return formattedDiff + case tools.FetchToolName: + var params tools.FetchParams + json.Unmarshal([]byte(toolCall.Input), ¶ms) + mdFormat := "markdown" + switch params.Format { + case "text": + mdFormat = "text" + case "html": + mdFormat = "html" + } + resultContent = fmt.Sprintf("```%s\n%s\n```", mdFormat, resultContent) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + case tools.GlobToolName: + return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent) + case tools.GrepToolName: + return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent) + case tools.LSToolName: + return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent) + case tools.SourcegraphToolName: + return styles.BaseStyle.Width(width).Foreground(styles.ForgroundMid).Render(resultContent) + case tools.ViewToolName: + metadata := tools.ViewResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + ext := filepath.Ext(metadata.FilePath) + if ext == "" { + ext = "" + } else { + ext = strings.ToLower(ext[1:]) + } + resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(metadata.Content, maxResultHeight)) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + case tools.WriteToolName: + params := tools.WriteParams{} + json.Unmarshal([]byte(toolCall.Input), ¶ms) + metadata := tools.WriteResponseMetadata{} + json.Unmarshal([]byte(response.Metadata), &metadata) + ext := filepath.Ext(params.FilePath) + if ext == "" { + ext = "" + } else { + ext = strings.ToLower(ext[1:]) + } + resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(params.Content, maxResultHeight)) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + default: + resultContent = fmt.Sprintf("```text\n%s\n```", resultContent) + return styles.ForceReplaceBackgroundWithLipgloss( + toMarkdown(resultContent, true, width), + styles.Background, + ) + } +} + +func renderToolMessage( + toolCall message.ToolCall, + allMessages []message.Message, + messagesService message.Service, + focusedUIMessageId string, + nested bool, + width int, + position int, +) uiMessage { + if nested { + width = width - 3 + } + response := findToolResponse(toolCall.ID, allMessages) + toolName := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%s: ", toolName(toolCall.Name))) + params := renderToolParams(width-2-lipgloss.Width(toolName), toolCall) + responseContent := "" + if response != nil { + responseContent = renderToolResponse(toolCall, *response, width-2) + responseContent = strings.TrimSuffix(responseContent, "\n") + } else { + responseContent = styles.BaseStyle. + Italic(true). + Width(width - 2). + Foreground(styles.ForgroundDim). + Render("Waiting for response...") + } + style := styles.BaseStyle. + Width(width - 1). + BorderLeft(true). + BorderStyle(lipgloss.ThickBorder()). + PaddingLeft(1). + BorderForeground(styles.ForgroundDim) + + parts := []string{} + if !nested { + params := styles.BaseStyle. + Width(width - 2 - lipgloss.Width(toolName)). + Foreground(styles.ForgroundDim). + Render(params) + + parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, toolName, params)) + } else { + prefix := styles.BaseStyle. + Foreground(styles.ForgroundDim). + Render(" └ ") + params := styles.BaseStyle. + Width(width - 2 - lipgloss.Width(toolName)). + Foreground(styles.ForgroundMid). + Render(params) + parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, prefix, toolName, params)) + } + if toolCall.Name == agent.AgentToolName { + taskMessages, _ := messagesService.List(context.Background(), toolCall.ID) + toolCalls := []message.ToolCall{} + for _, v := range taskMessages { + toolCalls = append(toolCalls, v.ToolCalls()...) + } + for _, call := range toolCalls { + rendered := renderToolMessage(call, []message.Message{}, messagesService, focusedUIMessageId, true, width, 0) + parts = append(parts, rendered.content) + } + } + if responseContent != "" && !nested { + parts = append(parts, responseContent) + } + + content := style.Render( + lipgloss.JoinVertical( + lipgloss.Left, + parts..., + ), + ) + if nested { + content = lipgloss.JoinVertical( + lipgloss.Left, + parts..., + ) + } + toolMsg := uiMessage{ + messageType: toolMessageType, + position: position, + height: lipgloss.Height(content), + content: content, + } + return toolMsg +} diff --git a/internal/tui/components/chat/messages.go b/internal/tui/components/chat/messages.go deleted file mode 100644 index c2ce7d88b13ea9858f7adf5cf343504525f35973..0000000000000000000000000000000000000000 --- a/internal/tui/components/chat/messages.go +++ /dev/null @@ -1,742 +0,0 @@ -package chat - -import ( - "context" - "encoding/json" - "fmt" - "math" - "strings" - "time" - - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/spinner" - "github.com/charmbracelet/bubbles/viewport" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/glamour" - "github.com/charmbracelet/lipgloss" - "github.com/charmbracelet/x/ansi" - "github.com/kujtimiihoxha/opencode/internal/app" - "github.com/kujtimiihoxha/opencode/internal/llm/agent" - "github.com/kujtimiihoxha/opencode/internal/llm/models" - "github.com/kujtimiihoxha/opencode/internal/llm/tools" - "github.com/kujtimiihoxha/opencode/internal/logging" - "github.com/kujtimiihoxha/opencode/internal/message" - "github.com/kujtimiihoxha/opencode/internal/pubsub" - "github.com/kujtimiihoxha/opencode/internal/session" - "github.com/kujtimiihoxha/opencode/internal/tui/layout" - "github.com/kujtimiihoxha/opencode/internal/tui/styles" - "github.com/kujtimiihoxha/opencode/internal/tui/util" -) - -type uiMessageType int - -const ( - userMessageType uiMessageType = iota - assistantMessageType - toolMessageType -) - -// messagesTickMsg is a message sent by the timer to refresh messages -type messagesTickMsg time.Time - -type uiMessage struct { - ID string - messageType uiMessageType - position int - height int - content string -} - -type messagesCmp struct { - app *app.App - width, height int - writingMode bool - viewport viewport.Model - session session.Session - messages []message.Message - uiMessages []uiMessage - currentMsgID string - renderer *glamour.TermRenderer - focusRenderer *glamour.TermRenderer - cachedContent map[string]string - spinner spinner.Model - needsRerender bool -} - -func (m *messagesCmp) Init() tea.Cmd { - return tea.Batch(m.viewport.Init(), m.spinner.Tick, m.tickMessages()) -} - -func (m *messagesCmp) tickMessages() tea.Cmd { - return tea.Tick(time.Second, func(t time.Time) tea.Msg { - return messagesTickMsg(t) - }) -} - -func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - switch msg := msg.(type) { - case messagesTickMsg: - // Refresh messages if we have an active session - if m.session.ID != "" { - messages, err := m.app.Messages.List(context.Background(), m.session.ID) - if err == nil { - m.messages = messages - m.needsRerender = true - } - } - // Continue ticking - cmds = append(cmds, m.tickMessages()) - case EditorFocusMsg: - m.writingMode = bool(msg) - case SessionSelectedMsg: - if msg.ID != m.session.ID { - cmd := m.SetSession(msg) - m.needsRerender = true - return m, cmd - } - return m, nil - case SessionClearedMsg: - m.session = session.Session{} - m.messages = make([]message.Message, 0) - m.currentMsgID = "" - m.needsRerender = true - m.cachedContent = make(map[string]string) - return m, nil - - case tea.KeyMsg: - if m.writingMode { - return m, nil - } - case pubsub.Event[message.Message]: - if msg.Type == pubsub.CreatedEvent { - if msg.Payload.SessionID == m.session.ID { - // check if message exists - - messageExists := false - for _, v := range m.messages { - if v.ID == msg.Payload.ID { - messageExists = true - break - } - } - - if !messageExists { - // If we have messages, ensure the previous last message is not cached - if len(m.messages) > 0 { - lastMsgID := m.messages[len(m.messages)-1].ID - delete(m.cachedContent, lastMsgID) - } - - m.messages = append(m.messages, msg.Payload) - delete(m.cachedContent, m.currentMsgID) - m.currentMsgID = msg.Payload.ID - m.needsRerender = true - } - } - for _, v := range m.messages { - for _, c := range v.ToolCalls() { - if c.ID == msg.Payload.SessionID { - m.needsRerender = true - } - } - } - } else if msg.Type == pubsub.UpdatedEvent && msg.Payload.SessionID == m.session.ID { - logging.Debug("Message", "finish", msg.Payload.FinishReason()) - for i, v := range m.messages { - if v.ID == msg.Payload.ID { - m.messages[i] = msg.Payload - delete(m.cachedContent, msg.Payload.ID) - - // If this is the last message, ensure it's not cached - if i == len(m.messages)-1 { - delete(m.cachedContent, msg.Payload.ID) - } - - m.needsRerender = true - break - } - } - } - } - - oldPos := m.viewport.YPosition - u, cmd := m.viewport.Update(msg) - m.viewport = u - m.needsRerender = m.needsRerender || m.viewport.YPosition != oldPos - cmds = append(cmds, cmd) - - spinner, cmd := m.spinner.Update(msg) - m.spinner = spinner - cmds = append(cmds, cmd) - - if m.needsRerender { - m.renderView() - if len(m.messages) > 0 { - if msg, ok := msg.(pubsub.Event[message.Message]); ok { - if (msg.Type == pubsub.CreatedEvent) || - (msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.messages[len(m.messages)-1].ID) { - m.viewport.GotoBottom() - } - } - } - m.needsRerender = false - } - return m, tea.Batch(cmds...) -} - -func (m *messagesCmp) IsAgentWorking() bool { - return m.app.CoderAgent.IsSessionBusy(m.session.ID) -} - -func (m *messagesCmp) renderSimpleMessage(msg message.Message, info ...string) string { - // Check if this is the last message in the list - isLastMessage := len(m.messages) > 0 && m.messages[len(m.messages)-1].ID == msg.ID - - // Only use cache for non-last messages - if !isLastMessage { - if v, ok := m.cachedContent[msg.ID]; ok { - return v - } - } - - style := styles.BaseStyle. - Width(m.width). - BorderLeft(true). - Foreground(styles.ForgroundDim). - BorderForeground(styles.ForgroundDim). - BorderStyle(lipgloss.ThickBorder()) - - renderer := m.renderer - if msg.ID == m.currentMsgID { - style = style. - Foreground(styles.Forground). - BorderForeground(styles.Blue). - BorderStyle(lipgloss.ThickBorder()) - renderer = m.focusRenderer - } - c, _ := renderer.Render(msg.Content().String()) - parts := []string{ - styles.ForceReplaceBackgroundWithLipgloss(c, styles.Background), - } - // remove newline at the end - parts[0] = strings.TrimSuffix(parts[0], "\n") - if len(info) > 0 { - parts = append(parts, info...) - } - rendered := style.Render( - lipgloss.JoinVertical( - lipgloss.Left, - parts..., - ), - ) - - // Only cache if it's not the last message - if !isLastMessage { - m.cachedContent[msg.ID] = rendered - } - - return rendered -} - -func formatTimeDifference(unixTime1, unixTime2 int64) string { - diffSeconds := float64(math.Abs(float64(unixTime2 - unixTime1))) - - if diffSeconds < 60 { - return fmt.Sprintf("%.1fs", diffSeconds) - } - - minutes := int(diffSeconds / 60) - seconds := int(diffSeconds) % 60 - return fmt.Sprintf("%dm%ds", minutes, seconds) -} - -func (m *messagesCmp) findToolResponse(callID string) *message.ToolResult { - for _, v := range m.messages { - for _, c := range v.ToolResults() { - if c.ToolCallID == callID { - return &c - } - } - } - return nil -} - -func (m *messagesCmp) renderToolCall(toolCall message.ToolCall, isNested bool) string { - key := "" - value := "" - result := styles.BaseStyle.Foreground(styles.PrimaryColor).Render(m.spinner.View() + " waiting for response...") - - response := m.findToolResponse(toolCall.ID) - if response != nil && response.IsError { - // Clean up error message for display by removing newlines - // This ensures error messages display properly in the UI - errMsg := strings.ReplaceAll(response.Content, "\n", " ") - result = styles.BaseStyle.Foreground(styles.Error).Render(ansi.Truncate(errMsg, 40, "...")) - } else if response != nil { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render("Done") - } - switch toolCall.Name { - // TODO: add result data to the tools - case agent.AgentToolName: - key = "Task" - var params agent.AgentParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = strings.ReplaceAll(params.Prompt, "\n", " ") - if response != nil && !response.IsError { - firstRow := strings.ReplaceAll(response.Content, "\n", " ") - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(ansi.Truncate(firstRow, 40, "...")) - } - case tools.BashToolName: - key = "Bash" - var params tools.BashParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.Command - if response != nil && !response.IsError { - metadata := tools.BashResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("Took %s", formatTimeDifference(metadata.StartTime, metadata.EndTime))) - } - - case tools.EditToolName: - key = "Edit" - var params tools.EditParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.FilePath - if response != nil && !response.IsError { - metadata := tools.EditResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) - } - case tools.FetchToolName: - key = "Fetch" - var params tools.FetchParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.URL - if response != nil && !response.IsError { - result = styles.BaseStyle.Foreground(styles.Error).Render(response.Content) - } - case tools.GlobToolName: - key = "Glob" - var params tools.GlobParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - if params.Path == "" { - params.Path = "." - } - value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) - if response != nil && !response.IsError { - metadata := tools.GlobResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - if metadata.Truncated { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) - } else { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) - } - } - case tools.GrepToolName: - key = "Grep" - var params tools.GrepParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - if params.Path == "" { - params.Path = "." - } - value = fmt.Sprintf("%s (%s)", params.Pattern, params.Path) - if response != nil && !response.IsError { - metadata := tools.GrepResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - if metadata.Truncated { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfMatches)) - } else { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfMatches)) - } - } - case tools.LSToolName: - key = "ls" - var params tools.LSParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - if params.Path == "" { - params.Path = "." - } - value = params.Path - if response != nil && !response.IsError { - metadata := tools.LSResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - if metadata.Truncated { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found (truncated)", metadata.NumberOfFiles)) - } else { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d files found", metadata.NumberOfFiles)) - } - } - case tools.SourcegraphToolName: - key = "Sourcegraph" - var params tools.SourcegraphParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.Query - if response != nil && !response.IsError { - metadata := tools.SourcegraphResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - if metadata.Truncated { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found (truncated)", metadata.NumberOfMatches)) - } else { - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d matches found", metadata.NumberOfMatches)) - } - } - case tools.ViewToolName: - key = "View" - var params tools.ViewParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.FilePath - case tools.WriteToolName: - key = "Write" - var params tools.WriteParams - json.Unmarshal([]byte(toolCall.Input), ¶ms) - value = params.FilePath - if response != nil && !response.IsError { - metadata := tools.WriteResponseMetadata{} - json.Unmarshal([]byte(response.Metadata), &metadata) - - result = styles.BaseStyle.Foreground(styles.ForgroundMid).Render(fmt.Sprintf("%d Additions %d Removals", metadata.Additions, metadata.Removals)) - } - default: - key = toolCall.Name - var params map[string]any - json.Unmarshal([]byte(toolCall.Input), ¶ms) - jsonData, _ := json.Marshal(params) - value = string(jsonData) - } - - style := styles.BaseStyle. - Width(m.width). - BorderLeft(true). - BorderStyle(lipgloss.ThickBorder()). - PaddingLeft(1). - BorderForeground(styles.Yellow) - - keyStyle := styles.BaseStyle. - Foreground(styles.ForgroundDim) - valyeStyle := styles.BaseStyle. - Foreground(styles.Forground) - - if isNested { - valyeStyle = valyeStyle.Foreground(styles.ForgroundMid) - } - keyValye := keyStyle.Render( - fmt.Sprintf("%s: ", key), - ) - if !isNested { - value = valyeStyle. - Render( - ansi.Truncate( - value+" ", - m.width-lipgloss.Width(keyValye)-2-lipgloss.Width(result), - "...", - ), - ) - value += result - - } else { - keyValye = keyStyle.Render( - fmt.Sprintf(" └ %s: ", key), - ) - value = valyeStyle. - Width(m.width - lipgloss.Width(keyValye) - 2). - Render( - ansi.Truncate( - value, - m.width-lipgloss.Width(keyValye)-2, - "...", - ), - ) - } - - innerToolCalls := make([]string, 0) - if toolCall.Name == agent.AgentToolName { - messages, _ := m.app.Messages.List(context.Background(), toolCall.ID) - toolCalls := make([]message.ToolCall, 0) - for _, v := range messages { - toolCalls = append(toolCalls, v.ToolCalls()...) - } - for _, v := range toolCalls { - call := m.renderToolCall(v, true) - innerToolCalls = append(innerToolCalls, call) - } - } - - if isNested { - return lipgloss.JoinHorizontal( - lipgloss.Left, - keyValye, - value, - ) - } - callContent := lipgloss.JoinHorizontal( - lipgloss.Left, - keyValye, - value, - ) - callContent = strings.ReplaceAll(callContent, "\n", "") - if len(innerToolCalls) > 0 { - callContent = lipgloss.JoinVertical( - lipgloss.Left, - callContent, - lipgloss.JoinVertical( - lipgloss.Left, - innerToolCalls..., - ), - ) - } - return style.Render(callContent) -} - -func (m *messagesCmp) renderAssistantMessage(msg message.Message) []uiMessage { - // find the user message that is before this assistant message - var userMsg message.Message - for i := len(m.messages) - 1; i >= 0; i-- { - if m.messages[i].Role == message.User { - userMsg = m.messages[i] - break - } - } - messages := make([]uiMessage, 0) - if msg.Content().String() != "" { - info := make([]string, 0) - if msg.IsFinished() && msg.FinishReason() == "end_turn" { - finish := msg.FinishPart() - took := formatTimeDifference(userMsg.CreatedAt, finish.Time) - - info = append(info, styles.BaseStyle.Width(m.width-1).Foreground(styles.ForgroundDim).Render( - fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took), - )) - } - content := m.renderSimpleMessage(msg, info...) - messages = append(messages, uiMessage{ - messageType: assistantMessageType, - position: 0, // gets updated in renderView - height: lipgloss.Height(content), - content: content, - }) - } - for _, v := range msg.ToolCalls() { - content := m.renderToolCall(v, false) - messages = append(messages, - uiMessage{ - messageType: toolMessageType, - position: 0, // gets updated in renderView - height: lipgloss.Height(content), - content: content, - }, - ) - } - - return messages -} - -func (m *messagesCmp) renderView() { - m.uiMessages = make([]uiMessage, 0) - pos := 0 - - // If we have messages, ensure the last message is not cached - // This ensures we always render the latest content for the most recent message - // which may be actively updating (e.g., during generation) - if len(m.messages) > 0 { - lastMsgID := m.messages[len(m.messages)-1].ID - delete(m.cachedContent, lastMsgID) - } - - // Limit cache to 10 messages - if len(m.cachedContent) > 15 { - // Create a list of keys to delete (oldest messages first) - keys := make([]string, 0, len(m.cachedContent)) - for k := range m.cachedContent { - keys = append(keys, k) - } - // Delete oldest messages until we have 10 or fewer - for i := 0; i < len(keys)-15; i++ { - delete(m.cachedContent, keys[i]) - } - } - - for _, v := range m.messages { - switch v.Role { - case message.User: - content := m.renderSimpleMessage(v) - m.uiMessages = append(m.uiMessages, uiMessage{ - messageType: userMessageType, - position: pos, - height: lipgloss.Height(content), - content: content, - }) - pos += lipgloss.Height(content) + 1 // + 1 for spacing - case message.Assistant: - assistantMessages := m.renderAssistantMessage(v) - for _, msg := range assistantMessages { - msg.position = pos - m.uiMessages = append(m.uiMessages, msg) - pos += msg.height + 1 // + 1 for spacing - } - - } - } - - messages := make([]string, 0) - for _, v := range m.uiMessages { - messages = append(messages, v.content, - styles.BaseStyle. - Width(m.width). - Render( - "", - ), - ) - } - m.viewport.SetContent( - styles.BaseStyle. - Width(m.width). - Render( - lipgloss.JoinVertical( - lipgloss.Top, - messages..., - ), - ), - ) -} - -func (m *messagesCmp) View() string { - if len(m.messages) == 0 { - content := styles.BaseStyle. - Width(m.width). - Height(m.height - 1). - Render( - m.initialScreen(), - ) - - return styles.BaseStyle. - Width(m.width). - Render( - lipgloss.JoinVertical( - lipgloss.Top, - content, - m.help(), - ), - ) - } - - return styles.BaseStyle. - Width(m.width). - Render( - lipgloss.JoinVertical( - lipgloss.Top, - m.viewport.View(), - m.help(), - ), - ) -} - -func (m *messagesCmp) help() string { - text := "" - - if m.IsAgentWorking() { - text += styles.BaseStyle.Foreground(styles.PrimaryColor).Bold(true).Render( - fmt.Sprintf("%s %s ", m.spinner.View(), "Generating..."), - ) - } - if m.writingMode { - text += lipgloss.JoinHorizontal( - lipgloss.Left, - styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), - styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("esc"), - styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to exit writing mode"), - ) - } else { - text += lipgloss.JoinHorizontal( - lipgloss.Left, - styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render("press "), - styles.BaseStyle.Foreground(styles.Forground).Bold(true).Render("i"), - styles.BaseStyle.Foreground(styles.ForgroundDim).Bold(true).Render(" to start writing"), - ) - } - - return styles.BaseStyle. - Width(m.width). - Render(text) -} - -func (m *messagesCmp) initialScreen() string { - return styles.BaseStyle.Width(m.width).Render( - lipgloss.JoinVertical( - lipgloss.Top, - header(m.width), - "", - lspsConfigured(m.width), - ), - ) -} - -func (m *messagesCmp) SetSize(width, height int) { - m.width = width - m.height = height - m.viewport.Width = width - m.viewport.Height = height - 1 - focusRenderer, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.MarkdownTheme(true)), - glamour.WithWordWrap(width-1), - ) - renderer, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.MarkdownTheme(false)), - glamour.WithWordWrap(width-1), - ) - m.focusRenderer = focusRenderer - // clear the cached content - for k := range m.cachedContent { - delete(m.cachedContent, k) - } - m.renderer = renderer - if len(m.messages) > 0 { - m.renderView() - m.viewport.GotoBottom() - } -} - -func (m *messagesCmp) GetSize() (int, int) { - return m.width, m.height -} - -func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { - m.session = session - messages, err := m.app.Messages.List(context.Background(), session.ID) - if err != nil { - return util.ReportError(err) - } - m.messages = messages - m.currentMsgID = m.messages[len(m.messages)-1].ID - m.needsRerender = true - m.cachedContent = make(map[string]string) - return nil -} - -func (m *messagesCmp) BindingKeys() []key.Binding { - bindings := layout.KeyMapToSlice(m.viewport.KeyMap) - return bindings -} - -func NewMessagesCmp(app *app.App) tea.Model { - focusRenderer, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.MarkdownTheme(true)), - glamour.WithWordWrap(80), - ) - renderer, _ := glamour.NewTermRenderer( - glamour.WithStyles(styles.MarkdownTheme(false)), - glamour.WithWordWrap(80), - ) - - s := spinner.New() - s.Spinner = spinner.Pulse - return &messagesCmp{ - app: app, - writingMode: true, - cachedContent: make(map[string]string), - viewport: viewport.New(0, 0), - focusRenderer: focusRenderer, - renderer: renderer, - spinner: s, - } -} diff --git a/internal/tui/components/chat/sidebar.go b/internal/tui/components/chat/sidebar.go index 5a275c0cfd1d264571fff33dedb5dae1c282b8d3..d330e592bc5be5ba302ee175d5dedccf4a0412eb 100644 --- a/internal/tui/components/chat/sidebar.go +++ b/internal/tui/components/chat/sidebar.go @@ -51,6 +51,12 @@ func (m *sidebarCmp) Init() tea.Cmd { func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { + case SessionSelectedMsg: + if msg.ID != m.session.ID { + m.session = msg + ctx := context.Background() + m.loadModifiedFiles(ctx) + } case pubsub.Event[session.Session]: if msg.Type == pubsub.UpdatedEvent { if m.session.ID == msg.Payload.ID { @@ -59,10 +65,16 @@ func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case pubsub.Event[history.File]: if msg.Payload.SessionID == m.session.ID { - // When a file changes, reload all modified files - // This ensures we have the complete and accurate list + // Process the individual file change instead of reloading all files ctx := context.Background() - m.loadModifiedFiles(ctx) + m.processFileChanges(ctx, msg.Payload) + + // Return a command to continue receiving events + return m, func() tea.Msg { + ctx := context.Background() + filesCh := m.history.Subscribe(ctx) + return <-filesCh + } } } return m, nil @@ -71,6 +83,8 @@ func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *sidebarCmp) View() string { return styles.BaseStyle. Width(m.width). + PaddingLeft(4). + PaddingRight(2). Height(m.height - 1). Render( lipgloss.JoinVertical( @@ -79,9 +93,9 @@ func (m *sidebarCmp) View() string { " ", m.sessionSection(), " ", - m.modifiedFiles(), - " ", lspsConfigured(m.width), + " ", + m.modifiedFiles(), ), ) } @@ -170,9 +184,10 @@ func (m *sidebarCmp) modifiedFiles() string { ) } -func (m *sidebarCmp) SetSize(width, height int) { +func (m *sidebarCmp) SetSize(width, height int) tea.Cmd { m.width = width m.height = height + return nil } func (m *sidebarCmp) GetSize() (int, int) { @@ -203,6 +218,12 @@ func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) { return } + // Clear the existing map to rebuild it + m.modFiles = make(map[string]struct { + additions int + removals int + }) + // Process each latest file for _, file := range latestFiles { // Skip if this is the initial version (no changes to show) @@ -250,28 +271,23 @@ func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) { } func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) { - // Skip if not the latest version + // Skip if this is the initial version (no changes to show) if file.Version == history.InitialVersion { return } - // Get all versions of this file - fileVersions, err := m.history.ListBySession(ctx, m.session.ID) - if err != nil { + // Find the initial version for this file + initialVersion, err := m.findInitialVersion(ctx, file.Path) + if err != nil || initialVersion.ID == "" { return } - // Find the initial version - var initialVersion history.File - for _, v := range fileVersions { - if v.Path == file.Path && v.Version == history.InitialVersion { - initialVersion = v - break - } - } - - // Skip if we can't find the initial version - if initialVersion.ID == "" { + // Skip if content hasn't changed + if initialVersion.Content == file.Content { + // If this file was previously modified but now matches the initial version, + // remove it from the modified files list + displayPath := getDisplayPath(file.Path) + delete(m.modFiles, displayPath) return } @@ -280,12 +296,7 @@ func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) // Only add to modified files if there are changes if additions > 0 || removals > 0 { - // Remove working directory prefix from file path - displayPath := file.Path - workingDir := config.WorkingDirectory() - displayPath = strings.TrimPrefix(displayPath, workingDir) - displayPath = strings.TrimPrefix(displayPath, "/") - + displayPath := getDisplayPath(file.Path) m.modFiles[displayPath] = struct { additions int removals int @@ -293,5 +304,34 @@ func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) additions: additions, removals: removals, } + } else { + // If no changes, remove from modified files + displayPath := getDisplayPath(file.Path) + delete(m.modFiles, displayPath) + } +} + +// Helper function to find the initial version of a file +func (m *sidebarCmp) findInitialVersion(ctx context.Context, path string) (history.File, error) { + // Get all versions of this file for the session + fileVersions, err := m.history.ListBySession(ctx, m.session.ID) + if err != nil { + return history.File{}, err } + + // Find the initial version + for _, v := range fileVersions { + if v.Path == path && v.Version == history.InitialVersion { + return v, nil + } + } + + return history.File{}, fmt.Errorf("initial version not found") +} + +// Helper function to get the display path for a file +func getDisplayPath(path string) string { + workingDir := config.WorkingDirectory() + displayPath := strings.TrimPrefix(path, workingDir) + return strings.TrimPrefix(displayPath, "/") } diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index e76ecde84e9575265ca06007a5b0697f543c6861..01c5358697906573a0bb52c4f8834fa7ee009df7 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -166,19 +166,31 @@ func (m *statusCmp) projectDiagnostics() string { diagnostics := []string{} if len(errorDiagnostics) > 0 { - errStr := lipgloss.NewStyle().Foreground(styles.Error).Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) + errStr := lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Error). + Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics))) diagnostics = append(diagnostics, errStr) } if len(warnDiagnostics) > 0 { - warnStr := lipgloss.NewStyle().Foreground(styles.Warning).Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) + warnStr := lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Warning). + Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics))) diagnostics = append(diagnostics, warnStr) } if len(hintDiagnostics) > 0 { - hintStr := lipgloss.NewStyle().Foreground(styles.Text).Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) + hintStr := lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Text). + Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics))) diagnostics = append(diagnostics, hintStr) } if len(infoDiagnostics) > 0 { - infoStr := lipgloss.NewStyle().Foreground(styles.Peach).Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) + infoStr := lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Peach). + Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics))) diagnostics = append(diagnostics, infoStr) } @@ -187,10 +199,12 @@ func (m *statusCmp) projectDiagnostics() string { func (m statusCmp) availableFooterMsgWidth(diagnostics string) int { tokens := "" + tokensWidth := 0 if m.session.ID != "" { tokens = formatTokensAndCost(m.session.PromptTokens+m.session.CompletionTokens, m.session.Cost) + tokensWidth = lipgloss.Width(tokens) + 2 } - return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)-lipgloss.Width(tokens)) + return max(0, m.width-lipgloss.Width(helpWidget)-lipgloss.Width(m.model())-lipgloss.Width(diagnostics)-tokensWidth) } func (m statusCmp) model() string { diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 2958844320f5dec51447fdf5f88bf6a243d686eb..f83472e68291b7deac8a243badb7ccab5db8c345 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -36,7 +36,7 @@ type PermissionResponseMsg struct { type PermissionDialogCmp interface { tea.Model layout.Bindings - SetPermissions(permission permission.PermissionRequest) + SetPermissions(permission permission.PermissionRequest) tea.Cmd } type permissionsMapping struct { @@ -98,7 +98,8 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: p.windowSize = msg - p.SetSize() + cmd := p.SetSize() + cmds = append(cmds, cmd) p.markdownCache = make(map[string]string) p.diffCache = make(map[string]string) case tea.KeyMsg: @@ -267,7 +268,7 @@ func (p *permissionDialogCmp) renderEditContent() string { } func (p *permissionDialogCmp) renderPatchContent() string { - if pr, ok := p.permission.Params.(tools.PatchPermissionsParams); ok { + if pr, ok := p.permission.Params.(tools.EditPermissionsParams); ok { diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) { return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width)) }) @@ -401,9 +402,9 @@ func (p *permissionDialogCmp) BindingKeys() []key.Binding { return layout.KeyMapToSlice(helpKeys) } -func (p *permissionDialogCmp) SetSize() { +func (p *permissionDialogCmp) SetSize() tea.Cmd { if p.permission.ID == "" { - return + return nil } switch p.permission.ToolName { case tools.BashToolName: @@ -422,11 +423,12 @@ func (p *permissionDialogCmp) SetSize() { p.width = int(float64(p.windowSize.Width) * 0.7) p.height = int(float64(p.windowSize.Height) * 0.5) } + return nil } -func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) { +func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) tea.Cmd { p.permission = permission - p.SetSize() + return p.SetSize() } // Helper to get or set cached diff content diff --git a/internal/tui/components/dialog/session.go b/internal/tui/components/dialog/session.go new file mode 100644 index 0000000000000000000000000000000000000000..d8c859c495345c6c7bb75c9f55070b21e4ed4b7a --- /dev/null +++ b/internal/tui/components/dialog/session.go @@ -0,0 +1,224 @@ +package dialog + +import ( + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/opencode/internal/session" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" +) + +// SessionSelectedMsg is sent when a session is selected +type SessionSelectedMsg struct { + Session session.Session +} + +// CloseSessionDialogMsg is sent when the session dialog is closed +type CloseSessionDialogMsg struct{} + +// SessionDialog interface for the session switching dialog +type SessionDialog interface { + tea.Model + layout.Bindings + SetSessions(sessions []session.Session) + SetSelectedSession(sessionID string) +} + +type sessionDialogCmp struct { + sessions []session.Session + selectedIdx int + width int + height int + selectedSessionID string +} + +type sessionKeyMap struct { + Up key.Binding + Down key.Binding + Enter key.Binding + Escape key.Binding + J key.Binding + K key.Binding +} + +var sessionKeys = sessionKeyMap{ + Up: key.NewBinding( + key.WithKeys("up"), + key.WithHelp("↑", "previous session"), + ), + Down: key.NewBinding( + key.WithKeys("down"), + key.WithHelp("↓", "next session"), + ), + Enter: key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "select session"), + ), + Escape: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "close"), + ), + J: key.NewBinding( + key.WithKeys("j"), + key.WithHelp("j", "next session"), + ), + K: key.NewBinding( + key.WithKeys("k"), + key.WithHelp("k", "previous session"), + ), +} + +func (s *sessionDialogCmp) Init() tea.Cmd { + return nil +} + +func (s *sessionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch { + case key.Matches(msg, sessionKeys.Up) || key.Matches(msg, sessionKeys.K): + if s.selectedIdx > 0 { + s.selectedIdx-- + } + return s, nil + case key.Matches(msg, sessionKeys.Down) || key.Matches(msg, sessionKeys.J): + if s.selectedIdx < len(s.sessions)-1 { + s.selectedIdx++ + } + return s, nil + case key.Matches(msg, sessionKeys.Enter): + if len(s.sessions) > 0 { + return s, util.CmdHandler(SessionSelectedMsg{ + Session: s.sessions[s.selectedIdx], + }) + } + case key.Matches(msg, sessionKeys.Escape): + return s, util.CmdHandler(CloseSessionDialogMsg{}) + } + case tea.WindowSizeMsg: + s.width = msg.Width + s.height = msg.Height + } + return s, nil +} + +func (s *sessionDialogCmp) View() string { + if len(s.sessions) == 0 { + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(40). + Render("No sessions available") + } + + // Calculate max width needed for session titles + maxWidth := 40 // Minimum width + for _, sess := range s.sessions { + if len(sess.Title) > maxWidth-4 { // Account for padding + maxWidth = len(sess.Title) + 4 + } + } + + // Limit height to avoid taking up too much screen space + maxVisibleSessions := min(10, len(s.sessions)) + + // Build the session list + sessionItems := make([]string, 0, maxVisibleSessions) + startIdx := 0 + + // If we have more sessions than can be displayed, adjust the start index + if len(s.sessions) > maxVisibleSessions { + // Center the selected item when possible + halfVisible := maxVisibleSessions / 2 + if s.selectedIdx >= halfVisible && s.selectedIdx < len(s.sessions)-halfVisible { + startIdx = s.selectedIdx - halfVisible + } else if s.selectedIdx >= len(s.sessions)-halfVisible { + startIdx = len(s.sessions) - maxVisibleSessions + } + } + + endIdx := min(startIdx+maxVisibleSessions, len(s.sessions)) + + for i := startIdx; i < endIdx; i++ { + sess := s.sessions[i] + itemStyle := styles.BaseStyle.Width(maxWidth) + + if i == s.selectedIdx { + itemStyle = itemStyle. + Background(styles.PrimaryColor). + Foreground(styles.Background). + Bold(true) + } + + sessionItems = append(sessionItems, itemStyle.Padding(0, 1).Render(sess.Title)) + } + + title := styles.BaseStyle. + Foreground(styles.PrimaryColor). + Bold(true). + Padding(0, 1). + Render("Switch Session") + + content := lipgloss.JoinVertical( + lipgloss.Left, + title, + styles.BaseStyle.Render(""), + lipgloss.JoinVertical(lipgloss.Left, sessionItems...), + styles.BaseStyle.Render(""), + styles.BaseStyle.Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), + ) + + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(lipgloss.Width(content) + 4). + Render(content) +} + +func (s *sessionDialogCmp) BindingKeys() []key.Binding { + return layout.KeyMapToSlice(sessionKeys) +} + +func (s *sessionDialogCmp) SetSessions(sessions []session.Session) { + s.sessions = sessions + + // If we have a selected session ID, find its index + if s.selectedSessionID != "" { + for i, sess := range sessions { + if sess.ID == s.selectedSessionID { + s.selectedIdx = i + return + } + } + } + + // Default to first session if selected not found + s.selectedIdx = 0 +} + +func (s *sessionDialogCmp) SetSelectedSession(sessionID string) { + s.selectedSessionID = sessionID + + // Update the selected index if sessions are already loaded + if len(s.sessions) > 0 { + for i, sess := range s.sessions { + if sess.ID == sessionID { + s.selectedIdx = i + return + } + } + } +} + +// NewSessionDialogCmp creates a new session switching dialog +func NewSessionDialogCmp() SessionDialog { + return &sessionDialogCmp{ + sessions: []session.Session{}, + selectedIdx: 0, + selectedSessionID: "", + } +} \ No newline at end of file diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index 7c74da10497fa2e8882be95b10e4014247e7b302..fa49adbbb2862a7b85e3974c120f2a16f331ed44 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -119,27 +119,17 @@ func (i *detailCmp) GetSize() (int, int) { return i.width, i.height } -func (i *detailCmp) SetSize(width int, height int) { +func (i *detailCmp) SetSize(width int, height int) tea.Cmd { i.width = width i.height = height i.viewport.Width = i.width i.viewport.Height = i.height i.updateContent() + return nil } func (i *detailCmp) BindingKeys() []key.Binding { - return []key.Binding{ - i.viewport.KeyMap.PageDown, - i.viewport.KeyMap.PageUp, - i.viewport.KeyMap.HalfPageDown, - i.viewport.KeyMap.HalfPageUp, - } -} - -func (i *detailCmp) BorderText() map[layout.BorderPosition]string { - return map[layout.BorderPosition]string{ - layout.TopLeftBorder: "Log Details", - } + return layout.KeyMapToSlice(i.viewport.KeyMap) } func NewLogsDetails() DetailComponent { diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 2d0f9c533da41eb58c16836278500d780831984e..245714d0db1606eea469daadd8d45c30e8f9a991 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -68,7 +68,7 @@ func (i *tableCmp) GetSize() (int, int) { return i.table.Width(), i.table.Height() } -func (i *tableCmp) SetSize(width int, height int) { +func (i *tableCmp) SetSize(width int, height int) tea.Cmd { i.table.SetWidth(width) i.table.SetHeight(height) cloumns := i.table.Columns() @@ -77,6 +77,7 @@ func (i *tableCmp) SetSize(width int, height int) { cloumns[i] = col } i.table.SetColumns(cloumns) + return nil } func (i *tableCmp) BindingKeys() []key.Binding { diff --git a/internal/tui/layout/bento.go b/internal/tui/layout/bento.go deleted file mode 100644 index c47c4e0907d49a7b606bde8aaedceb16a9030f85..0000000000000000000000000000000000000000 --- a/internal/tui/layout/bento.go +++ /dev/null @@ -1,392 +0,0 @@ -package layout - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" -) - -type paneID string - -const ( - BentoLeftPane paneID = "left" - BentoRightTopPane paneID = "right-top" - BentoRightBottomPane paneID = "right-bottom" -) - -type BentoPanes map[paneID]tea.Model - -const ( - defaultLeftWidthRatio = 0.2 - defaultRightTopHeightRatio = 0.85 - - minLeftWidth = 10 - minRightBottomHeight = 10 -) - -type BentoLayout interface { - tea.Model - Sizeable - Bindings -} - -type BentoKeyBindings struct { - SwitchPane key.Binding - SwitchPaneBack key.Binding - HideCurrentPane key.Binding - ShowAllPanes key.Binding -} - -var defaultBentoKeyBindings = BentoKeyBindings{ - SwitchPane: key.NewBinding( - key.WithKeys("tab"), - key.WithHelp("tab", "switch pane"), - ), - SwitchPaneBack: key.NewBinding( - key.WithKeys("shift+tab"), - key.WithHelp("shift+tab", "switch pane back"), - ), - HideCurrentPane: key.NewBinding( - key.WithKeys("X"), - key.WithHelp("X", "hide current pane"), - ), - ShowAllPanes: key.NewBinding( - key.WithKeys("R"), - key.WithHelp("R", "show all panes"), - ), -} - -type bentoLayout struct { - width int - height int - - leftWidthRatio float64 - rightTopHeightRatio float64 - - currentPane paneID - panes map[paneID]SinglePaneLayout - hiddenPanes map[paneID]bool -} - -func (b *bentoLayout) GetSize() (int, int) { - return b.width, b.height -} - -func (b *bentoLayout) Init() tea.Cmd { - var cmds []tea.Cmd - for _, pane := range b.panes { - cmd := pane.Init() - if cmd != nil { - cmds = append(cmds, cmd) - } - } - return tea.Batch(cmds...) -} - -func (b *bentoLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - b.SetSize(msg.Width, msg.Height) - return b, nil - case tea.KeyMsg: - switch { - case key.Matches(msg, defaultBentoKeyBindings.SwitchPane): - return b, b.SwitchPane(false) - case key.Matches(msg, defaultBentoKeyBindings.SwitchPaneBack): - return b, b.SwitchPane(true) - case key.Matches(msg, defaultBentoKeyBindings.HideCurrentPane): - return b, b.HidePane(b.currentPane) - case key.Matches(msg, defaultBentoKeyBindings.ShowAllPanes): - for id := range b.hiddenPanes { - delete(b.hiddenPanes, id) - } - b.SetSize(b.width, b.height) - return b, nil - } - } - - var cmds []tea.Cmd - for id, pane := range b.panes { - u, cmd := pane.Update(msg) - b.panes[id] = u.(SinglePaneLayout) - if cmd != nil { - cmds = append(cmds, cmd) - } - } - return b, tea.Batch(cmds...) -} - -func (b *bentoLayout) View() string { - if b.width <= 0 || b.height <= 0 { - return "" - } - - for id, pane := range b.panes { - if b.currentPane == id { - pane.Focus() - } else { - pane.Blur() - } - } - - leftVisible := false - rightTopVisible := false - rightBottomVisible := false - - var leftPane, rightTopPane, rightBottomPane string - - if pane, ok := b.panes[BentoLeftPane]; ok && !b.hiddenPanes[BentoLeftPane] { - leftPane = pane.View() - leftVisible = true - } - - if pane, ok := b.panes[BentoRightTopPane]; ok && !b.hiddenPanes[BentoRightTopPane] { - rightTopPane = pane.View() - rightTopVisible = true - } - - if pane, ok := b.panes[BentoRightBottomPane]; ok && !b.hiddenPanes[BentoRightBottomPane] { - rightBottomPane = pane.View() - rightBottomVisible = true - } - - if leftVisible { - if rightTopVisible || rightBottomVisible { - rightSection := "" - if rightTopVisible && rightBottomVisible { - rightSection = lipgloss.JoinVertical(lipgloss.Top, rightTopPane, rightBottomPane) - } else if rightTopVisible { - rightSection = rightTopPane - } else { - rightSection = rightBottomPane - } - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render( - lipgloss.JoinHorizontal(lipgloss.Left, leftPane, rightSection), - ) - } else { - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render(leftPane) - } - } else if rightTopVisible || rightBottomVisible { - if rightTopVisible && rightBottomVisible { - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render( - lipgloss.JoinVertical(lipgloss.Top, rightTopPane, rightBottomPane), - ) - } else if rightTopVisible { - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render(rightTopPane) - } else { - return lipgloss.NewStyle().Width(b.width).Height(b.height).Render(rightBottomPane) - } - } - return "" -} - -func (b *bentoLayout) SetSize(width int, height int) { - if width < 0 || height < 0 { - return - } - b.width = width - b.height = height - - leftExists := false - rightTopExists := false - rightBottomExists := false - - if _, ok := b.panes[BentoLeftPane]; ok && !b.hiddenPanes[BentoLeftPane] { - leftExists = true - } - if _, ok := b.panes[BentoRightTopPane]; ok && !b.hiddenPanes[BentoRightTopPane] { - rightTopExists = true - } - if _, ok := b.panes[BentoRightBottomPane]; ok && !b.hiddenPanes[BentoRightBottomPane] { - rightBottomExists = true - } - - leftWidth := 0 - rightWidth := 0 - rightTopHeight := 0 - rightBottomHeight := 0 - - if leftExists && (rightTopExists || rightBottomExists) { - leftWidth = int(float64(width) * b.leftWidthRatio) - if leftWidth < minLeftWidth && width >= minLeftWidth { - leftWidth = minLeftWidth - } - rightWidth = width - leftWidth - - if rightTopExists && rightBottomExists { - rightTopHeight = int(float64(height) * b.rightTopHeightRatio) - rightBottomHeight = height - rightTopHeight - - if rightBottomHeight < minRightBottomHeight && height >= minRightBottomHeight { - rightBottomHeight = minRightBottomHeight - rightTopHeight = height - rightBottomHeight - } - } else if rightTopExists { - rightTopHeight = height - } else if rightBottomExists { - rightBottomHeight = height - } - } else if leftExists { - leftWidth = width - } else if rightTopExists || rightBottomExists { - rightWidth = width - - if rightTopExists && rightBottomExists { - rightTopHeight = int(float64(height) * b.rightTopHeightRatio) - rightBottomHeight = height - rightTopHeight - - if rightBottomHeight < minRightBottomHeight && height >= minRightBottomHeight { - rightBottomHeight = minRightBottomHeight - rightTopHeight = height - rightBottomHeight - } - } else if rightTopExists { - rightTopHeight = height - } else if rightBottomExists { - rightBottomHeight = height - } - } - - if pane, ok := b.panes[BentoLeftPane]; ok && !b.hiddenPanes[BentoLeftPane] { - pane.SetSize(leftWidth, height) - } - if pane, ok := b.panes[BentoRightTopPane]; ok && !b.hiddenPanes[BentoRightTopPane] { - pane.SetSize(rightWidth, rightTopHeight) - } - if pane, ok := b.panes[BentoRightBottomPane]; ok && !b.hiddenPanes[BentoRightBottomPane] { - pane.SetSize(rightWidth, rightBottomHeight) - } -} - -func (b *bentoLayout) HidePane(pane paneID) tea.Cmd { - if len(b.panes)-len(b.hiddenPanes) == 1 { - return nil - } - if _, ok := b.panes[pane]; ok { - b.hiddenPanes[pane] = true - } - b.SetSize(b.width, b.height) - return b.SwitchPane(false) -} - -func (b *bentoLayout) SwitchPane(back bool) tea.Cmd { - orderForward := []paneID{BentoLeftPane, BentoRightTopPane, BentoRightBottomPane} - orderBackward := []paneID{BentoLeftPane, BentoRightBottomPane, BentoRightTopPane} - - order := orderForward - if back { - order = orderBackward - } - - currentIdx := -1 - for i, id := range order { - if id == b.currentPane { - currentIdx = i - break - } - } - - if currentIdx == -1 { - for _, id := range order { - if _, exists := b.panes[id]; exists { - if _, hidden := b.hiddenPanes[id]; !hidden { - b.currentPane = id - break - } - } - } - } else { - startIdx := currentIdx - for { - currentIdx = (currentIdx + 1) % len(order) - - nextID := order[currentIdx] - if _, exists := b.panes[nextID]; exists { - if _, hidden := b.hiddenPanes[nextID]; !hidden { - b.currentPane = nextID - break - } - } - - if currentIdx == startIdx { - break - } - } - } - - var cmds []tea.Cmd - for id, pane := range b.panes { - if _, ok := b.hiddenPanes[id]; ok { - continue - } - if id == b.currentPane { - cmds = append(cmds, pane.Focus()) - } else { - cmds = append(cmds, pane.Blur()) - } - } - - return tea.Batch(cmds...) -} - -func (s *bentoLayout) BindingKeys() []key.Binding { - bindings := KeyMapToSlice(defaultBentoKeyBindings) - if b, ok := s.panes[s.currentPane].(Bindings); ok { - bindings = append(bindings, b.BindingKeys()...) - } - return bindings -} - -type BentoLayoutOption func(*bentoLayout) - -func NewBentoLayout(panes BentoPanes, opts ...BentoLayoutOption) BentoLayout { - p := make(map[paneID]SinglePaneLayout, len(panes)) - for id, pane := range panes { - if sp, ok := pane.(SinglePaneLayout); !ok { - p[id] = NewSinglePane( - pane, - WithSinglePaneFocusable(true), - WithSinglePaneBordered(true), - ) - } else { - p[id] = sp - } - } - if len(p) == 0 { - panic("no panes provided for BentoLayout") - } - layout := &bentoLayout{ - panes: p, - hiddenPanes: make(map[paneID]bool), - currentPane: BentoLeftPane, - leftWidthRatio: defaultLeftWidthRatio, - rightTopHeightRatio: defaultRightTopHeightRatio, - } - - for _, opt := range opts { - opt(layout) - } - - return layout -} - -func WithBentoLayoutLeftWidthRatio(ratio float64) BentoLayoutOption { - return func(b *bentoLayout) { - if ratio > 0 && ratio < 1 { - b.leftWidthRatio = ratio - } - } -} - -func WithBentoLayoutRightTopHeightRatio(ratio float64) BentoLayoutOption { - return func(b *bentoLayout) { - if ratio > 0 && ratio < 1 { - b.rightTopHeightRatio = ratio - } - } -} - -func WithBentoLayoutCurrentPane(pane paneID) BentoLayoutOption { - return func(b *bentoLayout) { - b.currentPane = pane - } -} diff --git a/internal/tui/layout/border.go b/internal/tui/layout/border.go deleted file mode 100644 index ea9f5e0bc50c1d11710c8f784971f1037110c4da..0000000000000000000000000000000000000000 --- a/internal/tui/layout/border.go +++ /dev/null @@ -1,121 +0,0 @@ -package layout - -import ( - "fmt" - "strings" - - "github.com/charmbracelet/lipgloss" - "github.com/kujtimiihoxha/opencode/internal/tui/styles" -) - -type BorderPosition int - -const ( - TopLeftBorder BorderPosition = iota - TopMiddleBorder - TopRightBorder - BottomLeftBorder - BottomMiddleBorder - BottomRightBorder -) - -var ( - ActiveBorder = styles.Blue - InactivePreviewBorder = styles.Grey -) - -type BorderOptions struct { - Active bool - EmbeddedText map[BorderPosition]string - ActiveColor lipgloss.TerminalColor - InactiveColor lipgloss.TerminalColor - ActiveBorder lipgloss.Border - InactiveBorder lipgloss.Border -} - -func Borderize(content string, opts BorderOptions) string { - if opts.EmbeddedText == nil { - opts.EmbeddedText = make(map[BorderPosition]string) - } - if opts.ActiveColor == nil { - opts.ActiveColor = ActiveBorder - } - if opts.InactiveColor == nil { - opts.InactiveColor = InactivePreviewBorder - } - if opts.ActiveBorder == (lipgloss.Border{}) { - opts.ActiveBorder = lipgloss.ThickBorder() - } - if opts.InactiveBorder == (lipgloss.Border{}) { - opts.InactiveBorder = lipgloss.NormalBorder() - } - - var ( - thickness = map[bool]lipgloss.Border{ - true: opts.ActiveBorder, - false: opts.InactiveBorder, - } - color = map[bool]lipgloss.TerminalColor{ - true: opts.ActiveColor, - false: opts.InactiveColor, - } - border = thickness[opts.Active] - style = lipgloss.NewStyle().Foreground(color[opts.Active]) - width = lipgloss.Width(content) - ) - - encloseInSquareBrackets := func(text string) string { - if text != "" { - return fmt.Sprintf("%s%s%s", - style.Render(border.TopRight), - text, - style.Render(border.TopLeft), - ) - } - return text - } - buildHorizontalBorder := func(leftText, middleText, rightText, leftCorner, inbetween, rightCorner string) string { - leftText = encloseInSquareBrackets(leftText) - middleText = encloseInSquareBrackets(middleText) - rightText = encloseInSquareBrackets(rightText) - // Calculate length of border between embedded texts - remaining := max(0, width-lipgloss.Width(leftText)-lipgloss.Width(middleText)-lipgloss.Width(rightText)) - leftBorderLen := max(0, (width/2)-lipgloss.Width(leftText)-(lipgloss.Width(middleText)/2)) - rightBorderLen := max(0, remaining-leftBorderLen) - // Then construct border string - s := leftText + - style.Render(strings.Repeat(inbetween, leftBorderLen)) + - middleText + - style.Render(strings.Repeat(inbetween, rightBorderLen)) + - rightText - // Make it fit in the space available between the two corners. - s = lipgloss.NewStyle(). - Inline(true). - MaxWidth(width). - Render(s) - // Add the corners - return style.Render(leftCorner) + s + style.Render(rightCorner) - } - // Stack top border, content and horizontal borders, and bottom border. - return strings.Join([]string{ - buildHorizontalBorder( - opts.EmbeddedText[TopLeftBorder], - opts.EmbeddedText[TopMiddleBorder], - opts.EmbeddedText[TopRightBorder], - border.TopLeft, - border.Top, - border.TopRight, - ), - lipgloss.NewStyle(). - BorderForeground(color[opts.Active]). - Border(border, false, true, false, true).Render(content), - buildHorizontalBorder( - opts.EmbeddedText[BottomLeftBorder], - opts.EmbeddedText[BottomMiddleBorder], - opts.EmbeddedText[BottomRightBorder], - border.BottomLeft, - border.Bottom, - border.BottomRight, - ), - }, "\n") -} diff --git a/internal/tui/layout/container.go b/internal/tui/layout/container.go index c86d954ead7c3db2733d853c2ef208f45270ad67..fdb9ab40362320b191e2af96775e9ac5857a2d8c 100644 --- a/internal/tui/layout/container.go +++ b/internal/tui/layout/container.go @@ -86,7 +86,7 @@ func (c *container) View() string { return style.Render(c.content.View()) } -func (c *container) SetSize(width, height int) { +func (c *container) SetSize(width, height int) tea.Cmd { c.width = width c.height = height @@ -113,8 +113,9 @@ func (c *container) SetSize(width, height int) { // Set content size with adjusted dimensions contentWidth := max(0, width-horizontalSpace) contentHeight := max(0, height-verticalSpace) - sizeable.SetSize(contentWidth, contentHeight) + return sizeable.SetSize(contentWidth, contentHeight) } + return nil } func (c *container) GetSize() (int, int) { diff --git a/internal/tui/layout/grid.go b/internal/tui/layout/grid.go deleted file mode 100644 index 6be493e2c2f7536612660350e7e017f7e9600444..0000000000000000000000000000000000000000 --- a/internal/tui/layout/grid.go +++ /dev/null @@ -1,254 +0,0 @@ -package layout - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" -) - -type GridLayout interface { - tea.Model - Sizeable - Bindings - Panes() [][]tea.Model -} - -type gridLayout struct { - width int - height int - - rows int - columns int - - panes [][]tea.Model - - gap int - bordered bool - focusable bool - - currentRow int - currentColumn int - - activeColor lipgloss.TerminalColor -} - -type GridOption func(*gridLayout) - -func (g *gridLayout) Init() tea.Cmd { - var cmds []tea.Cmd - for i := range g.panes { - for j := range g.panes[i] { - if g.panes[i][j] != nil { - cmds = append(cmds, g.panes[i][j].Init()) - } - } - } - return tea.Batch(cmds...) -} - -func (g *gridLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - var cmds []tea.Cmd - - switch msg := msg.(type) { - case tea.WindowSizeMsg: - g.SetSize(msg.Width, msg.Height) - return g, nil - case tea.KeyMsg: - if key.Matches(msg, g.nextPaneBinding()) { - return g.focusNextPane() - } - } - - // Update all panes - for i := range g.panes { - for j := range g.panes[i] { - if g.panes[i][j] != nil { - var cmd tea.Cmd - g.panes[i][j], cmd = g.panes[i][j].Update(msg) - if cmd != nil { - cmds = append(cmds, cmd) - } - } - } - } - - return g, tea.Batch(cmds...) -} - -func (g *gridLayout) focusNextPane() (tea.Model, tea.Cmd) { - if !g.focusable { - return g, nil - } - - var cmds []tea.Cmd - - // Blur current pane - if g.currentRow < len(g.panes) && g.currentColumn < len(g.panes[g.currentRow]) { - if currentPane, ok := g.panes[g.currentRow][g.currentColumn].(Focusable); ok { - cmds = append(cmds, currentPane.Blur()) - } - } - - // Find next valid pane - g.currentColumn++ - if g.currentColumn >= len(g.panes[g.currentRow]) { - g.currentColumn = 0 - g.currentRow++ - if g.currentRow >= len(g.panes) { - g.currentRow = 0 - } - } - - // Focus next pane - if g.currentRow < len(g.panes) && g.currentColumn < len(g.panes[g.currentRow]) { - if nextPane, ok := g.panes[g.currentRow][g.currentColumn].(Focusable); ok { - cmds = append(cmds, nextPane.Focus()) - } - } - - return g, tea.Batch(cmds...) -} - -func (g *gridLayout) nextPaneBinding() key.Binding { - return key.NewBinding( - key.WithKeys("tab"), - key.WithHelp("tab", "next pane"), - ) -} - -func (g *gridLayout) View() string { - if len(g.panes) == 0 { - return "" - } - - // Calculate dimensions for each cell - cellWidth := (g.width - (g.columns-1)*g.gap) / g.columns - cellHeight := (g.height - (g.rows-1)*g.gap) / g.rows - - // Render each row - rows := make([]string, g.rows) - for i := range g.rows { - // Render each column in this row - cols := make([]string, len(g.panes[i])) - for j := range g.panes[i] { - if g.panes[i][j] == nil { - cols[j] = "" - continue - } - - // Set size for each pane - if sizable, ok := g.panes[i][j].(Sizeable); ok { - effectiveWidth, effectiveHeight := cellWidth, cellHeight - if g.bordered { - effectiveWidth -= 2 - effectiveHeight -= 2 - } - sizable.SetSize(effectiveWidth, effectiveHeight) - } - - // Render the pane - content := g.panes[i][j].View() - - // Apply border if needed - if g.bordered { - isFocused := false - if focusable, ok := g.panes[i][j].(Focusable); ok { - isFocused = focusable.IsFocused() - } - - borderText := map[BorderPosition]string{} - if bordered, ok := g.panes[i][j].(Bordered); ok { - borderText = bordered.BorderText() - } - - content = Borderize(content, BorderOptions{ - Active: isFocused, - EmbeddedText: borderText, - }) - } - - cols[j] = content - } - - // Join columns with gap - rows[i] = lipgloss.JoinHorizontal(lipgloss.Top, cols...) - } - - // Join rows with gap - return lipgloss.JoinVertical(lipgloss.Left, rows...) -} - -func (g *gridLayout) SetSize(width, height int) { - g.width = width - g.height = height -} - -func (g *gridLayout) GetSize() (int, int) { - return g.width, g.height -} - -func (g *gridLayout) BindingKeys() []key.Binding { - var bindings []key.Binding - bindings = append(bindings, g.nextPaneBinding()) - - // Collect bindings from all panes - for i := range g.panes { - for j := range g.panes[i] { - if g.panes[i][j] != nil { - if bindable, ok := g.panes[i][j].(Bindings); ok { - bindings = append(bindings, bindable.BindingKeys()...) - } - } - } - } - - return bindings -} - -func (g *gridLayout) Panes() [][]tea.Model { - return g.panes -} - -// NewGridLayout creates a new grid layout with the given number of rows and columns -func NewGridLayout(rows, cols int, panes [][]tea.Model, opts ...GridOption) GridLayout { - grid := &gridLayout{ - rows: rows, - columns: cols, - panes: panes, - gap: 1, - } - - for _, opt := range opts { - opt(grid) - } - - return grid -} - -// WithGridGap sets the gap between cells -func WithGridGap(gap int) GridOption { - return func(g *gridLayout) { - g.gap = gap - } -} - -// WithGridBordered sets whether cells should have borders -func WithGridBordered(bordered bool) GridOption { - return func(g *gridLayout) { - g.bordered = bordered - } -} - -// WithGridFocusable sets whether the grid supports focus navigation -func WithGridFocusable(focusable bool) GridOption { - return func(g *gridLayout) { - g.focusable = focusable - } -} - -// WithGridActiveColor sets the active border color -func WithGridActiveColor(color lipgloss.TerminalColor) GridOption { - return func(g *gridLayout) { - g.activeColor = color - } -} diff --git a/internal/tui/layout/layout.go b/internal/tui/layout/layout.go index 2f17c4a0e967b1a5e77d2e67c8f5c0f33455169b..495a3fbc5140917b35c342e96672aa4dd8ee4b18 100644 --- a/internal/tui/layout/layout.go +++ b/internal/tui/layout/layout.go @@ -13,12 +13,8 @@ type Focusable interface { IsFocused() bool } -type Bordered interface { - BorderText() map[BorderPosition]string -} - type Sizeable interface { - SetSize(width, height int) + SetSize(width, height int) tea.Cmd GetSize() (int, int) } diff --git a/internal/tui/layout/single.go b/internal/tui/layout/single.go deleted file mode 100644 index c77fa0d78e4b73bb3a79e4164ef264bcdc38aa02..0000000000000000000000000000000000000000 --- a/internal/tui/layout/single.go +++ /dev/null @@ -1,189 +0,0 @@ -package layout - -import ( - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" -) - -type SinglePaneLayout interface { - tea.Model - Focusable - Sizeable - Bindings - Pane() tea.Model -} - -type singlePaneLayout struct { - width int - height int - - focusable bool - focused bool - - bordered bool - borderText map[BorderPosition]string - - content tea.Model - - padding []int - - activeColor lipgloss.TerminalColor -} - -type SinglePaneOption func(*singlePaneLayout) - -func (s *singlePaneLayout) Init() tea.Cmd { - return s.content.Init() -} - -func (s *singlePaneLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - s.SetSize(msg.Width, msg.Height) - return s, nil - } - u, cmd := s.content.Update(msg) - s.content = u - return s, cmd -} - -func (s *singlePaneLayout) View() string { - style := lipgloss.NewStyle().Width(s.width).Height(s.height) - if s.bordered { - style = style.Width(s.width - 2).Height(s.height - 2) - } - if s.padding != nil { - style = style.Padding(s.padding...) - } - content := style.Render(s.content.View()) - if s.bordered { - if s.borderText == nil { - s.borderText = map[BorderPosition]string{} - } - if bordered, ok := s.content.(Bordered); ok { - s.borderText = bordered.BorderText() - } - return Borderize(content, BorderOptions{ - Active: s.focused, - EmbeddedText: s.borderText, - }) - } - return content -} - -func (s *singlePaneLayout) Blur() tea.Cmd { - if s.focusable { - s.focused = false - } - if blurable, ok := s.content.(Focusable); ok { - return blurable.Blur() - } - return nil -} - -func (s *singlePaneLayout) Focus() tea.Cmd { - if s.focusable { - s.focused = true - } - if focusable, ok := s.content.(Focusable); ok { - return focusable.Focus() - } - return nil -} - -func (s *singlePaneLayout) SetSize(width, height int) { - s.width = width - s.height = height - childWidth, childHeight := s.width, s.height - if s.bordered { - childWidth -= 2 - childHeight -= 2 - } - if s.padding != nil { - if len(s.padding) == 1 { - childWidth -= s.padding[0] * 2 - childHeight -= s.padding[0] * 2 - } else if len(s.padding) == 2 { - childWidth -= s.padding[0] * 2 - childHeight -= s.padding[1] * 2 - } else if len(s.padding) == 3 { - childWidth -= s.padding[0] * 2 - childHeight -= s.padding[1] + s.padding[2] - } else if len(s.padding) == 4 { - childWidth -= s.padding[0] + s.padding[2] - childHeight -= s.padding[1] + s.padding[3] - } - } - if s.content != nil { - if c, ok := s.content.(Sizeable); ok { - c.SetSize(childWidth, childHeight) - } - } -} - -func (s *singlePaneLayout) IsFocused() bool { - return s.focused -} - -func (s *singlePaneLayout) GetSize() (int, int) { - return s.width, s.height -} - -func (s *singlePaneLayout) BindingKeys() []key.Binding { - if b, ok := s.content.(Bindings); ok { - return b.BindingKeys() - } - return []key.Binding{} -} - -func (s *singlePaneLayout) Pane() tea.Model { - return s.content -} - -func NewSinglePane(content tea.Model, opts ...SinglePaneOption) SinglePaneLayout { - layout := &singlePaneLayout{ - content: content, - } - for _, opt := range opts { - opt(layout) - } - return layout -} - -func WithSinglePaneSize(width, height int) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.width = width - opts.height = height - } -} - -func WithSinglePaneFocusable(focusable bool) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.focusable = focusable - } -} - -func WithSinglePaneBordered(bordered bool) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.bordered = bordered - } -} - -func WithSinglePaneBorderText(borderText map[BorderPosition]string) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.borderText = borderText - } -} - -func WithSinglePanePadding(padding ...int) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.padding = padding - } -} - -func WithSinglePaneActiveColor(color lipgloss.TerminalColor) SinglePaneOption { - return func(opts *singlePaneLayout) { - opts.activeColor = color - } -} diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index bfb616a5364da9e62fe5058ddf28bb87780527c5..a41df6ab82199468e8627d4f96919630455b6b5c 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -11,9 +11,9 @@ type SplitPaneLayout interface { tea.Model Sizeable Bindings - SetLeftPanel(panel Container) - SetRightPanel(panel Container) - SetBottomPanel(panel Container) + SetLeftPanel(panel Container) tea.Cmd + SetRightPanel(panel Container) tea.Cmd + SetBottomPanel(panel Container) tea.Cmd } type splitPaneLayout struct { @@ -53,8 +53,7 @@ func (s *splitPaneLayout) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: - s.SetSize(msg.Width, msg.Height) - return s, nil + return s, s.SetSize(msg.Width, msg.Height) } if s.rightPanel != nil { @@ -122,7 +121,7 @@ func (s *splitPaneLayout) View() string { return finalView } -func (s *splitPaneLayout) SetSize(width, height int) { +func (s *splitPaneLayout) SetSize(width, height int) tea.Cmd { s.width = width s.height = height @@ -147,42 +146,50 @@ func (s *splitPaneLayout) SetSize(width, height int) { rightWidth = width } + var cmds []tea.Cmd if s.leftPanel != nil { - s.leftPanel.SetSize(leftWidth, topHeight) + cmd := s.leftPanel.SetSize(leftWidth, topHeight) + cmds = append(cmds, cmd) } if s.rightPanel != nil { - s.rightPanel.SetSize(rightWidth, topHeight) + cmd := s.rightPanel.SetSize(rightWidth, topHeight) + cmds = append(cmds, cmd) } if s.bottomPanel != nil { - s.bottomPanel.SetSize(width, bottomHeight) + cmd := s.bottomPanel.SetSize(width, bottomHeight) + cmds = append(cmds, cmd) } + return tea.Batch(cmds...) } func (s *splitPaneLayout) GetSize() (int, int) { return s.width, s.height } -func (s *splitPaneLayout) SetLeftPanel(panel Container) { +func (s *splitPaneLayout) SetLeftPanel(panel Container) tea.Cmd { s.leftPanel = panel if s.width > 0 && s.height > 0 { - s.SetSize(s.width, s.height) + return s.SetSize(s.width, s.height) } + return nil } -func (s *splitPaneLayout) SetRightPanel(panel Container) { +func (s *splitPaneLayout) SetRightPanel(panel Container) tea.Cmd { s.rightPanel = panel if s.width > 0 && s.height > 0 { - s.SetSize(s.width, s.height) + return s.SetSize(s.width, s.height) } + return nil } -func (s *splitPaneLayout) SetBottomPanel(panel Container) { +func (s *splitPaneLayout) SetBottomPanel(panel Container) tea.Cmd { s.bottomPanel = panel if s.width > 0 && s.height > 0 { - s.SetSize(s.width, s.height) + return s.SetSize(s.width, s.height) } + return nil } func (s *splitPaneLayout) BindingKeys() []key.Binding { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index 632e107641c0c7fb1fd2339ec4ea304f895948f5..b99dc3dfe387002a735bd821db4d7717f544cf80 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -54,9 +54,11 @@ func (p *chatPage) Init() tea.Cmd { } func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: - p.layout.SetSize(msg.Width, msg.Height) + cmd := p.layout.SetSize(msg.Width, msg.Height) + cmds = append(cmds, cmd) case chat.SendMsg: cmd := p.sendMessage(msg.Text) if cmd != nil { @@ -68,8 +70,10 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch { case key.Matches(msg, keyMap.NewSession): p.session = session.Session{} - p.clearSidebar() - return p, util.CmdHandler(chat.SessionClearedMsg{}) + return p, tea.Batch( + p.clearSidebar(), + util.CmdHandler(chat.SessionClearedMsg{}), + ) case key.Matches(msg, keyMap.Cancel): if p.session.ID != "" { // Cancel the current session's generation process @@ -80,11 +84,9 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } u, cmd := p.layout.Update(msg) + cmds = append(cmds, cmd) p.layout = u.(layout.SplitPaneLayout) - if cmd != nil { - return p, cmd - } - return p, nil + return p, tea.Batch(cmds...) } func (p *chatPage) setSidebar() tea.Cmd { @@ -92,16 +94,11 @@ func (p *chatPage) setSidebar() tea.Cmd { chat.NewSidebarCmp(p.session, p.app.History), layout.WithPadding(1, 1, 1, 1), ) - p.layout.SetRightPanel(sidebarContainer) - width, height := p.layout.GetSize() - p.layout.SetSize(width, height) - return sidebarContainer.Init() + return tea.Batch(p.layout.SetRightPanel(sidebarContainer), sidebarContainer.Init()) } -func (p *chatPage) clearSidebar() { - p.layout.SetRightPanel(nil) - width, height := p.layout.GetSize() - p.layout.SetSize(width, height) +func (p *chatPage) clearSidebar() tea.Cmd { + return p.layout.SetRightPanel(nil) } func (p *chatPage) sendMessage(text string) tea.Cmd { @@ -124,8 +121,8 @@ func (p *chatPage) sendMessage(text string) tea.Cmd { return tea.Batch(cmds...) } -func (p *chatPage) SetSize(width, height int) { - p.layout.SetSize(width, height) +func (p *chatPage) SetSize(width, height int) tea.Cmd { + return p.layout.SetSize(width, height) } func (p *chatPage) GetSize() (int, int) { diff --git a/internal/tui/page/logs.go b/internal/tui/page/logs.go index 0efc69e6e4f39bbcc8eb0db2b08a7b5ebae92255..f0d35fb7b1c9d07d63aeb37c09a4d6fbbc5854d8 100644 --- a/internal/tui/page/logs.go +++ b/internal/tui/page/logs.go @@ -23,15 +23,14 @@ type logsPage struct { } func (p *logsPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd switch msg := msg.(type) { case tea.WindowSizeMsg: p.width = msg.Width p.height = msg.Height - p.table.SetSize(msg.Width, msg.Height/2) - p.details.SetSize(msg.Width, msg.Height/2) + return p, p.SetSize(msg.Width, msg.Height) } - var cmds []tea.Cmd table, cmd := p.table.Update(msg) cmds = append(cmds, cmd) p.table = table.(layout.Container) @@ -60,11 +59,13 @@ func (p *logsPage) GetSize() (int, int) { } // SetSize implements LogPage. -func (p *logsPage) SetSize(width int, height int) { +func (p *logsPage) SetSize(width int, height int) tea.Cmd { p.width = width p.height = height - p.table.SetSize(width, height/2) - p.details.SetSize(width, height/2) + return tea.Batch( + p.table.SetSize(width, height/2), + p.details.SetSize(width, height/2), + ) } func (p *logsPage) Init() tea.Cmd { diff --git a/internal/tui/styles/background.go b/internal/tui/styles/background.go index bf6cbc1059f81d54cda3e7c3de10b925ab11a160..2fbb34efbbe52ecd5e233c33ee32fbb2981fb8f1 100644 --- a/internal/tui/styles/background.go +++ b/internal/tui/styles/background.go @@ -3,7 +3,6 @@ package styles import ( "fmt" "regexp" - "strconv" "strings" "github.com/charmbracelet/lipgloss" @@ -25,57 +24,100 @@ func getColorRGB(c lipgloss.TerminalColor) (uint8, uint8, uint8) { return uint8(r >> 8), uint8(g >> 8), uint8(b >> 8) } +// ForceReplaceBackgroundWithLipgloss replaces any ANSI background color codes +// in `input` with a single 24‑bit background (48;2;R;G;B). func ForceReplaceBackgroundWithLipgloss(input string, newBgColor lipgloss.TerminalColor) string { + // Precompute our new-bg sequence once r, g, b := getColorRGB(newBgColor) - newBg := fmt.Sprintf("48;2;%d;%d;%d", r, g, b) return ansiEscape.ReplaceAllStringFunc(input, func(seq string) string { - // Extract content between "\x1b[" and "m" - content := seq[2 : len(seq)-1] - tokens := strings.Split(content, ";") - var newTokens []string - - // Skip background color tokens - for i := 0; i < len(tokens); i++ { - if tokens[i] == "" { - continue - } + const ( + escPrefixLen = 2 // "\x1b[" + escSuffixLen = 1 // "m" + ) + + raw := seq + start := escPrefixLen + end := len(raw) - escSuffixLen - val, err := strconv.Atoi(tokens[i]) - if err != nil { - newTokens = append(newTokens, tokens[i]) - continue + var sb strings.Builder + // reserve enough space: original content minus bg codes + our newBg + sb.Grow((end - start) + len(newBg) + 2) + + // scan from start..end, token by token + for i := start; i < end; { + // find the next ';' or end + j := i + for j < end && raw[j] != ';' { + j++ } + token := raw[i:j] - // Skip background color tokens - if val == 48 { - // Skip "48;5;N" or "48;2;R;G;B" sequences - if i+1 < len(tokens) { - if nextVal, err := strconv.Atoi(tokens[i+1]); err == nil { - switch nextVal { - case 5: - i += 2 // Skip "5" and color index - case 2: - i += 4 // Skip "2" and RGB components + // fast‑path: skip "48;5;N" or "48;2;R;G;B" + if len(token) == 2 && token[0] == '4' && token[1] == '8' { + k := j + 1 + if k < end { + // find next token + l := k + for l < end && raw[l] != ';' { + l++ + } + next := raw[k:l] + if next == "5" { + // skip "48;5;N" + m := l + 1 + for m < end && raw[m] != ';' { + m++ + } + i = m + 1 + continue + } else if next == "2" { + // skip "48;2;R;G;B" + m := l + 1 + for count := 0; count < 3 && m < end; count++ { + for m < end && raw[m] != ';' { + m++ + } + m++ } + i = m + continue } } - } else if (val < 40 || val > 47) && (val < 100 || val > 107) && val != 49 { - // Keep non-background tokens - newTokens = append(newTokens, tokens[i]) } - } - // Add new background if provided - if newBg != "" { - newTokens = append(newTokens, strings.Split(newBg, ";")...) + // decide whether to keep this token + // manually parse ASCII digits to int + isNum := true + val := 0 + for p := i; p < j; p++ { + c := raw[p] + if c < '0' || c > '9' { + isNum = false + break + } + val = val*10 + int(c-'0') + } + keep := !isNum || + ((val < 40 || val > 47) && (val < 100 || val > 107) && val != 49) + + if keep { + if sb.Len() > 0 { + sb.WriteByte(';') + } + sb.WriteString(token) + } + // advance past this token (and the semicolon) + i = j + 1 } - if len(newTokens) == 0 { - return "" + // append our new background + if sb.Len() > 0 { + sb.WriteByte(';') } + sb.WriteString(newBg) - return "\x1b[" + strings.Join(newTokens, ";") + "m" + return "\x1b[" + sb.String() + "m" }) } diff --git a/internal/tui/styles/icons.go b/internal/tui/styles/icons.go index aa0df1e31ca994cab9b294694851787fb66f2e02..dd5f4dc51e1e00201bfa0d8dd52a1d728ffc0ede 100644 --- a/internal/tui/styles/icons.go +++ b/internal/tui/styles/icons.go @@ -2,19 +2,11 @@ package styles const ( OpenCodeIcon string = "⌬" - SessionsIcon string = "󰧑" - ChatIcon string = "󰭹" - - BotIcon string = "󰚩" - ToolIcon string = "" - UserIcon string = "" CheckIcon string = "✓" - ErrorIcon string = "" - WarningIcon string = "" + ErrorIcon string = "✖" + WarningIcon string = "⚠" InfoIcon string = "" - HintIcon string = "" + HintIcon string = "i" SpinnerIcon string = "..." - BugIcon string = "" - SleepIcon string = "󰒲" ) diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 840ad4905875a31ae0a588fca378d875a5512105..f3a7298cf0936b78c650a0024ffc25d6dff78cbe 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,6 +1,8 @@ package tui import ( + "context" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" @@ -8,6 +10,7 @@ import ( "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/kujtimiihoxha/opencode/internal/permission" "github.com/kujtimiihoxha/opencode/internal/pubsub" + "github.com/kujtimiihoxha/opencode/internal/tui/components/chat" "github.com/kujtimiihoxha/opencode/internal/tui/components/core" "github.com/kujtimiihoxha/opencode/internal/tui/components/dialog" "github.com/kujtimiihoxha/opencode/internal/tui/layout" @@ -16,9 +19,10 @@ import ( ) type keyMap struct { - Logs key.Binding - Quit key.Binding - Help key.Binding + Logs key.Binding + Quit key.Binding + Help key.Binding + SwitchSession key.Binding } var keys = keyMap{ @@ -35,6 +39,10 @@ var keys = keyMap{ key.WithKeys("ctrl+_"), key.WithHelp("ctrl+?", "toggle help"), ), + SwitchSession: key.NewBinding( + key.WithKeys("ctrl+a"), + key.WithHelp("ctrl+a", "switch session"), + ), } var returnKey = key.NewBinding( @@ -64,6 +72,9 @@ type appModel struct { showQuit bool quit dialog.QuitDialog + + showSessionDialog bool + sessionDialog dialog.SessionDialog } func (a appModel) Init() tea.Cmd { @@ -77,6 +88,8 @@ func (a appModel) Init() tea.Cmd { cmds = append(cmds, cmd) cmd = a.help.Init() cmds = append(cmds, cmd) + cmd = a.sessionDialog.Init() + cmds = append(cmds, cmd) return tea.Batch(cmds...) } @@ -100,6 +113,10 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.help = help.(dialog.HelpCmp) cmds = append(cmds, helpCmd) + session, sessionCmd := a.sessionDialog.Update(msg) + a.sessionDialog = session.(dialog.SessionDialog) + cmds = append(cmds, sessionCmd) + return a, tea.Batch(cmds...) // Status @@ -144,8 +161,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Permission case pubsub.Event[permission.PermissionRequest]: a.showPermissions = true - a.permissions.SetPermissions(msg.Payload) - return a, nil + return a, a.permissions.SetPermissions(msg.Payload) case dialog.PermissionResponseMsg: switch msg.Action { case dialog.PermissionAllow: @@ -165,6 +181,19 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showQuit = false return a, nil + case dialog.CloseSessionDialogMsg: + a.showSessionDialog = false + return a, nil + + case chat.SessionSelectedMsg: + a.sessionDialog.SetSelectedSession(msg.ID) + case dialog.SessionSelectedMsg: + a.showSessionDialog = false + if a.currentPage == page.ChatPage { + return a, util.CmdHandler(chat.SessionSelectedMsg(msg.Session)) + } + return a, nil + case tea.KeyMsg: switch { case key.Matches(msg, keys.Quit): @@ -172,6 +201,24 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if a.showHelp { a.showHelp = false } + if a.showSessionDialog { + a.showSessionDialog = false + } + return a, nil + case key.Matches(msg, keys.SwitchSession): + if a.currentPage == page.ChatPage && !a.showQuit && !a.showPermissions { + // Load sessions and show the dialog + sessions, err := a.app.Sessions.List(context.Background()) + if err != nil { + return a, util.ReportError(err) + } + if len(sessions) == 0 { + return a, util.ReportWarn("No sessions available") + } + a.sessionDialog.SetSessions(sessions) + a.showSessionDialog = true + return a, nil + } return a, nil case key.Matches(msg, logsKeyReturnKey): if a.currentPage == page.LogsPage { @@ -216,6 +263,16 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } + if a.showSessionDialog { + d, sessionCmd := a.sessionDialog.Update(msg) + a.sessionDialog = d.(dialog.SessionDialog) + cmds = append(cmds, sessionCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } + } + a.status, _ = a.status.Update(msg) a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) @@ -223,18 +280,24 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { - var cmd tea.Cmd + if a.app.CoderAgent.IsBusy() { + // For now we don't move to any page if the agent is busy + return util.ReportWarn("Agent is busy, please wait...") + } + var cmds []tea.Cmd if _, ok := a.loadedPages[pageID]; !ok { - cmd = a.pages[pageID].Init() + cmd := a.pages[pageID].Init() + cmds = append(cmds, cmd) a.loadedPages[pageID] = true } a.previousPage = a.currentPage a.currentPage = pageID if sizable, ok := a.pages[a.currentPage].(layout.Sizeable); ok { - sizable.SetSize(a.width, a.height) + cmd := sizable.SetSize(a.width, a.height) + cmds = append(cmds, cmd) } - return cmd + return tea.Batch(cmds...) } func (a appModel) View() string { @@ -304,19 +367,35 @@ func (a appModel) View() string { ) } + if a.showSessionDialog { + overlay := a.sessionDialog.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } + return appView } func New(app *app.App) tea.Model { startPage := page.ChatPage return &appModel{ - currentPage: startPage, - loadedPages: make(map[page.PageID]bool), - status: core.NewStatusCmp(app.LSPClients), - help: dialog.NewHelpCmp(), - quit: dialog.NewQuitCmp(), - permissions: dialog.NewPermissionDialogCmp(), - app: app, + currentPage: startPage, + loadedPages: make(map[page.PageID]bool), + status: core.NewStatusCmp(app.LSPClients), + help: dialog.NewHelpCmp(), + quit: dialog.NewQuitCmp(), + sessionDialog: dialog.NewSessionDialogCmp(), + permissions: dialog.NewPermissionDialogCmp(), + app: app, pages: map[page.PageID]tea.Model{ page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), From 72afeb9f54cee8e248093a52ac0779441c79aea3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 18 Apr 2025 21:24:35 +0200 Subject: [PATCH 27/41] small fixes --- internal/diff/patch.go | 33 ++++++------- internal/llm/agent/agent.go | 2 + internal/llm/tools/edit.go | 6 +-- internal/tui/components/chat/editor.go | 9 +++- internal/tui/components/chat/list.go | 49 ++++++++++++++------ internal/tui/components/chat/message.go | 1 + internal/tui/components/dialog/permission.go | 22 +++++---- internal/tui/components/dialog/session.go | 48 ++++++++++--------- internal/tui/page/chat.go | 7 --- internal/tui/tui.go | 4 +- 10 files changed, 106 insertions(+), 75 deletions(-) diff --git a/internal/diff/patch.go b/internal/diff/patch.go index aab0f956dcdad92b6ef6c468940f544c51cb2106..49242f7efc1c5c091f65a6d6de0827aef031e414 100644 --- a/internal/diff/patch.go +++ b/internal/diff/patch.go @@ -91,11 +91,9 @@ func (p *Parser) isDone(prefixes []string) bool { if p.index >= len(p.lines) { return true } - if prefixes != nil { - for _, prefix := range prefixes { - if strings.HasPrefix(p.lines[p.index], prefix) { - return true - } + for _, prefix := range prefixes { + if strings.HasPrefix(p.lines[p.index], prefix) { + return true } } return false @@ -219,7 +217,7 @@ func (p *Parser) parseUpdateFile(text string) (PatchAction, error) { sectionStr = p.lines[p.index] p.index++ } - if !(defStr != "" || sectionStr != "" || index == 0) { + if defStr == "" && sectionStr == "" && index != 0 { return action, NewDiffError(fmt.Sprintf("Invalid Line:\n%s", p.lines[p.index])) } if strings.TrimSpace(defStr) != "" { @@ -433,12 +431,13 @@ func peekNextSection(lines []string, initialIndex int) ([]string, []Chunk, int, delLines = make([]string, 0, 8) insLines = make([]string, 0, 8) } - if mode == "delete" { + switch mode { + case "delete": delLines = append(delLines, line) old = append(old, line) - } else if mode == "add" { + case "add": insLines = append(insLines, line) - } else { + default: old = append(old, line) } } @@ -513,7 +512,7 @@ func IdentifyFilesAdded(text string) []string { func getUpdatedFile(text string, action PatchAction, path string) (string, error) { if action.Type != ActionUpdate { - return "", errors.New("Expected UPDATE action") + return "", errors.New("expected UPDATE action") } origLines := strings.Split(text, "\n") destLines := make([]string, 0, len(origLines)) // Preallocate with capacity @@ -543,18 +542,19 @@ func getUpdatedFile(text string, action PatchAction, path string) (string, error func PatchToCommit(patch Patch, orig map[string]string) (Commit, error) { commit := Commit{Changes: make(map[string]FileChange, len(patch.Actions))} for pathKey, action := range patch.Actions { - if action.Type == ActionDelete { + switch action.Type { + case ActionDelete: oldContent := orig[pathKey] commit.Changes[pathKey] = FileChange{ Type: ActionDelete, OldContent: &oldContent, } - } else if action.Type == ActionAdd { + case ActionAdd: commit.Changes[pathKey] = FileChange{ Type: ActionAdd, NewContent: action.NewFile, } - } else if action.Type == ActionUpdate { + case ActionUpdate: newContent, err := getUpdatedFile(orig[pathKey], action, pathKey) if err != nil { return Commit{}, err @@ -619,18 +619,19 @@ func LoadFiles(paths []string, openFn func(string) (string, error)) (map[string] func ApplyCommit(commit Commit, writeFn func(string, string) error, removeFn func(string) error) error { for p, change := range commit.Changes { - if change.Type == ActionDelete { + switch change.Type { + case ActionDelete: if err := removeFn(p); err != nil { return err } - } else if change.Type == ActionAdd { + case ActionAdd: if change.NewContent == nil { return NewDiffError(fmt.Sprintf("Add action for %s has nil new_content", p)) } if err := writeFn(p, *change.NewContent); err != nil { return err } - } else if change.Type == ActionUpdate { + case ActionUpdate: if change.NewContent == nil { return NewDiffError(fmt.Sprintf("Update action for %s has nil new_content", p)) } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 5e9785991d311349b339d12a3a02d64dcdda76d8..7542d9adf436642aee9b5613e1cfc8e515b8aee7 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -221,6 +221,8 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory) if err != nil { if errors.Is(err, context.Canceled) { + agentMessage.AddFinish(message.FinishReasonCanceled) + a.messages.Update(context.Background(), agentMessage) return a.err(ErrRequestCancelled) } return a.err(fmt.Errorf("failed to process events: %w", err)) diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 6a16160109cffbdcc61edf66d00d84217dcc4868..83cec5dbafbca173f19ca06bd9a275dc5faa1b28 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -141,20 +141,20 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if params.OldString == "" { response, err = e.createNewFile(ctx, params.FilePath, params.NewString) if err != nil { - return response, nil + return response, err } } if params.NewString == "" { response, err = e.deleteContent(ctx, params.FilePath, params.OldString) if err != nil { - return response, nil + return response, err } } response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) if err != nil { - return response, nil + return response, err } if response.IsError { // Return early if there was an error during content replacement diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index 537ef392c2f9e3ef14a67d0da777196f3a13240e..963fbbdbfce8d233d05df1fbd912af2eec2651f7 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -21,6 +21,8 @@ type editorCmp struct { textarea textarea.Model } +type FocusEditorMsg bool + type focusedEditorKeyMaps struct { Send key.Binding OpenEditor key.Binding @@ -112,7 +114,6 @@ func (m *editorCmp) send() tea.Cmd { util.CmdHandler(SendMsg{ Text: value, }), - util.CmdHandler(EditorFocusMsg(false)), ) } @@ -124,9 +125,13 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg } return m, nil + case FocusEditorMsg: + if msg { + m.textarea.Focus() + return m, tea.Batch(textarea.Blink, util.CmdHandler(EditorFocusMsg(true))) + } case tea.KeyMsg: if key.Matches(msg, focusedKeyMaps.OpenEditor) { - m.textarea.Blur() return m, openEditor() } // if the key does not match any binding, return diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go index f95b53731773aff6e05292821ad13910bd8c66bb..b7703e2cc28c2a949fc82e40327cf35de69ed0cd 100644 --- a/internal/tui/components/chat/list.go +++ b/internal/tui/components/chat/list.go @@ -22,6 +22,10 @@ import ( "github.com/kujtimiihoxha/opencode/internal/tui/util" ) +type cacheItem struct { + width int + content []uiMessage +} type messagesCmp struct { app *app.App width, height int @@ -32,8 +36,9 @@ type messagesCmp struct { uiMessages []uiMessage currentMsgID string mutex sync.Mutex - cachedContent map[string][]uiMessage + cachedContent map[string]cacheItem spinner spinner.Model + lastUpdate time.Time rendering bool } type renderFinishedMsg struct{} @@ -44,6 +49,8 @@ func (m *messagesCmp) Init() tea.Cmd { func (m *messagesCmp) preloadSessions() tea.Cmd { return func() tea.Msg { + m.mutex.Lock() + defer m.mutex.Unlock() sessions, err := m.app.Sessions.List(context.Background()) if err != nil { return util.ReportError(err)() @@ -67,13 +74,13 @@ func (m *messagesCmp) preloadSessions() tea.Cmd { } logging.Debug("preloaded sessions") - return nil + return func() tea.Msg { + return renderFinishedMsg{} + } } } func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int) { - m.mutex.Lock() - defer m.mutex.Unlock() pos := 0 if m.width == 0 { return @@ -87,7 +94,10 @@ func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int width, pos, ) - m.cachedContent[msg.ID] = []uiMessage{userMsg} + m.cachedContent[msg.ID] = cacheItem{ + width: width, + content: []uiMessage{userMsg}, + } pos += userMsg.height + 1 // + 1 for spacing case message.Assistant: assistantMessages := renderAssistantMessage( @@ -102,7 +112,10 @@ func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int for _, msg := range assistantMessages { pos += msg.height + 1 // + 1 for spacing } - m.cachedContent[msg.ID] = assistantMessages + m.cachedContent[msg.ID] = cacheItem{ + width: width, + content: assistantMessages, + } } } } @@ -223,8 +236,8 @@ func (m *messagesCmp) renderView() { for inx, msg := range m.messages { switch msg.Role { case message.User: - if messages, ok := m.cachedContent[msg.ID]; ok { - m.uiMessages = append(m.uiMessages, messages...) + if cache, ok := m.cachedContent[msg.ID]; ok && cache.width == m.width { + m.uiMessages = append(m.uiMessages, cache.content...) continue } userMsg := renderUserMessage( @@ -234,11 +247,14 @@ func (m *messagesCmp) renderView() { pos, ) m.uiMessages = append(m.uiMessages, userMsg) - m.cachedContent[msg.ID] = []uiMessage{userMsg} + m.cachedContent[msg.ID] = cacheItem{ + width: m.width, + content: []uiMessage{userMsg}, + } pos += userMsg.height + 1 // + 1 for spacing case message.Assistant: - if messages, ok := m.cachedContent[msg.ID]; ok { - m.uiMessages = append(m.uiMessages, messages...) + if cache, ok := m.cachedContent[msg.ID]; ok && cache.width == m.width { + m.uiMessages = append(m.uiMessages, cache.content...) continue } assistantMessages := renderAssistantMessage( @@ -254,7 +270,10 @@ func (m *messagesCmp) renderView() { m.uiMessages = append(m.uiMessages, msg) pos += msg.height + 1 // + 1 for spacing } - m.cachedContent[msg.ID] = assistantMessages + m.cachedContent[msg.ID] = cacheItem{ + width: m.width, + content: assistantMessages, + } } } @@ -418,6 +437,10 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd { m.height = height m.viewport.Width = width m.viewport.Height = height - 2 + for _, msg := range m.messages { + delete(m.cachedContent, msg.ID) + } + m.uiMessages = make([]uiMessage, 0) m.renderView() return m.preloadSessions() } @@ -456,7 +479,7 @@ func NewMessagesCmp(app *app.App) tea.Model { return &messagesCmp{ app: app, writingMode: true, - cachedContent: make(map[string][]uiMessage), + cachedContent: make(map[string]cacheItem), viewport: viewport.New(0, 0), spinner: s, } diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go index be6c7ce5087276e7a616af1ed5da8933206ea972..7a840b4ec56e9d5225f35abdb7dad758091427ff 100644 --- a/internal/tui/components/chat/message.go +++ b/internal/tui/components/chat/message.go @@ -389,6 +389,7 @@ func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, errContent := fmt.Sprintf("Error: %s", strings.ReplaceAll(response.Content, "\n", " ")) errContent = ansi.Truncate(errContent, width-1, "...") return styles.BaseStyle. + Width(width). Foreground(styles.Error). Render(errContent) } diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index f83472e68291b7deac8a243badb7ccab5db8c345..1f8df21a0b2cdfdd319ae8fd43c05178e1ad28f8 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -40,7 +40,8 @@ type PermissionDialogCmp interface { } type permissionsMapping struct { - LeftRight key.Binding + Left key.Binding + Right key.Binding EnterSpace key.Binding Allow key.Binding AllowSession key.Binding @@ -49,9 +50,13 @@ type permissionsMapping struct { } var permissionsKeys = permissionsMapping{ - LeftRight: key.NewBinding( - key.WithKeys("left", "right"), - key.WithHelp("←/→", "switch options"), + Left: key.NewBinding( + key.WithKeys("left"), + key.WithHelp("←", "switch options"), + ), + Right: key.NewBinding( + key.WithKeys("right"), + key.WithHelp("→", "switch options"), ), EnterSpace: key.NewBinding( key.WithKeys("enter", " "), @@ -104,21 +109,18 @@ func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.diffCache = make(map[string]string) case tea.KeyMsg: switch { - case key.Matches(msg, permissionsKeys.LeftRight) || key.Matches(msg, permissionsKeys.Tab): - // Change selected option + case key.Matches(msg, permissionsKeys.Right) || key.Matches(msg, permissionsKeys.Tab): p.selectedOption = (p.selectedOption + 1) % 3 return p, nil + case key.Matches(msg, permissionsKeys.Left): + p.selectedOption = (p.selectedOption + 2) % 3 case key.Matches(msg, permissionsKeys.EnterSpace): - // Select current option return p, p.selectCurrentOption() case key.Matches(msg, permissionsKeys.Allow): - // Select Allow return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllow, Permission: p.permission}) case key.Matches(msg, permissionsKeys.AllowSession): - // Select Allow for session return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllowForSession, Permission: p.permission}) case key.Matches(msg, permissionsKeys.Deny): - // Select Deny return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionDeny, Permission: p.permission}) default: // Pass other keys to viewport diff --git a/internal/tui/components/dialog/session.go b/internal/tui/components/dialog/session.go index d8c859c495345c6c7bb75c9f55070b21e4ed4b7a..060875f91e612e5658c254447f04f7b83e955ac9 100644 --- a/internal/tui/components/dialog/session.go +++ b/internal/tui/components/dialog/session.go @@ -27,20 +27,20 @@ type SessionDialog interface { } type sessionDialogCmp struct { - sessions []session.Session - selectedIdx int - width int - height int + sessions []session.Session + selectedIdx int + width int + height int selectedSessionID string } type sessionKeyMap struct { - Up key.Binding - Down key.Binding - Enter key.Binding - Escape key.Binding - J key.Binding - K key.Binding + Up key.Binding + Down key.Binding + Enter key.Binding + Escape key.Binding + J key.Binding + K key.Binding } var sessionKeys = sessionKeyMap{ @@ -128,7 +128,7 @@ func (s *sessionDialogCmp) View() string { // Build the session list sessionItems := make([]string, 0, maxVisibleSessions) startIdx := 0 - + // If we have more sessions than can be displayed, adjust the start index if len(s.sessions) > maxVisibleSessions { // Center the selected item when possible @@ -145,30 +145,31 @@ func (s *sessionDialogCmp) View() string { for i := startIdx; i < endIdx; i++ { sess := s.sessions[i] itemStyle := styles.BaseStyle.Width(maxWidth) - + if i == s.selectedIdx { itemStyle = itemStyle. Background(styles.PrimaryColor). Foreground(styles.Background). Bold(true) } - + sessionItems = append(sessionItems, itemStyle.Padding(0, 1).Render(sess.Title)) } title := styles.BaseStyle. Foreground(styles.PrimaryColor). Bold(true). + Width(maxWidth). Padding(0, 1). Render("Switch Session") content := lipgloss.JoinVertical( lipgloss.Left, title, - styles.BaseStyle.Render(""), - lipgloss.JoinVertical(lipgloss.Left, sessionItems...), - styles.BaseStyle.Render(""), - styles.BaseStyle.Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), + styles.BaseStyle.Width(maxWidth).Render(""), + styles.BaseStyle.Width(maxWidth).Render(lipgloss.JoinVertical(lipgloss.Left, sessionItems...)), + styles.BaseStyle.Width(maxWidth).Render(""), + styles.BaseStyle.Width(maxWidth).Padding(0, 1).Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), ) return styles.BaseStyle.Padding(1, 2). @@ -185,7 +186,7 @@ func (s *sessionDialogCmp) BindingKeys() []key.Binding { func (s *sessionDialogCmp) SetSessions(sessions []session.Session) { s.sessions = sessions - + // If we have a selected session ID, find its index if s.selectedSessionID != "" { for i, sess := range sessions { @@ -195,14 +196,14 @@ func (s *sessionDialogCmp) SetSessions(sessions []session.Session) { } } } - + // Default to first session if selected not found s.selectedIdx = 0 } func (s *sessionDialogCmp) SetSelectedSession(sessionID string) { s.selectedSessionID = sessionID - + // Update the selected index if sessions are already loaded if len(s.sessions) > 0 { for i, sess := range s.sessions { @@ -217,8 +218,9 @@ func (s *sessionDialogCmp) SetSelectedSession(sessionID string) { // NewSessionDialogCmp creates a new session switching dialog func NewSessionDialogCmp() SessionDialog { return &sessionDialogCmp{ - sessions: []session.Session{}, - selectedIdx: 0, + sessions: []session.Session{}, + selectedIdx: 0, selectedSessionID: "", } -} \ No newline at end of file +} + diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index b99dc3dfe387002a735bd821db4d7717f544cf80..ef826e9a335e4df8592523212e345f435f0c1e37 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -43,13 +43,6 @@ func (p *chatPage) Init() tea.Cmd { cmds := []tea.Cmd{ p.layout.Init(), } - - sessions, _ := p.app.Sessions.List(context.Background()) - if len(sessions) > 0 { - p.session = sessions[0] - cmd := p.setSidebar() - cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(p.session)), cmd) - } return tea.Batch(cmds...) } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index f3a7298cf0936b78c650a0024ffc25d6dff78cbe..2a9ed0d70d193f0fdb568cd372b8d31241f43b87 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -163,6 +163,7 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showPermissions = true return a, a.permissions.SetPermissions(msg.Payload) case dialog.PermissionResponseMsg: + var cmd tea.Cmd switch msg.Action { case dialog.PermissionAllow: a.app.Permissions.Grant(msg.Permission) @@ -170,9 +171,10 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.app.Permissions.GrantPersistant(msg.Permission) case dialog.PermissionDeny: a.app.Permissions.Deny(msg.Permission) + cmd = util.CmdHandler(chat.FocusEditorMsg(true)) } a.showPermissions = false - return a, nil + return a, cmd case page.PageChangeMsg: return a, a.moveToPage(msg.ID) From bf8cd3bd128d97cf93fcd1481c37db0e46945fd2 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 19 Apr 2025 13:01:18 +0200 Subject: [PATCH 28/41] add bedrock model --- internal/llm/models/models.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index bbce6130e244d1fd69dabd50ea7c806741e5ce24..aba4a10c34909d094cb83b1c947f4c87dce1fe02 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -80,16 +80,16 @@ var SupportedModels = map[ModelID]Model{ // }, // // // Bedrock - // BedrockClaude37Sonnet: { - // ID: BedrockClaude37Sonnet, - // Name: "Bedrock: Claude 3.7 Sonnet", - // Provider: ProviderBedrock, - // APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", - // CostPer1MIn: 3.0, - // CostPer1MInCached: 3.75, - // CostPer1MOutCached: 0.30, - // CostPer1MOut: 15.0, - // }, + BedrockClaude37Sonnet: { + ID: BedrockClaude37Sonnet, + Name: "Bedrock: Claude 3.7 Sonnet", + Provider: ProviderBedrock, + APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + }, } func init() { From 2b5a33e476ae3c6b5c6345777d20792786836dda Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 19 Apr 2025 15:15:29 +0200 Subject: [PATCH 29/41] lsp improvements --- internal/app/app.go | 3 +- internal/app/lsp.go | 29 +- internal/config/config.go | 12 +- internal/llm/models/gemini.go | 63 ++++ internal/llm/models/models.go | 5 +- internal/llm/prompt/coder.go | 2 + internal/llm/provider/gemini.go | 1 - internal/llm/tools/grep.go | 91 +++-- internal/lsp/client.go | 364 +++++++++++++++++++- internal/lsp/watcher/watcher.go | 428 +++++++++++++++++++++--- internal/tui/components/chat/chat.go | 12 +- internal/tui/components/chat/list.go | 9 +- internal/tui/components/chat/message.go | 6 +- internal/tui/components/core/status.go | 17 + internal/tui/styles/icons.go | 5 +- 15 files changed, 921 insertions(+), 126 deletions(-) create mode 100644 internal/llm/models/gemini.go diff --git a/internal/app/app.go b/internal/app/app.go index 8f4f5e098197530812a21c8d84973f570920458b..36b1ca16f9b70ac1b46c4de01612162d63d16f86 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -49,7 +49,8 @@ func New(ctx context.Context, conn *sql.DB) (*App, error) { LSPClients: make(map[string]*lsp.Client), } - app.initLSPClients(ctx) + // Initialize LSP clients in the background + go app.initLSPClients(ctx) var err error app.CoderAgent, err = agent.NewAgent( diff --git a/internal/app/lsp.go b/internal/app/lsp.go index d8a35c8b3a9646ab079bfa4fe6151fcb06aaf45f..77feeb94359a87af028dfc3dbc50075c76975ade 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -15,24 +15,28 @@ func (app *App) initLSPClients(ctx context.Context) { // Initialize LSP clients for name, clientConfig := range cfg.LSP { - app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...) + // Start each client initialization in its own goroutine + go app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...) } + logging.Info("LSP clients initialization started in background") } // createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) { // Create a specific context for initialization with a timeout - + logging.Info("Creating LSP client", "name", name, "command", command, "args", args) + // Create the LSP client lspClient, err := lsp.NewClient(ctx, command, args...) if err != nil { logging.Error("Failed to create LSP client for", name, err) return - } - initCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + // Create a longer timeout for initialization (some servers take time to start) + initCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() + // Initialize with the initialization context _, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory()) if err != nil { @@ -42,8 +46,25 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, comman return } + // Wait for the server to be ready + if err := lspClient.WaitForServerReady(initCtx); err != nil { + logging.Error("Server failed to become ready", "name", name, "error", err) + // We'll continue anyway, as some functionality might still work + lspClient.SetServerState(lsp.StateError) + } else { + logging.Info("LSP server is ready", "name", name) + lspClient.SetServerState(lsp.StateReady) + } + + logging.Info("LSP client initialized", "name", name) + // Create a child context that can be canceled when the app is shutting down watchCtx, cancelFunc := context.WithCancel(ctx) + + // Create a context with the server name for better identification + watchCtx = context.WithValue(watchCtx, "serverName", name) + + // Create the workspace watcher workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) // Store the cancel function to be called during cleanup diff --git a/internal/config/config.go b/internal/config/config.go index 0cb727158aa5ff413caec01d9d990305ebb37572..2dbbcc9ca5153d518e6b9592fef8503d6b38b5d0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -209,17 +209,17 @@ func setProviderDefaults() { // Google Gemini configuration if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { viper.SetDefault("providers.gemini.apiKey", apiKey) - viper.SetDefault("agents.coder.model", models.GRMINI20Flash) - viper.SetDefault("agents.task.model", models.GRMINI20Flash) - viper.SetDefault("agents.title.model", models.GRMINI20Flash) + viper.SetDefault("agents.coder.model", models.Gemini25) + viper.SetDefault("agents.task.model", models.Gemini25Flash) + viper.SetDefault("agents.title.model", models.Gemini25Flash) } // OpenAI configuration if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { viper.SetDefault("providers.openai.apiKey", apiKey) - viper.SetDefault("agents.coder.model", models.GPT4o) - viper.SetDefault("agents.task.model", models.GPT4o) - viper.SetDefault("agents.title.model", models.GPT4o) + viper.SetDefault("agents.coder.model", models.GPT41) + viper.SetDefault("agents.task.model", models.GPT41Mini) + viper.SetDefault("agents.title.model", models.GPT41Mini) } diff --git a/internal/llm/models/gemini.go b/internal/llm/models/gemini.go new file mode 100644 index 0000000000000000000000000000000000000000..00bf7387f52fc93ba12538ddd49cfa96b88b4b25 --- /dev/null +++ b/internal/llm/models/gemini.go @@ -0,0 +1,63 @@ +package models + +const ( + ProviderGemini ModelProvider = "gemini" + + // Models + Gemini25Flash ModelID = "gemini-2.5-flash" + Gemini25 ModelID = "gemini-2.5" + Gemini20Flash ModelID = "gemini-2.0-flash" + Gemini20FlashLite ModelID = "gemini-2.0-flash-lite" +) + +var GeminiModels = map[ModelID]Model{ + Gemini25Flash: { + ID: Gemini25Flash, + Name: "Gemini 2.5 Flash", + Provider: ProviderGemini, + APIModel: "gemini-2.5-flash-preview-04-17", + CostPer1MIn: 0.15, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + CostPer1MOut: 0.60, + ContextWindow: 1000000, + DefaultMaxTokens: 50000, + }, + Gemini25: { + ID: Gemini25, + Name: "Gemini 2.5 Pro", + Provider: ProviderGemini, + APIModel: "gemini-2.5-pro-preview-03-25", + CostPer1MIn: 1.25, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + CostPer1MOut: 10, + ContextWindow: 1000000, + DefaultMaxTokens: 50000, + }, + + Gemini20Flash: { + ID: Gemini20Flash, + Name: "Gemini 2.0 Flash", + Provider: ProviderGemini, + APIModel: "gemini-2.0-flash", + CostPer1MIn: 0.10, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + CostPer1MOut: 0.40, + ContextWindow: 1000000, + DefaultMaxTokens: 6000, + }, + Gemini20FlashLite: { + ID: Gemini20FlashLite, + Name: "Gemini 2.0 Flash Lite", + Provider: ProviderGemini, + APIModel: "gemini-2.0-flash-lite", + CostPer1MIn: 0.05, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + CostPer1MOut: 0.30, + ContextWindow: 1000000, + DefaultMaxTokens: 6000, + }, +} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index aba4a10c34909d094cb83b1c947f4c87dce1fe02..cccbd2765331ed02f6f5cd840b5d704b09d7030b 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -23,9 +23,6 @@ type Model struct { // Model IDs const ( // GEMINI - GEMINI25 ModelID = "gemini-2.5" - GRMINI20Flash ModelID = "gemini-2.0-flash" - // GROQ QWENQwq ModelID = "qwen-qwq" @@ -35,7 +32,6 @@ const ( // GEMINI const ( ProviderBedrock ModelProvider = "bedrock" - ProviderGemini ModelProvider = "gemini" ProviderGROQ ModelProvider = "groq" // ForTests @@ -95,4 +91,5 @@ var SupportedModels = map[ModelID]Model{ func init() { maps.Copy(SupportedModels, AnthropicModels) maps.Copy(SupportedModels, OpenAIModels) + maps.Copy(SupportedModels, GeminiModels) } diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index d7ca7b2fde3629bfc77dc9105279d61d674dad3f..cc0da03133f0814063a287facde9b656247c859b 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -68,6 +68,7 @@ You MUST adhere to the following criteria when executing the task: - Do NOT show the full contents of large files you have already written, unless the user explicitly asks for them. - When doing things with paths, always use use the full path, if the working directory is /abc/xyz and you want to edit the file abc.go in the working dir refer to it as /abc/xyz/abc.go. - If you send a path not including the working dir, the working dir will be prepended to it. +- Remember the user does not see the full output of tools ` const baseAnthropicCoderPrompt = `You are OpenCode, an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user. @@ -162,6 +163,7 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN # Tool usage policy - When doing file search, prefer to use the Agent tool in order to reduce context usage. - If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in the same function_calls block. +- IMPORTANT: The user does not see the full output of the tool responses, so if you need the output of the tool for the response make sure to summarize it for the user. You MUST answer concisely with fewer than 4 lines of text (not including tool use or code generation), unless user asks for detail.` diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 384bff900aeafeb7b6e066c2e2dbb35616a4e808..a5e6ed8774271a35d10ef242a05aa8de464398b9 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -567,4 +567,3 @@ func contains(s string, substrs ...string) bool { } return false } - diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 086a5e686cce7e90a386a0273a76d62018ea7181..475370ffb1ba01ec2a4e508c31efa13f4bead70a 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -10,6 +10,7 @@ import ( "path/filepath" "regexp" "sort" + "strconv" "strings" "time" @@ -24,8 +25,10 @@ type GrepParams struct { } type grepMatch struct { - path string - modTime time.Time + path string + modTime time.Time + lineNum int + lineText string } type GrepResponseMetadata struct { @@ -147,13 +150,26 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) if len(matches) == 0 { output = "No files found" } else { - output = fmt.Sprintf("Found %d file%s\n%s", - len(matches), - pluralize(len(matches)), - strings.Join(matches, "\n")) + output = fmt.Sprintf("Found %d matches\n", len(matches)) + + currentFile := "" + for _, match := range matches { + if currentFile != match.path { + if currentFile != "" { + output += "\n" + } + currentFile = match.path + output += fmt.Sprintf("%s:\n", match.path) + } + if match.lineNum > 0 { + output += fmt.Sprintf(" Line %d: %s\n", match.lineNum, match.lineText) + } else { + output += fmt.Sprintf(" %s\n", match.path) + } + } if truncated { - output += "\n\n(Results are truncated. Consider using a more specific path or pattern.)" + output += "\n(Results are truncated. Consider using a more specific path or pattern.)" } } @@ -166,14 +182,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) ), nil } -func pluralize(count int) string { - if count == 1 { - return "" - } - return "s" -} - -func searchFiles(pattern, rootPath, include string, limit int) ([]string, bool, error) { +func searchFiles(pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) { matches, err := searchWithRipgrep(pattern, rootPath, include) if err != nil { matches, err = searchFilesWithRegex(pattern, rootPath, include) @@ -191,12 +200,7 @@ func searchFiles(pattern, rootPath, include string, limit int) ([]string, bool, matches = matches[:limit] } - results := make([]string, len(matches)) - for i, m := range matches { - results[i] = m.path - } - - return results, truncated, nil + return matches, truncated, nil } func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) { @@ -205,7 +209,8 @@ func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) { return nil, fmt.Errorf("ripgrep not found: %w", err) } - args := []string{"-l", pattern} + // Use -n to show line numbers and include the matched line + args := []string{"-n", pattern} if include != "" { args = append(args, "--glob", include) } @@ -228,14 +233,29 @@ func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) { continue } - fileInfo, err := os.Stat(line) + // Parse ripgrep output format: file:line:content + parts := strings.SplitN(line, ":", 3) + if len(parts) < 3 { + continue + } + + filePath := parts[0] + lineNum, err := strconv.Atoi(parts[1]) + if err != nil { + continue + } + lineText := parts[2] + + fileInfo, err := os.Stat(filePath) if err != nil { continue // Skip files we can't access } matches = append(matches, grepMatch{ - path: line, - modTime: fileInfo.ModTime(), + path: filePath, + modTime: fileInfo.ModTime(), + lineNum: lineNum, + lineText: lineText, }) } @@ -276,15 +296,17 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error return nil } - match, err := fileContainsPattern(path, regex) + match, lineNum, lineText, err := fileContainsPattern(path, regex) if err != nil { return nil // Skip files we can't read } if match { matches = append(matches, grepMatch{ - path: path, - modTime: info.ModTime(), + path: path, + modTime: info.ModTime(), + lineNum: lineNum, + lineText: lineText, }) if len(matches) >= 200 { @@ -301,21 +323,24 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error return matches, nil } -func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, error) { +func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) { file, err := os.Open(filePath) if err != nil { - return false, err + return false, 0, "", err } defer file.Close() scanner := bufio.NewScanner(file) + lineNum := 0 for scanner.Scan() { - if pattern.MatchString(scanner.Text()) { - return true, nil + lineNum++ + line := scanner.Text() + if pattern.MatchString(line) { + return true, lineNum, line, nil } } - return false, scanner.Err() + return false, 0, "", scanner.Err() } func globToRegex(glob string) string { diff --git a/internal/lsp/client.go b/internal/lsp/client.go index dad07f3c0e9d9a6bf816c8dbfb12361969ba8d68..932badc0b1957d38f5cfe9261b325ece6b9af8f5 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -8,6 +8,7 @@ import ( "io" "os" "os/exec" + "path/filepath" "strings" "sync" "sync/atomic" @@ -46,6 +47,9 @@ type Client struct { // Files are currently opened by the LSP openFiles map[string]*OpenFileInfo openFilesMu sync.RWMutex + + // Server state + serverState atomic.Value } func NewClient(ctx context.Context, command string, args ...string) (*Client, error) { @@ -80,6 +84,9 @@ func NewClient(ctx context.Context, command string, args ...string) (*Client, er openFiles: make(map[string]*OpenFileInfo), } + // Initialize server state + client.serverState.Store(StateStarting) + // Start the LSP server process if err := cmd.Start(); err != nil { return nil, fmt.Errorf("failed to start LSP server: %w", err) @@ -220,16 +227,6 @@ func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) ( return nil, fmt.Errorf("initialization failed: %w", err) } - // LSP sepecific Initialization - path := strings.ToLower(c.Cmd.Path) - switch { - case strings.Contains(path, "typescript-language-server"): - // err := initializeTypescriptLanguageServer(ctx, c, workspaceDir) - // if err != nil { - // return nil, err - // } - } - return &result, nil } @@ -273,10 +270,314 @@ const ( StateError ) +// GetServerState returns the current state of the LSP server +func (c *Client) GetServerState() ServerState { + if val := c.serverState.Load(); val != nil { + return val.(ServerState) + } + return StateStarting +} + +// SetServerState sets the current state of the LSP server +func (c *Client) SetServerState(state ServerState) { + c.serverState.Store(state) +} + +// WaitForServerReady waits for the server to be ready by polling the server +// with a simple request until it responds successfully or times out func (c *Client) WaitForServerReady(ctx context.Context) error { - // TODO: wait for specific messages or poll workspace/symbol - time.Sleep(time.Second * 1) - return nil + cnf := config.Get() + + // Set initial state + c.SetServerState(StateStarting) + + // Create a context with timeout + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // Try to ping the server with a simple request + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + if cnf.DebugLSP { + logging.Debug("Waiting for LSP server to be ready...") + } + + // Determine server type for specialized initialization + serverType := c.detectServerType() + + // For TypeScript-like servers, we need to open some key files first + if serverType == ServerTypeTypeScript { + if cnf.DebugLSP { + logging.Debug("TypeScript-like server detected, opening key configuration files") + } + c.openKeyConfigFiles(ctx) + } + + for { + select { + case <-ctx.Done(): + c.SetServerState(StateError) + return fmt.Errorf("timeout waiting for LSP server to be ready") + case <-ticker.C: + // Try a ping method appropriate for this server type + err := c.pingServerByType(ctx, serverType) + if err == nil { + // Server responded successfully + c.SetServerState(StateReady) + if cnf.DebugLSP { + logging.Debug("LSP server is ready") + } + return nil + } else { + logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType) + } + + if cnf.DebugLSP { + logging.Debug("LSP server not ready yet", "error", err, "serverType", serverType) + } + } + } +} + +// ServerType represents the type of LSP server +type ServerType int + +const ( + ServerTypeUnknown ServerType = iota + ServerTypeGo + ServerTypeTypeScript + ServerTypeRust + ServerTypePython + ServerTypeGeneric +) + +// detectServerType tries to determine what type of LSP server we're dealing with +func (c *Client) detectServerType() ServerType { + if c.Cmd == nil { + return ServerTypeUnknown + } + + cmdPath := strings.ToLower(c.Cmd.Path) + + switch { + case strings.Contains(cmdPath, "gopls"): + return ServerTypeGo + case strings.Contains(cmdPath, "typescript") || strings.Contains(cmdPath, "vtsls") || strings.Contains(cmdPath, "tsserver"): + return ServerTypeTypeScript + case strings.Contains(cmdPath, "rust-analyzer"): + return ServerTypeRust + case strings.Contains(cmdPath, "pyright") || strings.Contains(cmdPath, "pylsp") || strings.Contains(cmdPath, "python"): + return ServerTypePython + default: + return ServerTypeGeneric + } +} + +// openKeyConfigFiles opens important configuration files that help initialize the server +func (c *Client) openKeyConfigFiles(ctx context.Context) { + workDir := config.WorkingDirectory() + serverType := c.detectServerType() + + var filesToOpen []string + + switch serverType { + case ServerTypeTypeScript: + // TypeScript servers need these config files to properly initialize + filesToOpen = []string{ + filepath.Join(workDir, "tsconfig.json"), + filepath.Join(workDir, "package.json"), + filepath.Join(workDir, "jsconfig.json"), + } + + // Also find and open a few TypeScript files to help the server initialize + c.openTypeScriptFiles(ctx, workDir) + case ServerTypeGo: + filesToOpen = []string{ + filepath.Join(workDir, "go.mod"), + filepath.Join(workDir, "go.sum"), + } + case ServerTypeRust: + filesToOpen = []string{ + filepath.Join(workDir, "Cargo.toml"), + filepath.Join(workDir, "Cargo.lock"), + } + } + + // Try to open each file, ignoring errors if they don't exist + for _, file := range filesToOpen { + if _, err := os.Stat(file); err == nil { + // File exists, try to open it + if err := c.OpenFile(ctx, file); err != nil { + logging.Debug("Failed to open key config file", "file", file, "error", err) + } else { + logging.Debug("Opened key config file for initialization", "file", file) + } + } + } +} + +// pingServerByType sends a ping request appropriate for the server type +func (c *Client) pingServerByType(ctx context.Context, serverType ServerType) error { + switch serverType { + case ServerTypeTypeScript: + // For TypeScript, try a document symbol request on an open file + return c.pingTypeScriptServer(ctx) + case ServerTypeGo: + // For Go, workspace/symbol works well + return c.pingWithWorkspaceSymbol(ctx) + case ServerTypeRust: + // For Rust, workspace/symbol works well + return c.pingWithWorkspaceSymbol(ctx) + default: + // Default ping method + return c.pingWithWorkspaceSymbol(ctx) + } +} + +// pingTypeScriptServer tries to ping a TypeScript server with appropriate methods +func (c *Client) pingTypeScriptServer(ctx context.Context) error { + // First try workspace/symbol which works for many servers + if err := c.pingWithWorkspaceSymbol(ctx); err == nil { + return nil + } + + // If that fails, try to find an open file and request document symbols + c.openFilesMu.RLock() + defer c.openFilesMu.RUnlock() + + // If we have any open files, try to get document symbols for one + for uri := range c.openFiles { + filePath := strings.TrimPrefix(uri, "file://") + if strings.HasSuffix(filePath, ".ts") || strings.HasSuffix(filePath, ".js") || + strings.HasSuffix(filePath, ".tsx") || strings.HasSuffix(filePath, ".jsx") { + var symbols []protocol.DocumentSymbol + err := c.Call(ctx, "textDocument/documentSymbol", protocol.DocumentSymbolParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: protocol.DocumentUri(uri), + }, + }, &symbols) + if err == nil { + return nil + } + } + } + + // If we have no open TypeScript files, try to find and open one + workDir := config.WorkingDirectory() + err := filepath.WalkDir(workDir, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + + // Skip directories and non-TypeScript files + if d.IsDir() { + return nil + } + + ext := filepath.Ext(path) + if ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx" { + // Found a TypeScript file, try to open it + if err := c.OpenFile(ctx, path); err == nil { + // Successfully opened, stop walking + return filepath.SkipAll + } + } + + return nil + }) + if err != nil { + logging.Debug("Error walking directory for TypeScript files", "error", err) + } + + // Final fallback - just try a generic capability + return c.pingWithServerCapabilities(ctx) +} + +// openTypeScriptFiles finds and opens TypeScript files to help initialize the server +func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) { + cnf := config.Get() + filesOpened := 0 + maxFilesToOpen := 5 // Limit to a reasonable number of files + + // Find and open TypeScript files + err := filepath.WalkDir(workDir, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + + // Skip directories and non-TypeScript files + if d.IsDir() { + // Skip common directories to avoid wasting time + if shouldSkipDir(path) { + return filepath.SkipDir + } + return nil + } + + // Check if we've opened enough files + if filesOpened >= maxFilesToOpen { + return filepath.SkipAll + } + + // Check file extension + ext := filepath.Ext(path) + if ext == ".ts" || ext == ".tsx" || ext == ".js" || ext == ".jsx" { + // Try to open the file + if err := c.OpenFile(ctx, path); err == nil { + filesOpened++ + if cnf.DebugLSP { + logging.Debug("Opened TypeScript file for initialization", "file", path) + } + } + } + + return nil + }) + + if err != nil && cnf.DebugLSP { + logging.Debug("Error walking directory for TypeScript files", "error", err) + } + + if cnf.DebugLSP { + logging.Debug("Opened TypeScript files for initialization", "count", filesOpened) + } +} + +// shouldSkipDir returns true if the directory should be skipped during file search +func shouldSkipDir(path string) bool { + dirName := filepath.Base(path) + + // Skip hidden directories + if strings.HasPrefix(dirName, ".") { + return true + } + + // Skip common directories that won't contain relevant source files + skipDirs := map[string]bool{ + "node_modules": true, + "dist": true, + "build": true, + "coverage": true, + "vendor": true, + "target": true, + } + + return skipDirs[dirName] +} + +// pingWithWorkspaceSymbol tries a workspace/symbol request +func (c *Client) pingWithWorkspaceSymbol(ctx context.Context) error { + var result []protocol.SymbolInformation + return c.Call(ctx, "workspace/symbol", protocol.WorkspaceSymbolParams{ + Query: "", + }, &result) +} + +// pingWithServerCapabilities tries to get server capabilities +func (c *Client) pingWithServerCapabilities(ctx context.Context) error { + // This is a very lightweight request that should work for most servers + return c.Notify(ctx, "$/cancelRequest", struct{ ID int }{ID: -1}) } type OpenFileInfo struct { @@ -435,6 +736,43 @@ func (c *Client) GetFileDiagnostics(uri protocol.DocumentUri) []protocol.Diagnos return c.diagnostics[uri] } +// GetDiagnostics returns all diagnostics for all files func (c *Client) GetDiagnostics() map[protocol.DocumentUri][]protocol.Diagnostic { return c.diagnostics } + +// OpenFileOnDemand opens a file only if it's not already open +// This is used for lazy-loading files when they're actually needed +func (c *Client) OpenFileOnDemand(ctx context.Context, filepath string) error { + // Check if the file is already open + if c.IsFileOpen(filepath) { + return nil + } + + // Open the file + return c.OpenFile(ctx, filepath) +} + +// GetDiagnosticsForFile ensures a file is open and returns its diagnostics +// This is useful for on-demand diagnostics when using lazy loading +func (c *Client) GetDiagnosticsForFile(ctx context.Context, filepath string) ([]protocol.Diagnostic, error) { + uri := fmt.Sprintf("file://%s", filepath) + documentUri := protocol.DocumentUri(uri) + + // Make sure the file is open + if !c.IsFileOpen(filepath) { + if err := c.OpenFile(ctx, filepath); err != nil { + return nil, fmt.Errorf("failed to open file for diagnostics: %w", err) + } + + // Give the LSP server a moment to process the file + time.Sleep(100 * time.Millisecond) + } + + // Get diagnostics + c.diagnosticsMu.RLock() + diagnostics := c.diagnostics[documentUri] + c.diagnosticsMu.RUnlock() + + return diagnostics, nil +} diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index 595c78db9154c3ba94d2e0a8dc4590750280c6f6..58dd01f709b2127412f728003406e1084b49ef21 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/bmatcuk/doublestar/v4" "github.com/fsnotify/fsnotify" "github.com/kujtimiihoxha/opencode/internal/config" "github.com/kujtimiihoxha/opencode/internal/logging" @@ -43,6 +44,8 @@ func NewWorkspaceWatcher(client *lsp.Client) *WorkspaceWatcher { // AddRegistrations adds file watchers to track func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watchers []protocol.FileSystemWatcher) { cnf := config.Get() + + logging.Debug("Adding file watcher registrations") w.registrationMu.Lock() defer w.registrationMu.Unlock() @@ -55,7 +58,6 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc "id", id, "watchers", len(watchers), "total", len(w.registrations), - "watchers", watchers, ) for i, watcher := range watchers { @@ -88,66 +90,217 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc } logging.Debug("WatchKind", "kind", watchKind) - - // Test match against some example paths - testPaths := []string{ - "/Users/phil/dev/mcp-language-server/internal/watcher/watcher.go", - "/Users/phil/dev/mcp-language-server/go.mod", - } - - for _, testPath := range testPaths { - isMatch := w.matchesPattern(testPath, watcher.GlobPattern) - logging.Debug("Test path", "path", testPath, "matches", isMatch) - } } } - // Find and open all existing files that match the newly registered patterns - // TODO: not all language servers require this, but typescript does. Make this configurable - go func() { - startTime := time.Now() - filesOpened := 0 - - err := filepath.WalkDir(w.workspacePath, func(path string, d os.DirEntry, err error) error { - if err != nil { - return err + // Determine server type for specialized handling + serverName := getServerNameFromContext(ctx) + logging.Debug("Server type detected", "serverName", serverName) + + // Check if this server has sent file watchers + hasFileWatchers := len(watchers) > 0 + + // For servers that need file preloading, we'll use a smart approach + if shouldPreloadFiles(serverName) || !hasFileWatchers { + go func() { + startTime := time.Now() + filesOpened := 0 + + // Determine max files to open based on server type + maxFilesToOpen := 50 // Default conservative limit + + switch serverName { + case "typescript", "typescript-language-server", "tsserver", "vtsls": + // TypeScript servers benefit from seeing more files + maxFilesToOpen = 100 + case "java", "jdtls": + // Java servers need to see many files for project model + maxFilesToOpen = 200 + } + + // First, open high-priority files + highPriorityFilesOpened := w.openHighPriorityFiles(ctx, serverName) + filesOpened += highPriorityFilesOpened + + if cnf.DebugLSP { + logging.Debug("Opened high-priority files", + "count", highPriorityFilesOpened, + "serverName", serverName) } + + // If we've already opened enough high-priority files, we might not need more + if filesOpened >= maxFilesToOpen { + if cnf.DebugLSP { + logging.Debug("Reached file limit with high-priority files", + "filesOpened", filesOpened, + "maxFiles", maxFilesToOpen) + } + return + } + + // For the remaining slots, walk the directory and open matching files + + err := filepath.WalkDir(w.workspacePath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } - // Skip directories that should be excluded - if d.IsDir() { - if path != w.workspacePath && shouldExcludeDir(path) { - if cnf.DebugLSP { - logging.Debug("Skipping excluded directory", "path", path) + // Skip directories that should be excluded + if d.IsDir() { + if path != w.workspacePath && shouldExcludeDir(path) { + if cnf.DebugLSP { + logging.Debug("Skipping excluded directory", "path", path) + } + return filepath.SkipDir + } + } else { + // Process files, but limit the total number + if filesOpened < maxFilesToOpen { + // Only process if it's not already open (high-priority files were opened earlier) + if !w.client.IsFileOpen(path) { + w.openMatchingFile(ctx, path) + filesOpened++ + + // Add a small delay after every 10 files to prevent overwhelming the server + if filesOpened%10 == 0 { + time.Sleep(50 * time.Millisecond) + } + } + } else { + // We've reached our limit, stop walking + return filepath.SkipAll } - return filepath.SkipDir } - } else { - // Process files - w.openMatchingFile(ctx, path) - filesOpened++ - // Add a small delay after every 100 files to prevent overwhelming the server - if filesOpened%100 == 0 { - time.Sleep(10 * time.Millisecond) - } + return nil + }) + + elapsedTime := time.Since(startTime) + if cnf.DebugLSP { + logging.Debug("Limited workspace scan complete", + "filesOpened", filesOpened, + "maxFiles", maxFilesToOpen, + "elapsedTime", elapsedTime.Seconds(), + "workspacePath", w.workspacePath, + ) } - return nil - }) + if err != nil && cnf.DebugLSP { + logging.Debug("Error scanning workspace for files to open", "error", err) + } + }() + } else if cnf.DebugLSP { + logging.Debug("Using on-demand file loading for server", "server", serverName) + } +} - elapsedTime := time.Since(startTime) - if cnf.DebugLSP { - logging.Debug("Workspace scan complete", - "filesOpened", filesOpened, - "elapsedTime", elapsedTime.Seconds(), - "workspacePath", w.workspacePath, - ) +// openHighPriorityFiles opens important files for the server type +// Returns the number of files opened +func (w *WorkspaceWatcher) openHighPriorityFiles(ctx context.Context, serverName string) int { + cnf := config.Get() + filesOpened := 0 + + // Define patterns for high-priority files based on server type + var patterns []string + + switch serverName { + case "typescript", "typescript-language-server", "tsserver", "vtsls": + patterns = []string{ + "**/tsconfig.json", + "**/package.json", + "**/jsconfig.json", + "**/index.ts", + "**/index.js", + "**/main.ts", + "**/main.js", } - - if err != nil && cnf.DebugLSP { - logging.Debug("Error scanning workspace for files to open", "error", err) + case "gopls": + patterns = []string{ + "**/go.mod", + "**/go.sum", + "**/main.go", + } + case "rust-analyzer": + patterns = []string{ + "**/Cargo.toml", + "**/Cargo.lock", + "**/src/lib.rs", + "**/src/main.rs", + } + case "python", "pyright", "pylsp": + patterns = []string{ + "**/pyproject.toml", + "**/setup.py", + "**/requirements.txt", + "**/__init__.py", + "**/__main__.py", + } + case "clangd": + patterns = []string{ + "**/CMakeLists.txt", + "**/Makefile", + "**/compile_commands.json", + } + case "java", "jdtls": + patterns = []string{ + "**/pom.xml", + "**/build.gradle", + "**/src/main/java/**/*.java", } - }() + default: + // For unknown servers, use common configuration files + patterns = []string{ + "**/package.json", + "**/Makefile", + "**/CMakeLists.txt", + "**/.editorconfig", + } + } + + // For each pattern, find and open matching files + for _, pattern := range patterns { + // Use doublestar.Glob to find files matching the pattern (supports ** patterns) + matches, err := doublestar.Glob(os.DirFS(w.workspacePath), pattern) + if err != nil { + if cnf.DebugLSP { + logging.Debug("Error finding high-priority files", "pattern", pattern, "error", err) + } + continue + } + + for _, match := range matches { + // Convert relative path to absolute + fullPath := filepath.Join(w.workspacePath, match) + + // Skip directories and excluded files + info, err := os.Stat(fullPath) + if err != nil || info.IsDir() || shouldExcludeFile(fullPath) { + continue + } + + // Open the file + if err := w.client.OpenFile(ctx, fullPath); err != nil { + if cnf.DebugLSP { + logging.Debug("Error opening high-priority file", "path", fullPath, "error", err) + } + } else { + filesOpened++ + if cnf.DebugLSP { + logging.Debug("Opened high-priority file", "path", fullPath) + } + } + + // Add a small delay to prevent overwhelming the server + time.Sleep(20 * time.Millisecond) + + // Limit the number of files opened per pattern + if filesOpened >= 5 && (serverName != "java" && serverName != "jdtls") { + break + } + } + } + + return filesOpened } // WatchWorkspace sets up file watching for a workspace @@ -155,6 +308,18 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str cnf := config.Get() w.workspacePath = workspacePath + // Store the watcher in the context for later use + ctx = context.WithValue(ctx, "workspaceWatcher", w) + + // If the server name isn't already in the context, try to detect it + if _, ok := ctx.Value("serverName").(string); !ok { + serverName := getServerNameFromContext(ctx) + ctx = context.WithValue(ctx, "serverName", serverName) + } + + serverName := getServerNameFromContext(ctx) + logging.Debug("Starting workspace watcher", "workspacePath", workspacePath, "serverName", serverName) + // Register handler for file watcher registrations from the server lsp.RegisterFileWatchHandler(func(id string, watchers []protocol.FileSystemWatcher) { w.AddRegistrations(ctx, id, watchers) @@ -510,6 +675,57 @@ func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, chan return w.client.DidChangeWatchedFiles(ctx, params) } +// getServerNameFromContext extracts the server name from the context +// This is a best-effort function that tries to identify which LSP server we're dealing with +func getServerNameFromContext(ctx context.Context) string { + // First check if the server name is directly stored in the context + if serverName, ok := ctx.Value("serverName").(string); ok && serverName != "" { + return strings.ToLower(serverName) + } + + // Otherwise, try to extract server name from the client command path + if w, ok := ctx.Value("workspaceWatcher").(*WorkspaceWatcher); ok && w != nil && w.client != nil && w.client.Cmd != nil { + path := strings.ToLower(w.client.Cmd.Path) + + // Extract server name from path + if strings.Contains(path, "typescript") || strings.Contains(path, "tsserver") || strings.Contains(path, "vtsls") { + return "typescript" + } else if strings.Contains(path, "gopls") { + return "gopls" + } else if strings.Contains(path, "rust-analyzer") { + return "rust-analyzer" + } else if strings.Contains(path, "pyright") || strings.Contains(path, "pylsp") || strings.Contains(path, "python") { + return "python" + } else if strings.Contains(path, "clangd") { + return "clangd" + } else if strings.Contains(path, "jdtls") || strings.Contains(path, "java") { + return "java" + } + + // Return the base name as fallback + return filepath.Base(path) + } + + return "unknown" +} + +// shouldPreloadFiles determines if we should preload files for a specific language server +// Some servers work better with preloaded files, others don't need it +func shouldPreloadFiles(serverName string) bool { + // TypeScript/JavaScript servers typically need some files preloaded + // to properly resolve imports and provide intellisense + switch serverName { + case "typescript", "typescript-language-server", "tsserver", "vtsls": + return true + case "java", "jdtls": + // Java servers often need to see source files to build the project model + return true + default: + // For most servers, we'll use lazy loading by default + return false + } +} + // Common patterns for directories and files to exclude // TODO: make configurable var ( @@ -647,9 +863,119 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { // Check if this path should be watched according to server registrations if watched, _ := w.isPathWatched(path); watched { - // Don't need to check if it's already open - the client.OpenFile handles that - if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { - logging.Error("Error opening file", "path", path, "error", err) + // Get server name for specialized handling + serverName := getServerNameFromContext(ctx) + + // Check if the file is a high-priority file that should be opened immediately + // This helps with project initialization for certain language servers + if isHighPriorityFile(path, serverName) { + if cnf.DebugLSP { + logging.Debug("Opening high-priority file", "path", path, "serverName", serverName) + } + if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { + logging.Error("Error opening high-priority file", "path", path, "error", err) + } + return + } + + // For non-high-priority files, we'll use different strategies based on server type + if shouldPreloadFiles(serverName) { + // For servers that benefit from preloading, open files but with limits + + // Check file size - for preloading we're more conservative + if info.Size() > (1 * 1024 * 1024) { // 1MB limit for preloaded files + if cnf.DebugLSP { + logging.Debug("Skipping large file for preloading", "path", path, "size", info.Size()) + } + return + } + + // Check file extension for common source files + ext := strings.ToLower(filepath.Ext(path)) + + // Only preload source files for the specific language + shouldOpen := false + + switch serverName { + case "typescript", "typescript-language-server", "tsserver", "vtsls": + shouldOpen = ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx" + case "gopls": + shouldOpen = ext == ".go" + case "rust-analyzer": + shouldOpen = ext == ".rs" + case "python", "pyright", "pylsp": + shouldOpen = ext == ".py" + case "clangd": + shouldOpen = ext == ".c" || ext == ".cpp" || ext == ".h" || ext == ".hpp" + case "java", "jdtls": + shouldOpen = ext == ".java" + default: + // For unknown servers, be conservative + shouldOpen = false + } + + if shouldOpen { + // Don't need to check if it's already open - the client.OpenFile handles that + if err := w.client.OpenFile(ctx, path); err != nil && cnf.DebugLSP { + logging.Error("Error opening file", "path", path, "error", err) + } + } } } } + +// isHighPriorityFile determines if a file should be opened immediately +// regardless of the preloading strategy +func isHighPriorityFile(path string, serverName string) bool { + fileName := filepath.Base(path) + ext := filepath.Ext(path) + + switch serverName { + case "typescript", "typescript-language-server", "tsserver", "vtsls": + // For TypeScript, we want to open configuration files immediately + return fileName == "tsconfig.json" || + fileName == "package.json" || + fileName == "jsconfig.json" || + // Also open main entry points + fileName == "index.ts" || + fileName == "index.js" || + fileName == "main.ts" || + fileName == "main.js" + case "gopls": + // For Go, we want to open go.mod files immediately + return fileName == "go.mod" || + fileName == "go.sum" || + // Also open main.go files + fileName == "main.go" + case "rust-analyzer": + // For Rust, we want to open Cargo.toml files immediately + return fileName == "Cargo.toml" || + fileName == "Cargo.lock" || + // Also open lib.rs and main.rs + fileName == "lib.rs" || + fileName == "main.rs" + case "python", "pyright", "pylsp": + // For Python, open key project files + return fileName == "pyproject.toml" || + fileName == "setup.py" || + fileName == "requirements.txt" || + fileName == "__init__.py" || + fileName == "__main__.py" + case "clangd": + // For C/C++, open key project files + return fileName == "CMakeLists.txt" || + fileName == "Makefile" || + fileName == "compile_commands.json" + case "java", "jdtls": + // For Java, open key project files + return fileName == "pom.xml" || + fileName == "build.gradle" || + ext == ".java" // Java servers often need to see source files + } + + // For unknown servers, prioritize common configuration files + return fileName == "package.json" || + fileName == "Makefile" || + fileName == "CMakeLists.txt" || + fileName == ".editorconfig" +} diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index 52ff4c8bf3e28f500c9890487f9d7ba9e48eb62c..b2b5a5c4aa10acd9c870aefca71bf0978970c66c 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -2,6 +2,7 @@ package chat import ( "fmt" + "sort" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/x/ansi" @@ -28,8 +29,16 @@ func lspsConfigured(width int) string { lsps := styles.BaseStyle.Width(width).Foreground(styles.PrimaryColor).Bold(true).Render(title) + // Get LSP names and sort them for consistent ordering + var lspNames []string + for name := range cfg.LSP { + lspNames = append(lspNames, name) + } + sort.Strings(lspNames) + var lspViews []string - for name, lsp := range cfg.LSP { + for _, name := range lspNames { + lsp := cfg.LSP[name] lspName := styles.BaseStyle.Foreground(styles.Forground).Render( fmt.Sprintf("• %s", name), ) @@ -49,7 +58,6 @@ func lspsConfigured(width int) string { ), ), ) - } return styles.BaseStyle. Width(width). diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go index b7703e2cc28c2a949fc82e40327cf35de69ed0cd..994ddea036dbf9d0ca50840730cfc27e8cbc0b82 100644 --- a/internal/tui/components/chat/list.go +++ b/internal/tui/components/chat/list.go @@ -376,14 +376,7 @@ func (m *messagesCmp) working() string { if hasToolsWithoutResponse(m.messages) { task = "Waiting for tool response..." } else if !lastMessage.IsFinished() { - lastUpdate := lastMessage.UpdatedAt - currentTime := time.Now().Unix() - if lastMessage.Content().String() != "" && lastUpdate != 0 && currentTime-lastUpdate > 5 { - task = "Building tool call..." - } else if lastMessage.Content().String() == "" { - task = "Generating..." - } - task = "" + task = "Generating..." } if task != "" { text += styles.BaseStyle.Width(m.width).Foreground(styles.PrimaryColor).Bold(true).Render( diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go index 7a840b4ec56e9d5225f35abdb7dad758091427ff..14b9e268e2de3aefece34898f5fbaeba319774b2 100644 --- a/internal/tui/components/chat/message.go +++ b/internal/tui/components/chat/message.go @@ -151,7 +151,11 @@ func renderAssistantMessage( )) } } - if content != "" { + if content != "" || (finished && finishData.Reason == message.FinishReasonEndTurn) { + if content == "" { + content = "*Finished without output*" + } + content = renderMessage(content, false, msg.ID == focusedUIMessageId, width, info...) messages = append(messages, uiMessage{ ID: msg.ID, diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 01c5358697906573a0bb52c4f8834fa7ee009df7..5a2114e8363dbb424db41be18908bb50570a5c40 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -138,6 +138,23 @@ func (m statusCmp) View() string { } func (m *statusCmp) projectDiagnostics() string { + // Check if any LSP server is still initializing + initializing := false + for _, client := range m.lspClients { + if client.GetServerState() == lsp.StateStarting { + initializing = true + break + } + } + + // If any server is initializing, show that status + if initializing { + return lipgloss.NewStyle(). + Background(styles.BackgroundDarker). + Foreground(styles.Peach). + Render(fmt.Sprintf("%s Initializing LSP...", styles.SpinnerIcon)) + } + errorDiagnostics := []protocol.Diagnostic{} warnDiagnostics := []protocol.Diagnostic{} hintDiagnostics := []protocol.Diagnostic{} diff --git a/internal/tui/styles/icons.go b/internal/tui/styles/icons.go index dd5f4dc51e1e00201bfa0d8dd52a1d728ffc0ede..96d1b8976a96fc8c45d849e8e3f418b9a0587439 100644 --- a/internal/tui/styles/icons.go +++ b/internal/tui/styles/icons.go @@ -6,7 +6,8 @@ const ( CheckIcon string = "✓" ErrorIcon string = "✖" WarningIcon string = "⚠" - InfoIcon string = "" + InfoIcon string = "" HintIcon string = "i" SpinnerIcon string = "..." -) + LoadingIcon string = "⟳" +) \ No newline at end of file From 2de51274177432b559be3b7deb1f14b9539f2994 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 19 Apr 2025 16:35:45 +0200 Subject: [PATCH 30/41] initial tool call stream --- internal/llm/agent/agent.go | 22 +++++ internal/llm/provider/anthropic.go | 60 +++++++++--- internal/llm/provider/openai.go | 9 +- internal/llm/provider/provider.go | 7 +- internal/message/content.go | 43 +++++++++ internal/message/message.go | 2 + internal/pubsub/broker.go | 7 -- internal/tui/components/chat/list.go | 117 ++++++------------------ internal/tui/components/chat/message.go | 92 +++++++++++++++---- internal/tui/layout/split.go | 28 ++++++ internal/tui/page/chat.go | 10 +- 11 files changed, 261 insertions(+), 136 deletions(-) diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 7542d9adf436642aee9b5613e1cfc8e515b8aee7..ae5bcb23178810a31dcc2c4b63e5cd8486390f85 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -380,6 +380,21 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg case provider.EventContentDelta: assistantMsg.AppendContent(event.Content) return a.messages.Update(ctx, *assistantMsg) + case provider.EventToolUseStart: + assistantMsg.AddToolCall(*event.ToolCall) + return a.messages.Update(ctx, *assistantMsg) + // TODO: see how to handle this + // case provider.EventToolUseDelta: + // tm := time.Unix(assistantMsg.UpdatedAt, 0) + // assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input) + // if time.Since(tm) > 1000*time.Millisecond { + // err := a.messages.Update(ctx, *assistantMsg) + // assistantMsg.UpdatedAt = time.Now().Unix() + // return err + // } + case provider.EventToolUseStop: + assistantMsg.FinishToolCall(event.ToolCall.ID) + return a.messages.Update(ctx, *assistantMsg) case provider.EventError: if errors.Is(event.Error, context.Canceled) { logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID)) @@ -456,6 +471,13 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) provider.WithReasoningEffort(agentConfig.ReasoningEffort), ), ) + } else if model.Provider == models.ProviderAnthropic && model.CanReason { + opts = append( + opts, + provider.WithAnthropicOptions( + provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn), + ), + ) } agentProvider, err := provider.NewProvider( model.Provider, diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 7bbc02103df7d81b121d19244b4e6da8ce0fd600..2c16a059357abef9529c04f4d4a9f40e935bffd4 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -93,8 +93,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic } if len(blocks) == 0 { - logging.Warn("There is a message without content, investigate") - // This should never happend but we log this because we might have a bug in our cleanup method + logging.Warn("There is a message without content, investigate, this should not happen") continue } anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) @@ -196,8 +195,8 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message, preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) cfg := config.Get() if cfg.Debug { - jsonData, _ := json.Marshal(preparedMessages) - logging.Debug("Prepared messages", "messages", string(jsonData)) + // jsonData, _ := json.Marshal(preparedMessages) + // logging.Debug("Prepared messages", "messages", string(jsonData)) } attempts := 0 for { @@ -243,8 +242,8 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) cfg := config.Get() if cfg.Debug { - jsonData, _ := json.Marshal(preparedMessages) - logging.Debug("Prepared messages", "messages", string(jsonData)) + // jsonData, _ := json.Marshal(preparedMessages) + // logging.Debug("Prepared messages", "messages", string(jsonData)) } attempts := 0 eventChan := make(chan ProviderEvent) @@ -257,6 +256,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message ) accumulatedMessage := anthropic.Message{} + currentToolCallID := "" for anthropicStream.Next() { event := anthropicStream.Current() err := accumulatedMessage.Accumulate(event) @@ -267,7 +267,19 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message switch event := event.AsAny().(type) { case anthropic.ContentBlockStartEvent: - eventChan <- ProviderEvent{Type: EventContentStart} + if event.ContentBlock.Type == "text" { + eventChan <- ProviderEvent{Type: EventContentStart} + } else if event.ContentBlock.Type == "tool_use" { + currentToolCallID = event.ContentBlock.ID + eventChan <- ProviderEvent{ + Type: EventToolUseStart, + ToolCall: &message.ToolCall{ + ID: event.ContentBlock.ID, + Name: event.ContentBlock.Name, + Finished: false, + }, + } + } case anthropic.ContentBlockDeltaEvent: if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" { @@ -280,11 +292,30 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message Type: EventContentDelta, Content: event.Delta.Text, } + } else if event.Delta.Type == "input_json_delta" { + if currentToolCallID != "" { + eventChan <- ProviderEvent{ + Type: EventToolUseDelta, + ToolCall: &message.ToolCall{ + ID: currentToolCallID, + Finished: false, + Input: event.Delta.JSON.PartialJSON.Raw(), + }, + } + } } - // TODO: check if we can somehow stream tool calls - case anthropic.ContentBlockStopEvent: - eventChan <- ProviderEvent{Type: EventContentStop} + if currentToolCallID != "" { + eventChan <- ProviderEvent{ + Type: EventToolUseStop, + ToolCall: &message.ToolCall{ + ID: currentToolCallID, + }, + } + currentToolCallID = "" + } else { + eventChan <- ProviderEvent{Type: EventContentStop} + } case anthropic.MessageStopEvent: content := "" @@ -378,10 +409,11 @@ func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall { switch variant := block.AsAny().(type) { case anthropic.ToolUseBlock: toolCall := message.ToolCall{ - ID: variant.ID, - Name: variant.Name, - Input: string(variant.Input), - Type: string(variant.Type), + ID: variant.ID, + Name: variant.Name, + Input: string(variant.Input), + Type: string(variant.Type), + Finished: true, } toolCalls = append(toolCalls, toolCall) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 6c6f74988168c8a01e5cd4d8fd95b54b8c8618b6..40d2632420747c79e3cad96ad9a4126a7c0486fd 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -344,10 +344,11 @@ func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.Too if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 { for _, call := range completion.Choices[0].Message.ToolCalls { toolCall := message.ToolCall{ - ID: call.ID, - Name: call.Function.Name, - Input: call.Function.Arguments, - Type: "function", + ID: call.ID, + Name: call.Function.Name, + Input: call.Function.Arguments, + Type: "function", + Finished: true, } toolCalls = append(toolCalls, toolCall) } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index e04bee71bce87b642b6b1b227adc22d72dfcdb26..283a0d983003349cf00bf4446e4a56954e651e58 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -15,6 +15,9 @@ const maxRetries = 8 const ( EventContentStart EventType = "content_start" + EventToolUseStart EventType = "tool_use_start" + EventToolUseDelta EventType = "tool_use_delta" + EventToolUseStop EventType = "tool_use_stop" EventContentDelta EventType = "content_delta" EventThinkingDelta EventType = "thinking_delta" EventContentStop EventType = "content_stop" @@ -43,8 +46,8 @@ type ProviderEvent struct { Content string Thinking string Response *ProviderResponse - - Error error + ToolCall *message.ToolCall + Error error } type Provider interface { SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) diff --git a/internal/message/content.go b/internal/message/content.go index f52449f4a394cdc4fdebcb8921f3f7cf9601d068..beebe354e6eabe1e27266384a6534fa40f3e7167 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -233,6 +233,40 @@ func (m *Message) AppendReasoningContent(delta string) { } } +func (m *Message) FinishToolCall(toolCallID string) { + for i, part := range m.Parts { + if c, ok := part.(ToolCall); ok { + if c.ID == toolCallID { + m.Parts[i] = ToolCall{ + ID: c.ID, + Name: c.Name, + Input: c.Input, + Type: c.Type, + Finished: true, + } + return + } + } + } +} + +func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) { + for i, part := range m.Parts { + if c, ok := part.(ToolCall); ok { + if c.ID == toolCallID { + m.Parts[i] = ToolCall{ + ID: c.ID, + Name: c.Name, + Input: c.Input + inputDelta, + Type: c.Type, + Finished: c.Finished, + } + return + } + } + } +} + func (m *Message) AddToolCall(tc ToolCall) { for i, part := range m.Parts { if c, ok := part.(ToolCall); ok { @@ -246,6 +280,15 @@ func (m *Message) AddToolCall(tc ToolCall) { } func (m *Message) SetToolCalls(tc []ToolCall) { + // remove any existing tool call part it could have multiple + parts := make([]ContentPart, 0) + for _, part := range m.Parts { + if _, ok := part.(ToolCall); ok { + continue + } + parts = append(parts, part) + } + m.Parts = parts for _, toolCall := range tc { m.Parts = append(m.Parts, toolCall) } diff --git a/internal/message/message.go b/internal/message/message.go index f165fcfc75a3f56f0fc454cb1aabcb5bb75d8c6f..20ace7b4166bd6e5c2f8ba7b3c04c8b20056f48a 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "time" "github.com/google/uuid" "github.com/kujtimiihoxha/opencode/internal/db" @@ -116,6 +117,7 @@ func (s *service) Update(ctx context.Context, message Message) error { if err != nil { return err } + message.UpdatedAt = time.Now().Unix() s.Publish(pubsub.UpdatedEvent, message) return nil } diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 3e70ae09525a60e143d710614cf3d68d84f75d03..d73accffba6ed7695cd845143ef7a611a8058682 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -7,13 +7,6 @@ import ( const bufferSize = 1024 -type Logger interface { - Debug(msg string, args ...any) - Info(msg string, args ...any) - Warn(msg string, args ...any) - Error(msg string, args ...any) -} - // Broker allows clients to publish events and subscribe to events type Broker[T any] struct { subs map[chan Event[T]]struct{} // subscriptions diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go index 994ddea036dbf9d0ca50840730cfc27e8cbc0b82..b09cc449507881b6f5030fe4f1eaabc286784219 100644 --- a/internal/tui/components/chat/list.go +++ b/internal/tui/components/chat/list.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "math" - "sync" - "time" "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/spinner" @@ -13,7 +11,6 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/opencode/internal/app" - "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/kujtimiihoxha/opencode/internal/message" "github.com/kujtimiihoxha/opencode/internal/pubsub" "github.com/kujtimiihoxha/opencode/internal/session" @@ -35,89 +32,14 @@ type messagesCmp struct { messages []message.Message uiMessages []uiMessage currentMsgID string - mutex sync.Mutex cachedContent map[string]cacheItem spinner spinner.Model - lastUpdate time.Time rendering bool } type renderFinishedMsg struct{} func (m *messagesCmp) Init() tea.Cmd { - return tea.Batch(m.viewport.Init()) -} - -func (m *messagesCmp) preloadSessions() tea.Cmd { - return func() tea.Msg { - m.mutex.Lock() - defer m.mutex.Unlock() - sessions, err := m.app.Sessions.List(context.Background()) - if err != nil { - return util.ReportError(err)() - } - if len(sessions) == 0 { - return nil - } - if len(sessions) > 20 { - sessions = sessions[:20] - } - for _, s := range sessions { - messages, err := m.app.Messages.List(context.Background(), s.ID) - if err != nil { - return util.ReportError(err)() - } - if len(messages) == 0 { - continue - } - m.cacheSessionMessages(messages, m.width) - - } - logging.Debug("preloaded sessions") - - return func() tea.Msg { - return renderFinishedMsg{} - } - } -} - -func (m *messagesCmp) cacheSessionMessages(messages []message.Message, width int) { - pos := 0 - if m.width == 0 { - return - } - for inx, msg := range messages { - switch msg.Role { - case message.User: - userMsg := renderUserMessage( - msg, - false, - width, - pos, - ) - m.cachedContent[msg.ID] = cacheItem{ - width: width, - content: []uiMessage{userMsg}, - } - pos += userMsg.height + 1 // + 1 for spacing - case message.Assistant: - assistantMessages := renderAssistantMessage( - msg, - inx, - messages, - m.app.Messages, - "", - width, - pos, - ) - for _, msg := range assistantMessages { - pos += msg.height + 1 // + 1 for spacing - } - m.cachedContent[msg.ID] = cacheItem{ - width: width, - content: assistantMessages, - } - } - } + return tea.Batch(m.viewport.Init(), m.spinner.Tick) } func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -360,21 +282,35 @@ func hasToolsWithoutResponse(messages []message.Message) bool { break } } - if !found { + if !found && v.Finished { return true } } + return false +} +func hasUnfinishedToolCalls(messages []message.Message) bool { + toolCalls := make([]message.ToolCall, 0) + for _, m := range messages { + toolCalls = append(toolCalls, m.ToolCalls()...) + } + for _, v := range toolCalls { + if !v.Finished { + return true + } + } return false } func (m *messagesCmp) working() string { text := "" - if m.IsAgentWorking() { + if m.IsAgentWorking() && len(m.messages) > 0 { task := "Thinking..." lastMessage := m.messages[len(m.messages)-1] if hasToolsWithoutResponse(m.messages) { task = "Waiting for tool response..." + } else if hasUnfinishedToolCalls(m.messages) { + task = "Building tool call..." } else if !lastMessage.IsFinished() { task = "Generating..." } @@ -434,8 +370,7 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd { delete(m.cachedContent, msg.ID) } m.uiMessages = make([]uiMessage, 0) - m.renderView() - return m.preloadSessions() + return nil } func (m *messagesCmp) GetSize() (int, int) { @@ -446,16 +381,16 @@ func (m *messagesCmp) SetSession(session session.Session) tea.Cmd { if m.session.ID == session.ID { return nil } + m.session = session + messages, err := m.app.Messages.List(context.Background(), session.ID) + if err != nil { + return util.ReportError(err) + } + m.messages = messages + m.currentMsgID = m.messages[len(m.messages)-1].ID + delete(m.cachedContent, m.currentMsgID) m.rendering = true return func() tea.Msg { - m.session = session - messages, err := m.app.Messages.List(context.Background(), session.ID) - if err != nil { - return util.ReportError(err) - } - m.messages = messages - m.currentMsgID = m.messages[len(m.messages)-1].ID - delete(m.cachedContent, m.currentMsgID) m.renderView() return renderFinishedMsg{} } diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go index 14b9e268e2de3aefece34898f5fbaeba319774b2..b8e4500797851cdfe166e1f983b4e31d64c1b6aa 100644 --- a/internal/tui/components/chat/message.go +++ b/internal/tui/components/chat/message.go @@ -113,18 +113,10 @@ func renderAssistantMessage( width int, position int, ) []uiMessage { - // find the user message that is before this assistant message - var userMsg message.Message - for i := msgIndex - 1; i >= 0; i-- { - msg := allMessages[i] - if msg.Role == message.User { - userMsg = allMessages[i] - break - } - } - messages := []uiMessage{} content := msg.Content().String() + thinking := msg.IsThinking() + thinkingContent := msg.ReasoningContent().Thinking finished := msg.IsFinished() finishData := msg.FinishPart() info := []string{} @@ -133,7 +125,7 @@ func renderAssistantMessage( if finished { switch finishData.Reason { case message.FinishReasonEndTurn: - took := formatTimeDifference(userMsg.CreatedAt, finishData.Time) + took := formatTimeDifference(msg.CreatedAt, finishData.Time) info = append(info, styles.BaseStyle.Width(width-1).Foreground(styles.ForgroundDim).Render( fmt.Sprintf(" %s (%s)", models.SupportedModels[msg.Model].Name, took), )) @@ -166,6 +158,9 @@ func renderAssistantMessage( }) position += messages[0].height position++ // for the space + } else if thinking && thinkingContent != "" { + // Render the thinking content + content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width) } for i, toolCall := range msg.ToolCalls() { @@ -218,10 +213,40 @@ func toolName(name string) string { return "View" case tools.WriteToolName: return "Write" + case tools.PatchToolName: + return "Patch" } return name } +func getToolAction(name string) string { + switch name { + case agent.AgentToolName: + return "Preparing prompt..." + case tools.BashToolName: + return "Building command..." + case tools.EditToolName: + return "Preparing edit..." + case tools.FetchToolName: + return "Writing fetch..." + case tools.GlobToolName: + return "Finding files..." + case tools.GrepToolName: + return "Searching content..." + case tools.LSToolName: + return "Listing directory..." + case tools.SourcegraphToolName: + return "Searching code..." + case tools.ViewToolName: + return "Reading file..." + case tools.WriteToolName: + return "Preparing write..." + case tools.PatchToolName: + return "Preparing patch..." + } + return "Working..." +} + // renders params, params[0] (params[1]=params[2] ....) func renderParams(paramsWidth int, params ...string) string { if len(params) == 0 { @@ -490,8 +515,47 @@ func renderToolMessage( if nested { width = width - 3 } + style := styles.BaseStyle. + Width(width - 1). + BorderLeft(true). + BorderStyle(lipgloss.ThickBorder()). + PaddingLeft(1). + BorderForeground(styles.ForgroundDim) + response := findToolResponse(toolCall.ID, allMessages) toolName := styles.BaseStyle.Foreground(styles.ForgroundDim).Render(fmt.Sprintf("%s: ", toolName(toolCall.Name))) + + if !toolCall.Finished { + // Get a brief description of what the tool is doing + toolAction := getToolAction(toolCall.Name) + + // toolInput := strings.ReplaceAll(toolCall.Input, "\n", " ") + // truncatedInput := toolInput + // if len(truncatedInput) > 10 { + // truncatedInput = truncatedInput[len(truncatedInput)-10:] + // } + // + // truncatedInput = styles.BaseStyle. + // Italic(true). + // Width(width - 2 - lipgloss.Width(toolName)). + // Background(styles.BackgroundDim). + // Foreground(styles.ForgroundMid). + // Render(truncatedInput) + + progressText := styles.BaseStyle. + Width(width - 2 - lipgloss.Width(toolName)). + Foreground(styles.ForgroundDim). + Render(fmt.Sprintf("%s", toolAction)) + + content := style.Render(lipgloss.JoinHorizontal(lipgloss.Left, toolName, progressText)) + toolMsg := uiMessage{ + messageType: toolMessageType, + position: position, + height: lipgloss.Height(content), + content: content, + } + return toolMsg + } params := renderToolParams(width-2-lipgloss.Width(toolName), toolCall) responseContent := "" if response != nil { @@ -504,12 +568,6 @@ func renderToolMessage( Foreground(styles.ForgroundDim). Render("Waiting for response...") } - style := styles.BaseStyle. - Width(width - 1). - BorderLeft(true). - BorderStyle(lipgloss.ThickBorder()). - PaddingLeft(1). - BorderForeground(styles.ForgroundDim) parts := []string{} if !nested { diff --git a/internal/tui/layout/split.go b/internal/tui/layout/split.go index a41df6ab82199468e8627d4f96919630455b6b5c..f3ab9247df9b04ac64a4b19d030b2302cc55b6b0 100644 --- a/internal/tui/layout/split.go +++ b/internal/tui/layout/split.go @@ -14,6 +14,10 @@ type SplitPaneLayout interface { SetLeftPanel(panel Container) tea.Cmd SetRightPanel(panel Container) tea.Cmd SetBottomPanel(panel Container) tea.Cmd + + ClearLeftPanel() tea.Cmd + ClearRightPanel() tea.Cmd + ClearBottomPanel() tea.Cmd } type splitPaneLayout struct { @@ -192,6 +196,30 @@ func (s *splitPaneLayout) SetBottomPanel(panel Container) tea.Cmd { return nil } +func (s *splitPaneLayout) ClearLeftPanel() tea.Cmd { + s.leftPanel = nil + if s.width > 0 && s.height > 0 { + return s.SetSize(s.width, s.height) + } + return nil +} + +func (s *splitPaneLayout) ClearRightPanel() tea.Cmd { + s.rightPanel = nil + if s.width > 0 && s.height > 0 { + return s.SetSize(s.width, s.height) + } + return nil +} + +func (s *splitPaneLayout) ClearBottomPanel() tea.Cmd { + s.bottomPanel = nil + if s.width > 0 && s.height > 0 { + return s.SetSize(s.width, s.height) + } + return nil +} + func (s *splitPaneLayout) BindingKeys() []key.Binding { keys := []key.Binding{} if s.leftPanel != nil { diff --git a/internal/tui/page/chat.go b/internal/tui/page/chat.go index ef826e9a335e4df8592523212e345f435f0c1e37..a5a656a222697a5df6ffe5737c758d40e0dbb195 100644 --- a/internal/tui/page/chat.go +++ b/internal/tui/page/chat.go @@ -57,6 +57,14 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if cmd != nil { return p, cmd } + case chat.SessionSelectedMsg: + if p.session.ID == "" { + cmd := p.setSidebar() + if cmd != nil { + cmds = append(cmds, cmd) + } + } + p.session = msg case chat.EditorFocusMsg: p.editingMode = bool(msg) case tea.KeyMsg: @@ -91,7 +99,7 @@ func (p *chatPage) setSidebar() tea.Cmd { } func (p *chatPage) clearSidebar() tea.Cmd { - return p.layout.SetRightPanel(nil) + return p.layout.ClearRightPanel() } func (p *chatPage) sendMessage(text string) tea.Cmd { From 8e160488ff1aa29f6b2cb601145e9f3ff5410d07 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sun, 20 Apr 2025 15:56:36 +0200 Subject: [PATCH 31/41] improve cache --- internal/llm/provider/anthropic.go | 13 +++++++------ internal/tui/components/dialog/permission.go | 7 ++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 2c16a059357abef9529c04f4d4a9f40e935bffd4..86f483f649a61dca2887ee7d656a73c0769647af 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -57,16 +57,18 @@ func newAnthropicClient(opts providerClientOptions) AnthropicClient { } func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { - cachedBlocks := 0 - for _, msg := range messages { + for i, msg := range messages { + cache := false + if len(messages)-3 > i { + cache = true + } switch msg.Role { case message.User: content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.options.disableCache { + if cache && !a.options.disableCache { content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } - cachedBlocks++ } anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content)) @@ -74,11 +76,10 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic blocks := []anthropic.ContentBlockParamUnion{} if msg.Content().String() != "" { content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 && !a.options.disableCache { + if cache && !a.options.disableCache { content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } - cachedBlocks++ } blocks = append(blocks, content) } diff --git a/internal/tui/components/dialog/permission.go b/internal/tui/components/dialog/permission.go index 1f8df21a0b2cdfdd319ae8fd43c05178e1ad28f8..16b63815cf98e23cbc9709e02012ec009b3c57c3 100644 --- a/internal/tui/components/dialog/permission.go +++ b/internal/tui/components/dialog/permission.go @@ -375,6 +375,9 @@ func (p *permissionDialogCmp) render() string { contentFinal = p.renderDefaultContent() } + // Add help text + helpText := styles.BaseStyle.Width(p.width - 4).Padding(0, 1).Foreground(styles.ForgroundDim).Render("←/→/tab: switch options a: allow A: allow for session d: deny enter/space: confirm") + content := lipgloss.JoinVertical( lipgloss.Top, title, @@ -382,6 +385,8 @@ func (p *permissionDialogCmp) render() string { headerContent, contentFinal, buttons, + styles.BaseStyle.Render(strings.Repeat(" ", p.width - 4)), + helpText, ) return styles.BaseStyle. @@ -401,7 +406,7 @@ func (p *permissionDialogCmp) View() string { } func (p *permissionDialogCmp) BindingKeys() []key.Binding { - return layout.KeyMapToSlice(helpKeys) + return layout.KeyMapToSlice(permissionsKeys) } func (p *permissionDialogCmp) SetSize() tea.Cmd { From c40e68496d4ed5a7db47879376dec293a9e82856 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sun, 20 Apr 2025 16:52:36 +0200 Subject: [PATCH 32/41] add context to the prompt --- internal/llm/prompt/prompt.go | 46 ++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index cdc3560cefe9d0d464fe1786e12e92165503868f..cf4d9a7e723a285d3529b230ad184893215cdd96 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -1,19 +1,57 @@ package prompt import ( + "fmt" + "os" + "path/filepath" + "github.com/kujtimiihoxha/opencode/internal/config" "github.com/kujtimiihoxha/opencode/internal/llm/models" ) +// contextFiles is a list of potential context files to check for +var contextFiles = []string{ + ".github/copilot-instructions.md", + ".cursorrules", + "CLAUDE.md", + "opencode.md", + "OpenCode.md", +} + func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { + basePrompt := "" switch agentName { case config.AgentCoder: - return CoderPrompt(provider) + basePrompt = CoderPrompt(provider) case config.AgentTitle: - return TitlePrompt(provider) + basePrompt = TitlePrompt(provider) case config.AgentTask: - return TaskPrompt(provider) + basePrompt = TaskPrompt(provider) default: - return "You are a helpful assistant" + basePrompt = "You are a helpful assistant" } + + // Add context from project-specific instruction files if they exist + contextContent := getContextFromFiles() + if contextContent != "" { + return fmt.Sprintf("%s\n\n# Project-Specific Context\n%s", basePrompt, contextContent) + } + + return basePrompt +} + +// getContextFromFiles checks for the existence of context files and returns their content +func getContextFromFiles() string { + workDir := config.WorkingDirectory() + var contextContent string + + for _, file := range contextFiles { + filePath := filepath.Join(workDir, file) + content, err := os.ReadFile(filePath) + if err == nil { + contextContent += fmt.Sprintf("\n%s\n", string(content)) + } + } + + return contextContent } From 1da298e7554bab0f7a631a44fed12692d668c024 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 07:20:09 +0200 Subject: [PATCH 33/41] fix anthropic --- internal/llm/provider/anthropic.go | 2 +- internal/llm/tools/patch.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 86f483f649a61dca2887ee7d656a73c0769647af..03d96fb249785f721a1c6b42c304ebf17e9d61c4 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -59,7 +59,7 @@ func newAnthropicClient(opts providerClientOptions) AnthropicClient { func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { for i, msg := range messages { cache := false - if len(messages)-3 > i { + if i > len(messages)-3 { cache = true } switch msg.Role { diff --git a/internal/llm/tools/patch.go b/internal/llm/tools/patch.go index 0f879462caf833fd121600153f8b70ff2c85c898..92eab69297a2513530a84a99cf411dc50cc4e359 100644 --- a/internal/llm/tools/patch.go +++ b/internal/llm/tools/patch.go @@ -169,7 +169,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse(fmt.Sprintf("failed to parse patch: %s", err)), nil } - if fuzz > 0 { + if fuzz > 3 { return NewTextErrorResponse(fmt.Sprintf("patch contains fuzzy matches (fuzz level: %d). Please make your context lines more precise", fuzz)), nil } From e7bb99baab5e6968ce0351d6ad219ed21ceec4df Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 13:33:51 +0200 Subject: [PATCH 34/41] fix the memory bug --- README.md | 7 ++- cmd/root.go | 24 ++++----- internal/pubsub/broker.go | 72 +++++++++++++++++--------- internal/tui/components/chat/list.go | 1 + internal/tui/components/core/status.go | 15 ++++-- internal/tui/tui.go | 63 +++++++++++++++++----- 6 files changed, 127 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index ef55b69294670f2e6b4fc410204224f3a785dee2..075114fc335e9cc9ccfdf035478bce73a125f590 100644 --- a/README.md +++ b/README.md @@ -351,9 +351,12 @@ go build -o opencode ## Acknowledgments -OpenCode builds upon the work of several open source projects and developers: +OpenCode gratefully acknowledges the contributions and support from these key individuals: -- [@isaacphi](https://github.com/isaacphi) - LSP client implementation +- [@isaacphi](https://github.com/isaacphi) - For the [mcp-language-server](https://github.com/isaacphi/mcp-language-server) project which provided the foundation for our LSP client implementation +- [@adamdottv](https://github.com/adamdottv) - For the design direction and UI/UX architecture + +Special thanks to the broader open source community whose tools and libraries have made this project possible. ## License diff --git a/cmd/root.go b/cmd/root.go index f506e99404f2bdc2d0331f592ffe2ab69b560ef5..54280ecaaab7a8dbab3cd5555c53a526ae177134 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -79,7 +79,7 @@ var rootCmd = &cobra.Command{ initMCPTools(ctx, app) // Setup the subscriptions, this will send services events to the TUI - ch, cancelSubs := setupSubscriptions(app) + ch, cancelSubs := setupSubscriptions(app, ctx) // Create a context for the TUI message handler tuiCtx, tuiCancel := context.WithCancel(ctx) @@ -174,21 +174,21 @@ func setupSubscriber[T any]( defer wg.Done() defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil) + subCh := subscriber(ctx) + for { select { - case event, ok := <-subscriber(ctx): + case event, ok := <-subCh: if !ok { logging.Info("%s subscription channel closed", name) return } - // Convert generic event to tea.Msg if needed var msg tea.Msg = event - // Non-blocking send with timeout to prevent deadlocks select { case outputCh <- msg: - case <-time.After(500 * time.Millisecond): + case <-time.After(2 * time.Second): logging.Warn("%s message dropped due to slow consumer", name) case <-ctx.Done(): logging.Info("%s subscription cancelled", name) @@ -202,23 +202,21 @@ func setupSubscriber[T any]( }() } -func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { +func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) { ch := make(chan tea.Msg, 100) - // Add a buffer to prevent blocking + wg := sync.WaitGroup{} - ctx, cancel := context.WithCancel(context.Background()) - // Setup each subscription using the helper + ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context + setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch) setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch) setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch) setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch) - // Return channel and a cleanup function cleanupFunc := func() { logging.Info("Cancelling all subscriptions") cancel() // Signal all goroutines to stop - // Wait with a timeout for all goroutines to complete waitCh := make(chan struct{}) go func() { defer logging.RecoverPanic("subscription-cleanup", nil) @@ -229,11 +227,11 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { select { case <-waitCh: logging.Info("All subscription goroutines completed successfully") + close(ch) // Only close after all writers are confirmed done case <-time.After(5 * time.Second): logging.Warn("Timed out waiting for some subscription goroutines to complete") + close(ch) } - - close(ch) // Safe to close after all writers are done or timed out } return ch, cleanupFunc } diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index d73accffba6ed7695cd845143ef7a611a8058682..0de1be063b05e522c951ee9fe25c9358cf44ef52 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -5,47 +5,53 @@ import ( "sync" ) -const bufferSize = 1024 +const bufferSize = 64 -// Broker allows clients to publish events and subscribe to events type Broker[T any] struct { - subs map[chan Event[T]]struct{} // subscriptions - mu sync.Mutex // sync access to map - done chan struct{} // close when broker is shutting down + subs map[chan Event[T]]struct{} + mu sync.RWMutex + done chan struct{} + subCount int + maxEvents int } -// NewBroker constructs a pub/sub broker. func NewBroker[T any]() *Broker[T] { + return NewBrokerWithOptions[T](bufferSize, 1000) +} + +func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] { b := &Broker[T]{ - subs: make(map[chan Event[T]]struct{}), - done: make(chan struct{}), + subs: make(map[chan Event[T]]struct{}), + done: make(chan struct{}), + subCount: 0, + maxEvents: maxEvents, } return b } -// Shutdown the broker, terminating any subscriptions. func (b *Broker[T]) Shutdown() { - close(b.done) + select { + case <-b.done: // Already closed + return + default: + close(b.done) + } b.mu.Lock() defer b.mu.Unlock() - // Remove each subscriber entry, so Publish() cannot send any further - // messages, and close each subscriber's channel, so the subscriber cannot - // consume any more messages. for ch := range b.subs { delete(b.subs, ch) close(ch) } + + b.subCount = 0 } -// Subscribe subscribes the caller to a stream of events. The returned channel -// is closed when the broker is shutdown. func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] { b.mu.Lock() defer b.mu.Unlock() - // Check if broker has shutdown and if so return closed channel select { case <-b.done: ch := make(chan Event[T]) @@ -54,18 +60,16 @@ func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] { default: } - // Subscribe sub := make(chan Event[T], bufferSize) b.subs[sub] = struct{}{} + b.subCount++ - // Unsubscribe when context is done. go func() { <-ctx.Done() b.mu.Lock() defer b.mu.Unlock() - // Check if broker has shutdown and if so do nothing select { case <-b.done: return @@ -74,21 +78,39 @@ func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] { delete(b.subs, sub) close(sub) + b.subCount-- }() return sub } -// Publish an event to subscribers. +func (b *Broker[T]) GetSubscriberCount() int { + b.mu.RLock() + defer b.mu.RUnlock() + return b.subCount +} + func (b *Broker[T]) Publish(t EventType, payload T) { - b.mu.Lock() - defer b.mu.Unlock() + b.mu.RLock() + select { + case <-b.done: + b.mu.RUnlock() + return + default: + } + subscribers := make([]chan Event[T], 0, len(b.subs)) for sub := range b.subs { + subscribers = append(subscribers, sub) + } + b.mu.RUnlock() + + event := Event[T]{Type: t, Payload: payload} + + for _, sub := range subscribers { select { - case sub <- Event[T]{Type: t, Payload: payload}: - case <-b.done: - return + case sub <- event: + default: } } } diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go index b09cc449507881b6f5030fe4f1eaabc286784219..03a50541e4da1a8f2bdd7a73a9e626f5e9c098da 100644 --- a/internal/tui/components/chat/list.go +++ b/internal/tui/components/chat/list.go @@ -370,6 +370,7 @@ func (m *messagesCmp) SetSize(width, height int) tea.Cmd { delete(m.cachedContent, msg.ID) } m.uiMessages = make([]uiMessage, 0) + m.renderView() return nil } diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index 5a2114e8363dbb424db41be18908bb50570a5c40..8bf3e516614ea92b1c57b7efc3a2f24980c9d1dc 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -18,6 +18,11 @@ import ( "github.com/kujtimiihoxha/opencode/internal/tui/util" ) +type StatusCmp interface { + tea.Model + SetHelpMsg(string) +} + type statusCmp struct { info util.InfoMsg width int @@ -146,7 +151,7 @@ func (m *statusCmp) projectDiagnostics() string { break } } - + // If any server is initializing, show that status if initializing { return lipgloss.NewStyle(). @@ -154,7 +159,7 @@ func (m *statusCmp) projectDiagnostics() string { Foreground(styles.Peach). Render(fmt.Sprintf("%s Initializing LSP...", styles.SpinnerIcon)) } - + errorDiagnostics := []protocol.Diagnostic{} warnDiagnostics := []protocol.Diagnostic{} hintDiagnostics := []protocol.Diagnostic{} @@ -235,7 +240,11 @@ func (m statusCmp) model() string { return styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render(model.Name) } -func NewStatusCmp(lspClients map[string]*lsp.Client) tea.Model { +func (m statusCmp) SetHelpMsg(s string) { + helpWidget = styles.Padded.Background(styles.Forground).Foreground(styles.BackgroundDarker).Bold(true).Render(s) +} + +func NewStatusCmp(lspClients map[string]*lsp.Client) StatusCmp { return &statusCmp{ messageTTL: 10 * time.Second, lspClients: lspClients, diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 2a9ed0d70d193f0fdb568cd372b8d31241f43b87..dec43f7c074a71d435c0f442bec40823e6b0fe2f 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -39,12 +39,18 @@ var keys = keyMap{ key.WithKeys("ctrl+_"), key.WithHelp("ctrl+?", "toggle help"), ), + SwitchSession: key.NewBinding( key.WithKeys("ctrl+a"), key.WithHelp("ctrl+a", "switch session"), ), } +var helpEsc = key.NewBinding( + key.WithKeys("?"), + key.WithHelp("?", "toggle help"), +) + var returnKey = key.NewBinding( key.WithKeys("esc"), key.WithHelp("esc", "close"), @@ -61,7 +67,7 @@ type appModel struct { previousPage page.PageID pages map[page.PageID]tea.Model loadedPages map[page.PageID]bool - status tea.Model + status core.StatusCmp app *app.App showPermissions bool @@ -75,6 +81,8 @@ type appModel struct { showSessionDialog bool sessionDialog dialog.SessionDialog + + editingMode bool } func (a appModel) Init() tea.Cmd { @@ -101,7 +109,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { msg.Height -= 1 // Make space for the status bar a.width, a.height = msg.Width, msg.Height - a.status, _ = a.status.Update(msg) + s, _ := a.status.Update(msg) + a.status = s.(core.StatusCmp) a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) @@ -118,45 +127,56 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, sessionCmd) return a, tea.Batch(cmds...) - + case chat.EditorFocusMsg: + a.editingMode = bool(msg) // Status case util.InfoMsg: - a.status, cmd = a.status.Update(msg) + s, cmd := a.status.Update(msg) + a.status = s.(core.StatusCmp) cmds = append(cmds, cmd) return a, tea.Batch(cmds...) case pubsub.Event[logging.LogMessage]: if msg.Payload.Persist { switch msg.Payload.Level { case "error": - a.status, cmd = a.status.Update(util.InfoMsg{ + s, cmd := a.status.Update(util.InfoMsg{ Type: util.InfoTypeError, Msg: msg.Payload.Message, TTL: msg.Payload.PersistTime, }) + a.status = s.(core.StatusCmp) + cmds = append(cmds, cmd) case "info": - a.status, cmd = a.status.Update(util.InfoMsg{ + s, cmd := a.status.Update(util.InfoMsg{ Type: util.InfoTypeInfo, Msg: msg.Payload.Message, TTL: msg.Payload.PersistTime, }) + a.status = s.(core.StatusCmp) + cmds = append(cmds, cmd) + case "warn": - a.status, cmd = a.status.Update(util.InfoMsg{ + s, cmd := a.status.Update(util.InfoMsg{ Type: util.InfoTypeWarn, Msg: msg.Payload.Message, TTL: msg.Payload.PersistTime, }) + a.status = s.(core.StatusCmp) + cmds = append(cmds, cmd) default: - a.status, cmd = a.status.Update(util.InfoMsg{ + s, cmd := a.status.Update(util.InfoMsg{ Type: util.InfoTypeInfo, Msg: msg.Payload.Message, TTL: msg.Payload.PersistTime, }) + a.status = s.(core.StatusCmp) + cmds = append(cmds, cmd) } - cmds = append(cmds, cmd) } case util.ClearStatusMsg: - a.status, _ = a.status.Update(msg) + s, _ := a.status.Update(msg) + a.status = s.(core.StatusCmp) // Permission case pubsub.Event[permission.PermissionRequest]: @@ -243,7 +263,16 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } a.showHelp = !a.showHelp return a, nil + case key.Matches(msg, helpEsc): + if !a.editingMode { + if a.showQuit { + return a, nil + } + a.showHelp = !a.showHelp + return a, nil + } } + } if a.showQuit { @@ -275,7 +304,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } - a.status, _ = a.status.Update(msg) + s, _ := a.status.Update(msg) + a.status = s.(core.StatusCmp) a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) cmds = append(cmds, cmd) return a, tea.Batch(cmds...) @@ -326,6 +356,12 @@ func (a appModel) View() string { ) } + if a.editingMode { + a.status.SetHelpMsg("ctrl+? help") + } else { + a.status.SetHelpMsg("? help") + } + if a.showHelp { bindings := layout.KeyMapToSlice(keys) if p, ok := a.pages[a.currentPage].(layout.Bindings); ok { @@ -337,7 +373,9 @@ func (a appModel) View() string { if a.currentPage == page.LogsPage { bindings = append(bindings, logsKeyReturnKey) } - + if !a.editingMode { + bindings = append(bindings, helpEsc) + } a.help.SetBindings(bindings) overlay := a.help.View() @@ -398,6 +436,7 @@ func New(app *app.App) tea.Model { sessionDialog: dialog.NewSessionDialogCmp(), permissions: dialog.NewPermissionDialogCmp(), app: app, + editingMode: true, pages: map[page.PageID]tea.Model{ page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), From 9ae6af8856ca6a13d575ec6a8989a5f6ee4297b1 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 13:46:32 +0200 Subject: [PATCH 35/41] remove old logs --- internal/llm/tools/edit.go | 11 ++++++----- internal/llm/tools/patch.go | 7 ++++--- internal/llm/tools/write.go | 5 +++-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 83cec5dbafbca173f19ca06bd9a275dc5faa1b28..b7b813ca7a5045784e12fbf5b8d4033f41049caa 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -12,6 +12,7 @@ import ( "github.com/kujtimiihoxha/opencode/internal/config" "github.com/kujtimiihoxha/opencode/internal/diff" "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/kujtimiihoxha/opencode/internal/permission" ) @@ -227,7 +228,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) _, err = e.files.CreateVersion(ctx, sessionID, filePath, content) if err != nil { // Log error but don't fail the operation - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } recordFileWrite(filePath) @@ -334,13 +335,13 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string // User Manually changed the content store an intermediate version _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } } // Store the new version _, err = e.files.CreateVersion(ctx, sessionID, filePath, "") if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } recordFileWrite(filePath) @@ -448,13 +449,13 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS // User Manually changed the content store an intermediate version _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent) if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } } // Store the new version _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent) if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } recordFileWrite(filePath) diff --git a/internal/llm/tools/patch.go b/internal/llm/tools/patch.go index 92eab69297a2513530a84a99cf411dc50cc4e359..903404497bff31e174f47127435634d892139498 100644 --- a/internal/llm/tools/patch.go +++ b/internal/llm/tools/patch.go @@ -11,6 +11,7 @@ import ( "github.com/kujtimiihoxha/opencode/internal/config" "github.com/kujtimiihoxha/opencode/internal/diff" "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/kujtimiihoxha/opencode/internal/permission" ) @@ -314,7 +315,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error // If not adding a file, create history entry for existing file _, err = p.files.Create(ctx, sessionID, absPath, oldContent) if err != nil { - fmt.Printf("Error creating file history: %v\n", err) + logging.Debug("Error creating file history", "error", err) } } @@ -322,7 +323,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error // User manually changed content, store intermediate version _, err = p.files.CreateVersion(ctx, sessionID, absPath, oldContent) if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } } @@ -333,7 +334,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error _, err = p.files.CreateVersion(ctx, sessionID, absPath, newContent) } if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } // Record file operations diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 261865c398b061264a23c43aff9b3fcef0ee1283..2b3fa3dd077bc5a087bb8af3fecfa60b43298a4d 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -11,6 +11,7 @@ import ( "github.com/kujtimiihoxha/opencode/internal/config" "github.com/kujtimiihoxha/opencode/internal/diff" "github.com/kujtimiihoxha/opencode/internal/history" + "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/kujtimiihoxha/opencode/internal/lsp" "github.com/kujtimiihoxha/opencode/internal/permission" ) @@ -192,13 +193,13 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error // User Manually changed the content store an intermediate version _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent) if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } } // Store the new version _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content) if err != nil { - fmt.Printf("Error creating file history version: %v\n", err) + logging.Debug("Error creating file history version", "error", err) } recordFileWrite(filePath) From a8d5787e8ef561037f73b669128f46ae1b1e8553 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 14:29:03 +0200 Subject: [PATCH 36/41] config validation --- .opencode.json | 1 + cmd/schema/README.md | 64 +++++++ cmd/schema/main.go | 262 ++++++++++++++++++++++++++++ internal/config/config.go | 276 +++++++++++++++++++++++++++++- internal/llm/agent/agent.go | 2 +- internal/llm/tools/edit.go | 27 ++- internal/llm/tools/write.go | 11 +- internal/permission/permission.go | 10 +- internal/tui/tui.go | 4 +- internal/version/version.go | 2 +- opencode-schema.json | 269 +++++++++++++++++++++++++++++ 11 files changed, 911 insertions(+), 17 deletions(-) create mode 100644 cmd/schema/README.md create mode 100644 cmd/schema/main.go create mode 100644 opencode-schema.json diff --git a/.opencode.json b/.opencode.json index b7fc19b524371cf7e4a625173f2fe305914694d3..c4d1547a0c62aad24a470af1d503c225a5b5955b 100644 --- a/.opencode.json +++ b/.opencode.json @@ -1,4 +1,5 @@ { + "$schema": "./opencode-schema.json", "lsp": { "gopls": { "command": "gopls" diff --git a/cmd/schema/README.md b/cmd/schema/README.md new file mode 100644 index 0000000000000000000000000000000000000000..93ebe9f03bd81bc7f679328b1b2171b7ce831954 --- /dev/null +++ b/cmd/schema/README.md @@ -0,0 +1,64 @@ +# OpenCode Configuration Schema Generator + +This tool generates a JSON Schema for the OpenCode configuration file. The schema can be used to validate configuration files and provide autocompletion in editors that support JSON Schema. + +## Usage + +```bash +go run cmd/schema/main.go > opencode-schema.json +``` + +This will generate a JSON Schema file that can be used to validate configuration files. + +## Schema Features + +The generated schema includes: + +- All configuration options with descriptions +- Default values where applicable +- Validation for enum values (e.g., model IDs, provider types) +- Required fields +- Type checking + +## Using the Schema + +You can use the generated schema in several ways: + +1. **Editor Integration**: Many editors (VS Code, JetBrains IDEs, etc.) support JSON Schema for validation and autocompletion. You can configure your editor to use the generated schema for `.opencode.json` files. + +2. **Validation Tools**: You can use tools like [jsonschema](https://github.com/Julian/jsonschema) to validate your configuration files against the schema. + +3. **Documentation**: The schema serves as documentation for the configuration options. + +## Example Configuration + +Here's an example configuration that conforms to the schema: + +```json +{ + "data": { + "directory": ".opencode" + }, + "debug": false, + "providers": { + "anthropic": { + "apiKey": "your-api-key" + } + }, + "agents": { + "coder": { + "model": "claude-3.7-sonnet", + "maxTokens": 5000, + "reasoningEffort": "medium" + }, + "task": { + "model": "claude-3.7-sonnet", + "maxTokens": 5000 + }, + "title": { + "model": "claude-3.7-sonnet", + "maxTokens": 80 + } + } +} +``` \ No newline at end of file diff --git a/cmd/schema/main.go b/cmd/schema/main.go new file mode 100644 index 0000000000000000000000000000000000000000..030c0907e1866066201729d19acf081355ed651e --- /dev/null +++ b/cmd/schema/main.go @@ -0,0 +1,262 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/kujtimiihoxha/opencode/internal/config" + "github.com/kujtimiihoxha/opencode/internal/llm/models" +) + +// JSONSchemaType represents a JSON Schema type +type JSONSchemaType struct { + Type string `json:"type,omitempty"` + Description string `json:"description,omitempty"` + Properties map[string]any `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` + AdditionalProperties any `json:"additionalProperties,omitempty"` + Enum []any `json:"enum,omitempty"` + Items map[string]any `json:"items,omitempty"` + OneOf []map[string]any `json:"oneOf,omitempty"` + AnyOf []map[string]any `json:"anyOf,omitempty"` + Default any `json:"default,omitempty"` +} + +func main() { + schema := generateSchema() + + // Pretty print the schema + encoder := json.NewEncoder(os.Stdout) + encoder.SetIndent("", " ") + if err := encoder.Encode(schema); err != nil { + fmt.Fprintf(os.Stderr, "Error encoding schema: %v\n", err) + os.Exit(1) + } +} + +func generateSchema() map[string]any { + schema := map[string]any{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "OpenCode Configuration", + "description": "Configuration schema for the OpenCode application", + "type": "object", + "properties": map[string]any{}, + } + + // Add Data configuration + schema["properties"].(map[string]any)["data"] = map[string]any{ + "type": "object", + "description": "Storage configuration", + "properties": map[string]any{ + "directory": map[string]any{ + "type": "string", + "description": "Directory where application data is stored", + "default": ".opencode", + }, + }, + "required": []string{"directory"}, + } + + // Add working directory + schema["properties"].(map[string]any)["wd"] = map[string]any{ + "type": "string", + "description": "Working directory for the application", + } + + // Add debug flags + schema["properties"].(map[string]any)["debug"] = map[string]any{ + "type": "boolean", + "description": "Enable debug mode", + "default": false, + } + + schema["properties"].(map[string]any)["debugLSP"] = map[string]any{ + "type": "boolean", + "description": "Enable LSP debug mode", + "default": false, + } + + // Add MCP servers + schema["properties"].(map[string]any)["mcpServers"] = map[string]any{ + "type": "object", + "description": "Model Control Protocol server configurations", + "additionalProperties": map[string]any{ + "type": "object", + "description": "MCP server configuration", + "properties": map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "Command to execute for the MCP server", + }, + "env": map[string]any{ + "type": "array", + "description": "Environment variables for the MCP server", + "items": map[string]any{ + "type": "string", + }, + }, + "args": map[string]any{ + "type": "array", + "description": "Command arguments for the MCP server", + "items": map[string]any{ + "type": "string", + }, + }, + "type": map[string]any{ + "type": "string", + "description": "Type of MCP server", + "enum": []string{"stdio", "sse"}, + "default": "stdio", + }, + "url": map[string]any{ + "type": "string", + "description": "URL for SSE type MCP servers", + }, + "headers": map[string]any{ + "type": "object", + "description": "HTTP headers for SSE type MCP servers", + "additionalProperties": map[string]any{ + "type": "string", + }, + }, + }, + "required": []string{"command"}, + }, + } + + // Add providers + providerSchema := map[string]any{ + "type": "object", + "description": "LLM provider configurations", + "additionalProperties": map[string]any{ + "type": "object", + "description": "Provider configuration", + "properties": map[string]any{ + "apiKey": map[string]any{ + "type": "string", + "description": "API key for the provider", + }, + "disabled": map[string]any{ + "type": "boolean", + "description": "Whether the provider is disabled", + "default": false, + }, + }, + }, + } + + // Add known providers + knownProviders := []string{ + string(models.ProviderAnthropic), + string(models.ProviderOpenAI), + string(models.ProviderGemini), + string(models.ProviderGROQ), + string(models.ProviderBedrock), + } + + providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{ + "type": "string", + "description": "Provider type", + "enum": knownProviders, + } + + schema["properties"].(map[string]any)["providers"] = providerSchema + + // Add agents + agentSchema := map[string]any{ + "type": "object", + "description": "Agent configurations", + "additionalProperties": map[string]any{ + "type": "object", + "description": "Agent configuration", + "properties": map[string]any{ + "model": map[string]any{ + "type": "string", + "description": "Model ID for the agent", + }, + "maxTokens": map[string]any{ + "type": "integer", + "description": "Maximum tokens for the agent", + "minimum": 1, + }, + "reasoningEffort": map[string]any{ + "type": "string", + "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", + "enum": []string{"low", "medium", "high"}, + }, + }, + "required": []string{"model"}, + }, + } + + // Add model enum + modelEnum := []string{} + for modelID := range models.SupportedModels { + modelEnum = append(modelEnum, string(modelID)) + } + agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum + + // Add specific agent properties + agentProperties := map[string]any{} + knownAgents := []string{ + string(config.AgentCoder), + string(config.AgentTask), + string(config.AgentTitle), + } + + for _, agentName := range knownAgents { + agentProperties[agentName] = map[string]any{ + "$ref": "#/definitions/agent", + } + } + + // Create a combined schema that allows both specific agents and additional ones + combinedAgentSchema := map[string]any{ + "type": "object", + "description": "Agent configurations", + "properties": agentProperties, + "additionalProperties": agentSchema["additionalProperties"], + } + + schema["properties"].(map[string]any)["agents"] = combinedAgentSchema + schema["definitions"] = map[string]any{ + "agent": agentSchema["additionalProperties"], + } + + // Add LSP configuration + schema["properties"].(map[string]any)["lsp"] = map[string]any{ + "type": "object", + "description": "Language Server Protocol configurations", + "additionalProperties": map[string]any{ + "type": "object", + "description": "LSP configuration for a language", + "properties": map[string]any{ + "disabled": map[string]any{ + "type": "boolean", + "description": "Whether the LSP is disabled", + "default": false, + }, + "command": map[string]any{ + "type": "string", + "description": "Command to execute for the LSP server", + }, + "args": map[string]any{ + "type": "array", + "description": "Command arguments for the LSP server", + "items": map[string]any{ + "type": "string", + }, + }, + "options": map[string]any{ + "type": "object", + "description": "Additional options for the LSP server", + }, + }, + "required": []string{"command"}, + }, + } + + return schema +} + diff --git a/internal/config/config.go b/internal/config/config.go index 2dbbcc9ca5153d518e6b9592fef8503d6b38b5d0..13c7d13284f6e3b7e187227b442a3af654cf4443 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -120,13 +120,11 @@ func Load(workingDir string, debug bool) (*Config, error) { } applyDefaultValues() - defaultLevel := slog.LevelInfo if cfg.Debug { defaultLevel = slog.LevelDebug } - // if we are in debug mode make the writer a file - if cfg.Debug { + if os.Getenv("OPENCODE_DEV_DEBUG") == "true" { loggingFile := fmt.Sprintf("%s/%s", cfg.Data.Directory, "debug.log") // if file does not exist create it @@ -156,6 +154,11 @@ func Load(workingDir string, debug bool) (*Config, error) { slog.SetDefault(logger) } + // Validate configuration + if err := Validate(); err != nil { + return cfg, fmt.Errorf("config validation failed: %w", err) + } + if cfg.Agents == nil { cfg.Agents = make(map[AgentName]Agent) } @@ -302,6 +305,273 @@ func applyDefaultValues() { } } +// Validate checks if the configuration is valid and applies defaults where needed. +// It validates model IDs and providers, ensuring they are supported. +func Validate() error { + if cfg == nil { + return fmt.Errorf("config not loaded") + } + + // Validate agent models + for name, agent := range cfg.Agents { + // Check if model exists + model, modelExists := models.SupportedModels[agent.Model] + if !modelExists { + logging.Warn("unsupported model configured, reverting to default", + "agent", name, + "configured_model", agent.Model) + + // Set default model based on available providers + if setDefaultModelForAgent(name) { + logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) + } else { + return fmt.Errorf("no valid provider available for agent %s", name) + } + continue + } + + // Check if provider for the model is configured + provider := model.Provider + providerCfg, providerExists := cfg.Providers[provider] + + if !providerExists { + // Provider not configured, check if we have environment variables + apiKey := getProviderAPIKey(provider) + if apiKey == "" { + logging.Warn("provider not configured for model, reverting to default", + "agent", name, + "model", agent.Model, + "provider", provider) + + // Set default model based on available providers + if setDefaultModelForAgent(name) { + logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) + } else { + return fmt.Errorf("no valid provider available for agent %s", name) + } + } else { + // Add provider with API key from environment + cfg.Providers[provider] = Provider{ + APIKey: apiKey, + } + logging.Info("added provider from environment", "provider", provider) + } + } else if providerCfg.Disabled || providerCfg.APIKey == "" { + // Provider is disabled or has no API key + logging.Warn("provider is disabled or has no API key, reverting to default", + "agent", name, + "model", agent.Model, + "provider", provider) + + // Set default model based on available providers + if setDefaultModelForAgent(name) { + logging.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model) + } else { + return fmt.Errorf("no valid provider available for agent %s", name) + } + } + + // Validate max tokens + if agent.MaxTokens <= 0 { + logging.Warn("invalid max tokens, setting to default", + "agent", name, + "model", agent.Model, + "max_tokens", agent.MaxTokens) + + // Update the agent with default max tokens + updatedAgent := cfg.Agents[name] + if model.DefaultMaxTokens > 0 { + updatedAgent.MaxTokens = model.DefaultMaxTokens + } else { + updatedAgent.MaxTokens = 4096 // Fallback default + } + cfg.Agents[name] = updatedAgent + } else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 { + // Ensure max tokens doesn't exceed half the context window (reasonable limit) + logging.Warn("max tokens exceeds half the context window, adjusting", + "agent", name, + "model", agent.Model, + "max_tokens", agent.MaxTokens, + "context_window", model.ContextWindow) + + // Update the agent with adjusted max tokens + updatedAgent := cfg.Agents[name] + updatedAgent.MaxTokens = model.ContextWindow / 2 + cfg.Agents[name] = updatedAgent + } + + // Validate reasoning effort for models that support reasoning + if model.CanReason && provider == models.ProviderOpenAI { + if agent.ReasoningEffort == "" { + // Set default reasoning effort for models that support it + logging.Info("setting default reasoning effort for model that supports reasoning", + "agent", name, + "model", agent.Model) + + // Update the agent with default reasoning effort + updatedAgent := cfg.Agents[name] + updatedAgent.ReasoningEffort = "medium" + cfg.Agents[name] = updatedAgent + } else { + // Check if reasoning effort is valid (low, medium, high) + effort := strings.ToLower(agent.ReasoningEffort) + if effort != "low" && effort != "medium" && effort != "high" { + logging.Warn("invalid reasoning effort, setting to medium", + "agent", name, + "model", agent.Model, + "reasoning_effort", agent.ReasoningEffort) + + // Update the agent with valid reasoning effort + updatedAgent := cfg.Agents[name] + updatedAgent.ReasoningEffort = "medium" + cfg.Agents[name] = updatedAgent + } + } + } else if !model.CanReason && agent.ReasoningEffort != "" { + // Model doesn't support reasoning but reasoning effort is set + logging.Warn("model doesn't support reasoning but reasoning effort is set, ignoring", + "agent", name, + "model", agent.Model, + "reasoning_effort", agent.ReasoningEffort) + + // Update the agent to remove reasoning effort + updatedAgent := cfg.Agents[name] + updatedAgent.ReasoningEffort = "" + cfg.Agents[name] = updatedAgent + } + } + + // Validate providers + for provider, providerCfg := range cfg.Providers { + if providerCfg.APIKey == "" && !providerCfg.Disabled { + logging.Warn("provider has no API key, marking as disabled", "provider", provider) + providerCfg.Disabled = true + cfg.Providers[provider] = providerCfg + } + } + + // Validate LSP configurations + for language, lspConfig := range cfg.LSP { + if lspConfig.Command == "" && !lspConfig.Disabled { + logging.Warn("LSP configuration has no command, marking as disabled", "language", language) + lspConfig.Disabled = true + cfg.LSP[language] = lspConfig + } + } + + return nil +} + +// getProviderAPIKey gets the API key for a provider from environment variables +func getProviderAPIKey(provider models.ModelProvider) string { + switch provider { + case models.ProviderAnthropic: + return os.Getenv("ANTHROPIC_API_KEY") + case models.ProviderOpenAI: + return os.Getenv("OPENAI_API_KEY") + case models.ProviderGemini: + return os.Getenv("GEMINI_API_KEY") + case models.ProviderGROQ: + return os.Getenv("GROQ_API_KEY") + case models.ProviderBedrock: + if hasAWSCredentials() { + return "aws-credentials-available" + } + } + return "" +} + +// setDefaultModelForAgent sets a default model for an agent based on available providers +func setDefaultModelForAgent(agent AgentName) bool { + // Check providers in order of preference + if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { + maxTokens := int64(5000) + if agent == AgentTitle { + maxTokens = 80 + } + cfg.Agents[agent] = Agent{ + Model: models.Claude37Sonnet, + MaxTokens: maxTokens, + } + return true + } + + if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { + var model models.ModelID + maxTokens := int64(5000) + reasoningEffort := "" + + switch agent { + case AgentTitle: + model = models.GPT41Mini + maxTokens = 80 + case AgentTask: + model = models.GPT41Mini + default: + model = models.GPT41 + } + + // Check if model supports reasoning + if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason { + reasoningEffort = "medium" + } + + cfg.Agents[agent] = Agent{ + Model: model, + MaxTokens: maxTokens, + ReasoningEffort: reasoningEffort, + } + return true + } + + if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" { + var model models.ModelID + maxTokens := int64(5000) + + if agent == AgentTitle { + model = models.Gemini25Flash + maxTokens = 80 + } else { + model = models.Gemini25 + } + + cfg.Agents[agent] = Agent{ + Model: model, + MaxTokens: maxTokens, + } + return true + } + + if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { + maxTokens := int64(5000) + if agent == AgentTitle { + maxTokens = 80 + } + + cfg.Agents[agent] = Agent{ + Model: models.QWENQwq, + MaxTokens: maxTokens, + } + return true + } + + if hasAWSCredentials() { + maxTokens := int64(5000) + if agent == AgentTitle { + maxTokens = 80 + } + + cfg.Agents[agent] = Agent{ + Model: models.BedrockClaude37Sonnet, + MaxTokens: maxTokens, + ReasoningEffort: "medium", // Claude models support reasoning + } + return true + } + + return false +} + // Get returns the current configuration. // It's safe to call this function multiple times. func Get() *Config { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index ae5bcb23178810a31dcc2c4b63e5cd8486390f85..6c5808eabcc15eab4cda5337bfce99885c8974f9 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -471,7 +471,7 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error) provider.WithReasoningEffort(agentConfig.ReasoningEffort), ), ) - } else if model.Provider == models.ProviderAnthropic && model.CanReason { + } else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder { opts = append( opts, provider.WithAnthropicOptions( diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index b7b813ca7a5045784e12fbf5b8d4033f41049caa..23c44399b2fcf39e4fae14c751a90958246da729 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -196,11 +196,16 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) content, filePath, ) + rootDir := config.WorkingDirectory() + permissionPath := filepath.Dir(filePath) + if strings.HasPrefix(filePath, rootDir) { + permissionPath = rootDir + } p := e.permissions.Request( permission.CreatePermissionRequest{ - Path: filepath.Dir(filePath), + Path: permissionPath, ToolName: EditToolName, - Action: "create", + Action: "write", Description: fmt.Sprintf("Create file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, @@ -301,11 +306,16 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string filePath, ) + rootDir := config.WorkingDirectory() + permissionPath := filepath.Dir(filePath) + if strings.HasPrefix(filePath, rootDir) { + permissionPath = rootDir + } p := e.permissions.Request( permission.CreatePermissionRequest{ - Path: filepath.Dir(filePath), + Path: permissionPath, ToolName: EditToolName, - Action: "delete", + Action: "write", Description: fmt.Sprintf("Delete content from file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, @@ -415,11 +425,16 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS newContent, filePath, ) + rootDir := config.WorkingDirectory() + permissionPath := filepath.Dir(filePath) + if strings.HasPrefix(filePath, rootDir) { + permissionPath = rootDir + } p := e.permissions.Request( permission.CreatePermissionRequest{ - Path: filepath.Dir(filePath), + Path: permissionPath, ToolName: EditToolName, - Action: "replace", + Action: "write", Description: fmt.Sprintf("Replace content in file %s", filePath), Params: EditPermissionsParams{ FilePath: filePath, diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 2b3fa3dd077bc5a087bb8af3fecfa60b43298a4d..3a94b47b6437e7b36197864fe0ee2e223b1f0bf3 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "time" "github.com/kujtimiihoxha/opencode/internal/config" @@ -159,11 +160,17 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error params.Content, filePath, ) + + rootDir := config.WorkingDirectory() + permissionPath := filepath.Dir(filePath) + if strings.HasPrefix(filePath, rootDir) { + permissionPath = rootDir + } p := w.permissions.Request( permission.CreatePermissionRequest{ - Path: filePath, + Path: permissionPath, ToolName: WriteToolName, - Action: "create", + Action: "write", Description: fmt.Sprintf("Create file %s", filePath), Params: WritePermissionsParams{ FilePath: filePath, diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 4cb379dea133a4b101f16d987f34c4647a8dd2b5..06f69a33dfcc73f3afa6ec60b60eb495bd37e46b 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -2,10 +2,12 @@ package permission import ( "errors" + "path/filepath" "sync" "time" "github.com/google/uuid" + "github.com/kujtimiihoxha/opencode/internal/config" "github.com/kujtimiihoxha/opencode/internal/pubsub" ) @@ -67,9 +69,13 @@ func (s *permissionService) Deny(permission PermissionRequest) { } func (s *permissionService) Request(opts CreatePermissionRequest) bool { + dir := filepath.Dir(opts.Path) + if dir == "." { + dir = config.WorkingDirectory() + } permission := PermissionRequest{ ID: uuid.New().String(), - Path: opts.Path, + Path: dir, ToolName: opts.ToolName, Description: opts.Description, Action: opts.Action, @@ -77,7 +83,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { } for _, p := range s.sessionPermissions { - if p.ToolName == permission.ToolName && p.Action == permission.Action { + if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path { return true } } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index dec43f7c074a71d435c0f442bec40823e6b0fe2f..392b9ec41f636864df0b783926be76468f743367 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -57,8 +57,8 @@ var returnKey = key.NewBinding( ) var logsKeyReturnKey = key.NewBinding( - key.WithKeys("backspace"), - key.WithHelp("backspace", "go back"), + key.WithKeys("backspace", "q"), + key.WithHelp("backspace/q", "go back"), ) type appModel struct { diff --git a/internal/version/version.go b/internal/version/version.go index 54c576f6c2605f8c678212d776d22f3a103c5656..1e19bea3883db14820ad6fef5e0879848a23395c 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -5,7 +5,7 @@ import "runtime/debug" // Build-time parameters set via -ldflags var Version = "unknown" -// A user may install pug using `go install github.com/leg100/pug@latest` +// A user may install pug using `go install github.com/kujtimiihoxha/opencode@latest`. // without -ldflags, in which case the version above is unset. As a workaround // we use the embedded build version that *is* set when using `go install` (and // is only set for `go install` and not for `go build`). diff --git a/opencode-schema.json b/opencode-schema.json new file mode 100644 index 0000000000000000000000000000000000000000..452790cdfb9c86dcf496a1e9e62d113b5b3577ab --- /dev/null +++ b/opencode-schema.json @@ -0,0 +1,269 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "agent": { + "description": "Agent configuration", + "properties": { + "maxTokens": { + "description": "Maximum tokens for the agent", + "minimum": 1, + "type": "integer" + }, + "model": { + "description": "Model ID for the agent", + "enum": [ + "gemini-2.0-flash", + "bedrock.claude-3.7-sonnet", + "claude-3-opus", + "claude-3.5-sonnet", + "gpt-4o-mini", + "o1", + "o3-mini", + "o1-pro", + "o4-mini", + "claude-3-haiku", + "gpt-4o", + "o3", + "gpt-4.1-mini", + "gpt-4.5-preview", + "gemini-2.5-flash", + "claude-3.5-haiku", + "gpt-4.1", + "gemini-2.0-flash-lite", + "claude-3.7-sonnet", + "o1-mini", + "gpt-4.1-nano", + "gemini-2.5" + ], + "type": "string" + }, + "reasoningEffort": { + "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", + "enum": [ + "low", + "medium", + "high" + ], + "type": "string" + } + }, + "required": [ + "model" + ], + "type": "object" + } + }, + "description": "Configuration schema for the OpenCode application", + "properties": { + "agents": { + "additionalProperties": { + "description": "Agent configuration", + "properties": { + "maxTokens": { + "description": "Maximum tokens for the agent", + "minimum": 1, + "type": "integer" + }, + "model": { + "description": "Model ID for the agent", + "enum": [ + "gemini-2.0-flash", + "bedrock.claude-3.7-sonnet", + "claude-3-opus", + "claude-3.5-sonnet", + "gpt-4o-mini", + "o1", + "o3-mini", + "o1-pro", + "o4-mini", + "claude-3-haiku", + "gpt-4o", + "o3", + "gpt-4.1-mini", + "gpt-4.5-preview", + "gemini-2.5-flash", + "claude-3.5-haiku", + "gpt-4.1", + "gemini-2.0-flash-lite", + "claude-3.7-sonnet", + "o1-mini", + "gpt-4.1-nano", + "gemini-2.5" + ], + "type": "string" + }, + "reasoningEffort": { + "description": "Reasoning effort for models that support it (OpenAI, Anthropic)", + "enum": [ + "low", + "medium", + "high" + ], + "type": "string" + } + }, + "required": [ + "model" + ], + "type": "object" + }, + "description": "Agent configurations", + "properties": { + "coder": { + "$ref": "#/definitions/agent" + }, + "task": { + "$ref": "#/definitions/agent" + }, + "title": { + "$ref": "#/definitions/agent" + } + }, + "type": "object" + }, + "data": { + "description": "Storage configuration", + "properties": { + "directory": { + "default": ".opencode", + "description": "Directory where application data is stored", + "type": "string" + } + }, + "required": [ + "directory" + ], + "type": "object" + }, + "debug": { + "default": false, + "description": "Enable debug mode", + "type": "boolean" + }, + "debugLSP": { + "default": false, + "description": "Enable LSP debug mode", + "type": "boolean" + }, + "lsp": { + "additionalProperties": { + "description": "LSP configuration for a language", + "properties": { + "args": { + "description": "Command arguments for the LSP server", + "items": { + "type": "string" + }, + "type": "array" + }, + "command": { + "description": "Command to execute for the LSP server", + "type": "string" + }, + "disabled": { + "default": false, + "description": "Whether the LSP is disabled", + "type": "boolean" + }, + "options": { + "description": "Additional options for the LSP server", + "type": "object" + } + }, + "required": [ + "command" + ], + "type": "object" + }, + "description": "Language Server Protocol configurations", + "type": "object" + }, + "mcpServers": { + "additionalProperties": { + "description": "MCP server configuration", + "properties": { + "args": { + "description": "Command arguments for the MCP server", + "items": { + "type": "string" + }, + "type": "array" + }, + "command": { + "description": "Command to execute for the MCP server", + "type": "string" + }, + "env": { + "description": "Environment variables for the MCP server", + "items": { + "type": "string" + }, + "type": "array" + }, + "headers": { + "additionalProperties": { + "type": "string" + }, + "description": "HTTP headers for SSE type MCP servers", + "type": "object" + }, + "type": { + "default": "stdio", + "description": "Type of MCP server", + "enum": [ + "stdio", + "sse" + ], + "type": "string" + }, + "url": { + "description": "URL for SSE type MCP servers", + "type": "string" + } + }, + "required": [ + "command" + ], + "type": "object" + }, + "description": "Model Control Protocol server configurations", + "type": "object" + }, + "providers": { + "additionalProperties": { + "description": "Provider configuration", + "properties": { + "apiKey": { + "description": "API key for the provider", + "type": "string" + }, + "disabled": { + "default": false, + "description": "Whether the provider is disabled", + "type": "boolean" + }, + "provider": { + "description": "Provider type", + "enum": [ + "anthropic", + "openai", + "gemini", + "groq", + "bedrock" + ], + "type": "string" + } + }, + "type": "object" + }, + "description": "LLM provider configurations", + "type": "object" + }, + "wd": { + "description": "Working directory for the application", + "type": "string" + } + }, + "title": "OpenCode Configuration", + "type": "object" +} From 1e11805efc9f3feaf9b9696bcaa8a8dd599db0b1 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 15:52:32 +0200 Subject: [PATCH 37/41] add description --- cmd/root.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 54280ecaaab7a8dbab3cd5555c53a526ae177134..545652a7a9e938e2b7a2598e253287146b98f543 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -21,8 +21,10 @@ import ( var rootCmd = &cobra.Command{ Use: "OpenCode", - Short: "A terminal ai assistant", - Long: `A terminal ai assistant`, + Short: "A terminal AI assistant for software development", + Long: `OpenCode is a powerful terminal-based AI assistant that helps with software development tasks. +It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration +to assist developers in writing, debugging, and understanding code directly from the terminal.`, RunE: func(cmd *cobra.Command, args []string) error { // If the help flag is set, show the help message if cmd.Flag("help").Changed { From d03a73a8d36565cf00ccdee0b1689f295999ad51 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 16:22:52 +0200 Subject: [PATCH 38/41] update readme --- README.md | 131 ++++++++++++++++++++++++++---------------------------- 1 file changed, 64 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 075114fc335e9cc9ccfdf035478bce73a125f590..7e8c2791b6f5a5c8a57f400d807b2c54f3038967 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ OpenCode is a Go-based CLI application that brings AI assistance to your termina - **Persistent Storage**: SQLite database for storing conversations and sessions - **LSP Integration**: Language Server Protocol support for code intelligence - **File Change Tracking**: Track and visualize file changes during sessions +- **External Editor Support**: Open your preferred editor for composing messages ## Installation @@ -100,41 +101,31 @@ You can configure OpenCode using environment variables: ## Supported AI Models -### OpenAI Models - -| Model ID | Name | Context Window | -| ----------------- | --------------- | ---------------- | -| `gpt-4.1` | GPT 4.1 | 1,047,576 tokens | -| `gpt-4.1-mini` | GPT 4.1 Mini | 200,000 tokens | -| `gpt-4.1-nano` | GPT 4.1 Nano | 1,047,576 tokens | -| `gpt-4.5-preview` | GPT 4.5 Preview | 128,000 tokens | -| `gpt-4o` | GPT-4o | 128,000 tokens | -| `gpt-4o-mini` | GPT-4o Mini | 128,000 tokens | -| `o1` | O1 | 200,000 tokens | -| `o1-pro` | O1 Pro | 200,000 tokens | -| `o1-mini` | O1 Mini | 128,000 tokens | -| `o3` | O3 | 200,000 tokens | -| `o3-mini` | O3 Mini | 200,000 tokens | -| `o4-mini` | O4 Mini | 128,000 tokens | - -### Anthropic Models - -| Model ID | Name | Context Window | -| ------------------- | ----------------- | -------------- | -| `claude-3.5-sonnet` | Claude 3.5 Sonnet | 200,000 tokens | -| `claude-3-haiku` | Claude 3 Haiku | 200,000 tokens | -| `claude-3.7-sonnet` | Claude 3.7 Sonnet | 200,000 tokens | -| `claude-3.5-haiku` | Claude 3.5 Haiku | 200,000 tokens | -| `claude-3-opus` | Claude 3 Opus | 200,000 tokens | - -### Other Models - -| Model ID | Provider | Name | Context Window | -| --------------------------- | ----------- | ----------------- | -------------- | -| `gemini-2.5` | Google | Gemini 2.5 Pro | - | -| `gemini-2.0-flash` | Google | Gemini 2.0 Flash | - | -| `qwen-qwq` | Groq | Qwen Qwq | - | -| `bedrock.claude-3.7-sonnet` | AWS Bedrock | Claude 3.7 Sonnet | - | +OpenCode supports a variety of AI models from different providers: + +### OpenAI +- GPT-4.1 family (gpt-4.1, gpt-4.1-mini, gpt-4.1-nano) +- GPT-4.5 Preview +- GPT-4o family (gpt-4o, gpt-4o-mini) +- O1 family (o1, o1-pro, o1-mini) +- O3 family (o3, o3-mini) +- O4 Mini + +### Anthropic +- Claude 3.5 Sonnet +- Claude 3.5 Haiku +- Claude 3.7 Sonnet +- Claude 3 Haiku +- Claude 3 Opus + +### Google +- Gemini 2.5 +- Gemini 2.5 Flash +- Gemini 2.0 Flash +- Gemini 2.0 Flash Lite + +### AWS Bedrock +- Claude 3.7 Sonnet ## Usage @@ -161,12 +152,14 @@ opencode -c /path/to/project ### Global Shortcuts -| Shortcut | Action | -| -------- | ------------------------------------------------------- | -| `Ctrl+C` | Quit application | -| `Ctrl+?` | Toggle help dialog | -| `Ctrl+L` | View logs | -| `Esc` | Close current overlay/dialog or return to previous mode | +| Shortcut | Action | +| --------- | ------------------------------------------------------- | +| `Ctrl+C` | Quit application | +| `Ctrl+?` | Toggle help dialog | +| `?` | Toggle help dialog (when not in editing mode) | +| `Ctrl+L` | View logs | +| `Ctrl+A` | Switch session | +| `Esc` | Close current overlay/dialog or return to previous mode | ### Chat Page Shortcuts @@ -183,13 +176,34 @@ opencode -c /path/to/project | ------------------- | ----------------------------------------- | | `Ctrl+S` | Send message (when editor is focused) | | `Enter` or `Ctrl+S` | Send message (when editor is not focused) | +| `Ctrl+E` | Open external editor | | `Esc` | Blur editor and focus messages | +### Session Dialog Shortcuts + +| Shortcut | Action | +| ------------- | ---------------- | +| `↑` or `k` | Previous session | +| `↓` or `j` | Next session | +| `Enter` | Select session | +| `Esc` | Close dialog | + +### Permission Dialog Shortcuts + +| Shortcut | Action | +| ------------------------- | ----------------------- | +| `←` or `left` | Switch options left | +| `→` or `right` or `tab` | Switch options right | +| `Enter` or `space` | Confirm selection | +| `a` | Allow permission | +| `A` | Allow permission for session | +| `d` | Deny permission | + ### Logs Page Shortcuts -| Shortcut | Action | -| ----------- | ------------------- | -| `Backspace` | Return to chat page | +| Shortcut | Action | +| ---------------- | ------------------- | +| `Backspace` or `q` | Return to chat page | ## AI Assistant Tools @@ -275,28 +289,13 @@ Once configured, MCP tools are automatically available to the AI assistant along ## LSP (Language Server Protocol) -OpenCode integrates with Language Server Protocol to provide rich code intelligence features across multiple programming languages. +OpenCode integrates with Language Server Protocol to provide code intelligence features across multiple programming languages. ### LSP Features - **Multi-language Support**: Connect to language servers for different programming languages -- **Code Intelligence**: Get diagnostics, completions, and navigation assistance +- **Diagnostics**: Receive error checking and linting information - **File Watching**: Automatically notify language servers of file changes -- **Diagnostics**: Display errors, warnings, and hints in your code - -### Supported LSP Features - -| Feature | Description | -| ----------------- | ----------------------------------- | -| Diagnostics | Error checking and linting | -| Completions | Code suggestions and autocompletion | -| Hover | Documentation on hover | -| Definition | Go to definition | -| References | Find all references | -| Document Symbols | Navigate symbols in current file | -| Workspace Symbols | Search symbols across workspace | -| Formatting | Code formatting | -| Code Actions | Quick fixes and refactorings | ### Configuring LSP @@ -324,13 +323,14 @@ The AI assistant can access LSP features through the `diagnostics` tool, allowin - Check for errors in your code - Suggest fixes based on diagnostics -- Provide intelligent code assistance + +While the LSP client implementation supports the full LSP protocol (including completions, hover, definition, etc.), currently only diagnostics are exposed to the AI assistant. ## Development ### Prerequisites -- Go 1.23.5 or higher +- Go 1.24.0 or higher ### Building from Source @@ -339,9 +339,6 @@ The AI assistant can access LSP features through the `diagnostics` tool, allowin git clone https://github.com/kujtimiihoxha/opencode.git cd opencode -# Build the diff script first -go run cmd/diff/main.go - # Build go build -o opencode @@ -372,4 +369,4 @@ Contributions are welcome! Here's how you can contribute: 4. Push to the branch (`git push origin feature/amazing-feature`) 5. Open a Pull Request -Please make sure to update tests as appropriate and follow the existing code style. +Please make sure to update tests as appropriate and follow the existing code style. \ No newline at end of file From ed3518d0755cb5cae25d9d8f1690ab2e60702588 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 16:24:38 +0200 Subject: [PATCH 39/41] small things --- cmd/root.go | 8 ++++---- internal/db/connect.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 545652a7a9e938e2b7a2598e253287146b98f543..8777acb823229ad82eed591fa1206fcab9ccc27c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -182,7 +182,7 @@ func setupSubscriber[T any]( select { case event, ok := <-subCh: if !ok { - logging.Info("%s subscription channel closed", name) + logging.Info("subscription channel closed", "name", name) return } @@ -191,13 +191,13 @@ func setupSubscriber[T any]( select { case outputCh <- msg: case <-time.After(2 * time.Second): - logging.Warn("%s message dropped due to slow consumer", name) + logging.Warn("message dropped due to slow consumer", "name", name) case <-ctx.Done(): - logging.Info("%s subscription cancelled", name) + logging.Info("subscription cancelled", "name", name) return } case <-ctx.Done(): - logging.Info("%s subscription cancelled", name) + logging.Info("subscription cancelled", "name", name) return } } diff --git a/internal/db/connect.go b/internal/db/connect.go index e850bc8d02a4e9c685e1e5dc71260eacb44a0563..9335bfc263c288f4ca95cffb3bb212a99b0bcec5 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -48,9 +48,9 @@ func Connect() (*sql.DB, error) { for _, pragma := range pragmas { if _, err = db.Exec(pragma); err != nil { - logging.Warn("Failed to set pragma", pragma, err) + logging.Error("Failed to set pragma", pragma, err) } else { - logging.Warn("Set pragma", "pragma", pragma) + logging.Debug("Set pragma", "pragma", pragma) } } From d7569d79c6da1437fe46343ed13810df6c8cae1f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 16:50:10 +0200 Subject: [PATCH 40/41] prevent the editor when agent busy --- internal/tui/components/chat/editor.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/tui/components/chat/editor.go b/internal/tui/components/chat/editor.go index 963fbbdbfce8d233d05df1fbd912af2eec2651f7..4f6937039b7a79f0687d47c846380cc311252ce4 100644 --- a/internal/tui/components/chat/editor.go +++ b/internal/tui/components/chat/editor.go @@ -132,6 +132,9 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case tea.KeyMsg: if key.Matches(msg, focusedKeyMaps.OpenEditor) { + if m.app.CoderAgent.IsSessionBusy(m.session.ID) { + return m, util.ReportWarn("Agent is working, please wait...") + } return m, openEditor() } // if the key does not match any binding, return From 3a6a26981a8074b6ab0eaadb520db986e04799ff Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 21 Apr 2025 19:48:36 +0200 Subject: [PATCH 41/41] init command --- README.md | 55 ++--- internal/config/init.go | 61 +++++ internal/llm/agent/mcp-tools.go | 5 + internal/llm/prompt/prompt.go | 16 +- internal/llm/prompt/title.go | 1 + internal/llm/tools/bash.go | 8 +- internal/llm/tools/edit.go | 3 + internal/llm/tools/fetch.go | 6 + internal/llm/tools/patch.go | 3 + internal/llm/tools/write.go | 1 + internal/permission/permission.go | 16 +- internal/tui/components/dialog/commands.go | 247 +++++++++++++++++++++ internal/tui/components/dialog/init.go | 191 ++++++++++++++++ internal/tui/tui.go | 175 ++++++++++++++- 14 files changed, 753 insertions(+), 35 deletions(-) create mode 100644 internal/config/init.go create mode 100644 internal/tui/components/dialog/commands.go create mode 100644 internal/tui/components/dialog/init.go diff --git a/README.md b/README.md index 7e8c2791b6f5a5c8a57f400d807b2c54f3038967..145a881f58c7fdc633bc27b2015f66ddf59ee298 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ You can configure OpenCode using environment variables: OpenCode supports a variety of AI models from different providers: ### OpenAI + - GPT-4.1 family (gpt-4.1, gpt-4.1-mini, gpt-4.1-nano) - GPT-4.5 Preview - GPT-4o family (gpt-4o, gpt-4o-mini) @@ -112,6 +113,7 @@ OpenCode supports a variety of AI models from different providers: - O4 Mini ### Anthropic + - Claude 3.5 Sonnet - Claude 3.5 Haiku - Claude 3.7 Sonnet @@ -119,12 +121,14 @@ OpenCode supports a variety of AI models from different providers: - Claude 3 Opus ### Google + - Gemini 2.5 - Gemini 2.5 Flash - Gemini 2.0 Flash - Gemini 2.0 Flash Lite ### AWS Bedrock + - Claude 3.7 Sonnet ## Usage @@ -152,14 +156,15 @@ opencode -c /path/to/project ### Global Shortcuts -| Shortcut | Action | -| --------- | ------------------------------------------------------- | -| `Ctrl+C` | Quit application | -| `Ctrl+?` | Toggle help dialog | -| `?` | Toggle help dialog (when not in editing mode) | -| `Ctrl+L` | View logs | -| `Ctrl+A` | Switch session | -| `Esc` | Close current overlay/dialog or return to previous mode | +| Shortcut | Action | +| -------- | ------------------------------------------------------- | +| `Ctrl+C` | Quit application | +| `Ctrl+?` | Toggle help dialog | +| `?` | Toggle help dialog (when not in editing mode) | +| `Ctrl+L` | View logs | +| `Ctrl+A` | Switch session | +| `Ctrl+K` | Command dialog | +| `Esc` | Close current overlay/dialog or return to previous mode | ### Chat Page Shortcuts @@ -181,28 +186,28 @@ opencode -c /path/to/project ### Session Dialog Shortcuts -| Shortcut | Action | -| ------------- | ---------------- | -| `↑` or `k` | Previous session | -| `↓` or `j` | Next session | -| `Enter` | Select session | -| `Esc` | Close dialog | +| Shortcut | Action | +| ---------- | ---------------- | +| `↑` or `k` | Previous session | +| `↓` or `j` | Next session | +| `Enter` | Select session | +| `Esc` | Close dialog | ### Permission Dialog Shortcuts -| Shortcut | Action | -| ------------------------- | ----------------------- | -| `←` or `left` | Switch options left | -| `→` or `right` or `tab` | Switch options right | -| `Enter` or `space` | Confirm selection | -| `a` | Allow permission | -| `A` | Allow permission for session | -| `d` | Deny permission | +| Shortcut | Action | +| ----------------------- | ---------------------------- | +| `←` or `left` | Switch options left | +| `→` or `right` or `tab` | Switch options right | +| `Enter` or `space` | Confirm selection | +| `a` | Allow permission | +| `A` | Allow permission for session | +| `d` | Deny permission | ### Logs Page Shortcuts -| Shortcut | Action | -| ---------------- | ------------------- | +| Shortcut | Action | +| ------------------ | ------------------- | | `Backspace` or `q` | Return to chat page | ## AI Assistant Tools @@ -369,4 +374,4 @@ Contributions are welcome! Here's how you can contribute: 4. Push to the branch (`git push origin feature/amazing-feature`) 5. Open a Pull Request -Please make sure to update tests as appropriate and follow the existing code style. \ No newline at end of file +Please make sure to update tests as appropriate and follow the existing code style. diff --git a/internal/config/init.go b/internal/config/init.go new file mode 100644 index 0000000000000000000000000000000000000000..e0a1c6da7372fb3c66656d18bdf565357b6b1b07 --- /dev/null +++ b/internal/config/init.go @@ -0,0 +1,61 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" +) + +const ( + // InitFlagFilename is the name of the file that indicates whether the project has been initialized + InitFlagFilename = "init" +) + +// ProjectInitFlag represents the initialization status for a project directory +type ProjectInitFlag struct { + Initialized bool `json:"initialized"` +} + +// ShouldShowInitDialog checks if the initialization dialog should be shown for the current directory +func ShouldShowInitDialog() (bool, error) { + if cfg == nil { + return false, fmt.Errorf("config not loaded") + } + + // Create the flag file path + flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename) + + // Check if the flag file exists + _, err := os.Stat(flagFilePath) + if err == nil { + // File exists, don't show the dialog + return false, nil + } + + // If the error is not "file not found", return the error + if !os.IsNotExist(err) { + return false, fmt.Errorf("failed to check init flag file: %w", err) + } + + // File doesn't exist, show the dialog + return true, nil +} + +// MarkProjectInitialized marks the current project as initialized +func MarkProjectInitialized() error { + if cfg == nil { + return fmt.Errorf("config not loaded") + } + // Create the flag file path + flagFilePath := filepath.Join(cfg.Data.Directory, InitFlagFilename) + + // Create an empty file to mark the project as initialized + file, err := os.Create(flagFilePath) + if err != nil { + return fmt.Errorf("failed to create init flag file: %w", err) + } + defer file.Close() + + return nil +} + diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 16dddc1ba467dc3ae2dad62dab03bc981ff97226..53aada33fb91351e77d78295b94e31c0c566b8a9 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -80,9 +80,14 @@ func runTool(ctx context.Context, c MCPClient, toolName string, input string) (t } func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { + sessionID, messageID := tools.GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") + } permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input) p := b.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: config.WorkingDirectory(), ToolName: b.Info().Name, Action: "execute", diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index cf4d9a7e723a285d3529b230ad184893215cdd96..a6b4c03fb31a9ae3b472fd1c0c42c58dc879e146 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -14,8 +14,13 @@ var contextFiles = []string{ ".github/copilot-instructions.md", ".cursorrules", "CLAUDE.md", + "CLAUDE.local.md", "opencode.md", + "opencode.local.md", "OpenCode.md", + "OpenCode.local.md", + "OPENCODE.md", + "OPENCODE.local.md", } func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string { @@ -31,12 +36,13 @@ func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) s basePrompt = "You are a helpful assistant" } - // Add context from project-specific instruction files if they exist - contextContent := getContextFromFiles() - if contextContent != "" { - return fmt.Sprintf("%s\n\n# Project-Specific Context\n%s", basePrompt, contextContent) + if agentName == config.AgentCoder || agentName == config.AgentTask { + // Add context from project-specific instruction files if they exist + contextContent := getContextFromFiles() + if contextContent != "" { + return fmt.Sprintf("%s\n\n# Project-Specific Context\n%s", basePrompt, contextContent) + } } - return basePrompt } diff --git a/internal/llm/prompt/title.go b/internal/llm/prompt/title.go index 6e5289b24e984f92d9dbb9b86bfd530dfb4ae441..5656360da516b6be35f1e10917ea067048676eaf 100644 --- a/internal/llm/prompt/title.go +++ b/internal/llm/prompt/title.go @@ -6,6 +6,7 @@ func TitlePrompt(_ models.ModelProvider) string { return `you will generate a short title based on the first message a user begins a conversation with - ensure it is not more than 50 characters long - the title should be a summary of the user's message +- it should be one line long - do not use quotes or colons - the entire text you return will be used as the title` } diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 18533b761d0cfeca17f842df0db1e3756c38619d..a175061972506d252434a0069586dbf28fae645f 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -51,7 +51,7 @@ var safeReadOnlyCommands = []string{ "git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote", "git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog", - "go version", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean", + "go version", "go help", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean", } func bashDescription() string { @@ -261,9 +261,15 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } } } + + sessionID, messageID := GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") + } if !isSafeReadOnly { p := b.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: config.WorkingDirectory(), ToolName: BashToolName, Action: "execute", diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index 23c44399b2fcf39e4fae14c751a90958246da729..e2e2578757c05aa1505d9a1ab9c84f44d7408aff 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -203,6 +203,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) } p := e.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: permissionPath, ToolName: EditToolName, Action: "write", @@ -313,6 +314,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string } p := e.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: permissionPath, ToolName: EditToolName, Action: "write", @@ -432,6 +434,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS } p := e.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: permissionPath, ToolName: EditToolName, Action: "write", diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 827755863d053ab9e3073c859523a5cc8048500b..47ff03e5740f8c896bf97b5f975ee1e4d20da183 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -116,8 +116,14 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return NewTextErrorResponse("URL must start with http:// or https://"), nil } + sessionID, messageID := GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") + } + p := t.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: config.WorkingDirectory(), ToolName: FetchToolName, Action: "fetch", diff --git a/internal/llm/tools/patch.go b/internal/llm/tools/patch.go index 903404497bff31e174f47127435634d892139498..7e20e378e28924ed4c2b1ca91eb5478520c772ed 100644 --- a/internal/llm/tools/patch.go +++ b/internal/llm/tools/patch.go @@ -194,6 +194,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error patchDiff, _, _ := diff.GenerateDiff("", *change.NewContent, path) p := p.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: dir, ToolName: PatchToolName, Action: "create", @@ -220,6 +221,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error dir := filepath.Dir(path) p := p.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: dir, ToolName: PatchToolName, Action: "update", @@ -238,6 +240,7 @@ func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error patchDiff, _, _ := diff.GenerateDiff(*change.OldContent, "", path) p := p.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: dir, ToolName: PatchToolName, Action: "delete", diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 3a94b47b6437e7b36197864fe0ee2e223b1f0bf3..ec6fc1dc4f0db310056119a2a2a2cdb02c3bf916 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -168,6 +168,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error } p := w.permissions.Request( permission.CreatePermissionRequest{ + SessionID: sessionID, Path: permissionPath, ToolName: WriteToolName, Action: "write", diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 06f69a33dfcc73f3afa6ec60b60eb495bd37e46b..f36efea652fc728ddfd22b64e3f2e71de8549d96 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -3,6 +3,7 @@ package permission import ( "errors" "path/filepath" + "slices" "sync" "time" @@ -14,6 +15,7 @@ import ( var ErrorPermissionDenied = errors.New("permission denied") type CreatePermissionRequest struct { + SessionID string `json:"session_id"` ToolName string `json:"tool_name"` Description string `json:"description"` Action string `json:"action"` @@ -37,13 +39,15 @@ type Service interface { Grant(permission PermissionRequest) Deny(permission PermissionRequest) Request(opts CreatePermissionRequest) bool + AutoApproveSession(sessionID string) } type permissionService struct { *pubsub.Broker[PermissionRequest] - sessionPermissions []PermissionRequest - pendingRequests sync.Map + sessionPermissions []PermissionRequest + pendingRequests sync.Map + autoApproveSessions []string } func (s *permissionService) GrantPersistant(permission PermissionRequest) { @@ -69,6 +73,9 @@ func (s *permissionService) Deny(permission PermissionRequest) { } func (s *permissionService) Request(opts CreatePermissionRequest) bool { + if slices.Contains(s.autoApproveSessions, opts.SessionID) { + return true + } dir := filepath.Dir(opts.Path) if dir == "." { dir = config.WorkingDirectory() @@ -76,6 +83,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { permission := PermissionRequest{ ID: uuid.New().String(), Path: dir, + SessionID: opts.SessionID, ToolName: opts.ToolName, Description: opts.Description, Action: opts.Action, @@ -104,6 +112,10 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { } } +func (s *permissionService) AutoApproveSession(sessionID string) { + s.autoApproveSessions = append(s.autoApproveSessions, sessionID) +} + func NewPermissionService() Service { return &permissionService{ Broker: pubsub.NewBroker[PermissionRequest](), diff --git a/internal/tui/components/dialog/commands.go b/internal/tui/components/dialog/commands.go new file mode 100644 index 0000000000000000000000000000000000000000..7b25caeb04f4e1129b3fa550a4939c7d7d8ad9c8 --- /dev/null +++ b/internal/tui/components/dialog/commands.go @@ -0,0 +1,247 @@ +package dialog + +import ( + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/kujtimiihoxha/opencode/internal/tui/layout" + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" +) + +// Command represents a command that can be executed +type Command struct { + ID string + Title string + Description string + Handler func(cmd Command) tea.Cmd +} + +// CommandSelectedMsg is sent when a command is selected +type CommandSelectedMsg struct { + Command Command +} + +// CloseCommandDialogMsg is sent when the command dialog is closed +type CloseCommandDialogMsg struct{} + +// CommandDialog interface for the command selection dialog +type CommandDialog interface { + tea.Model + layout.Bindings + SetCommands(commands []Command) + SetSelectedCommand(commandID string) +} + +type commandDialogCmp struct { + commands []Command + selectedIdx int + width int + height int + selectedCommandID string +} + +type commandKeyMap struct { + Up key.Binding + Down key.Binding + Enter key.Binding + Escape key.Binding + J key.Binding + K key.Binding +} + +var commandKeys = commandKeyMap{ + Up: key.NewBinding( + key.WithKeys("up"), + key.WithHelp("↑", "previous command"), + ), + Down: key.NewBinding( + key.WithKeys("down"), + key.WithHelp("↓", "next command"), + ), + Enter: key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "select command"), + ), + Escape: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "close"), + ), + J: key.NewBinding( + key.WithKeys("j"), + key.WithHelp("j", "next command"), + ), + K: key.NewBinding( + key.WithKeys("k"), + key.WithHelp("k", "previous command"), + ), +} + +func (c *commandDialogCmp) Init() tea.Cmd { + return nil +} + +func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch { + case key.Matches(msg, commandKeys.Up) || key.Matches(msg, commandKeys.K): + if c.selectedIdx > 0 { + c.selectedIdx-- + } + return c, nil + case key.Matches(msg, commandKeys.Down) || key.Matches(msg, commandKeys.J): + if c.selectedIdx < len(c.commands)-1 { + c.selectedIdx++ + } + return c, nil + case key.Matches(msg, commandKeys.Enter): + if len(c.commands) > 0 { + return c, util.CmdHandler(CommandSelectedMsg{ + Command: c.commands[c.selectedIdx], + }) + } + case key.Matches(msg, commandKeys.Escape): + return c, util.CmdHandler(CloseCommandDialogMsg{}) + } + case tea.WindowSizeMsg: + c.width = msg.Width + c.height = msg.Height + } + return c, nil +} + +func (c *commandDialogCmp) View() string { + if len(c.commands) == 0 { + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(40). + Render("No commands available") + } + + // Calculate max width needed for command titles + maxWidth := 40 // Minimum width + for _, cmd := range c.commands { + if len(cmd.Title) > maxWidth-4 { // Account for padding + maxWidth = len(cmd.Title) + 4 + } + if len(cmd.Description) > maxWidth-4 { + maxWidth = len(cmd.Description) + 4 + } + } + + // Limit height to avoid taking up too much screen space + maxVisibleCommands := min(10, len(c.commands)) + + // Build the command list + commandItems := make([]string, 0, maxVisibleCommands) + startIdx := 0 + + // If we have more commands than can be displayed, adjust the start index + if len(c.commands) > maxVisibleCommands { + // Center the selected item when possible + halfVisible := maxVisibleCommands / 2 + if c.selectedIdx >= halfVisible && c.selectedIdx < len(c.commands)-halfVisible { + startIdx = c.selectedIdx - halfVisible + } else if c.selectedIdx >= len(c.commands)-halfVisible { + startIdx = len(c.commands) - maxVisibleCommands + } + } + + endIdx := min(startIdx+maxVisibleCommands, len(c.commands)) + + for i := startIdx; i < endIdx; i++ { + cmd := c.commands[i] + itemStyle := styles.BaseStyle.Width(maxWidth) + descStyle := styles.BaseStyle.Width(maxWidth).Foreground(styles.ForgroundDim) + + if i == c.selectedIdx { + itemStyle = itemStyle. + Background(styles.PrimaryColor). + Foreground(styles.Background). + Bold(true) + descStyle = descStyle. + Background(styles.PrimaryColor). + Foreground(styles.Background) + } + + title := itemStyle.Padding(0, 1).Render(cmd.Title) + description := "" + if cmd.Description != "" { + description = descStyle.Padding(0, 1).Render(cmd.Description) + commandItems = append(commandItems, lipgloss.JoinVertical(lipgloss.Left, title, description)) + } else { + commandItems = append(commandItems, title) + } + } + + title := styles.BaseStyle. + Foreground(styles.PrimaryColor). + Bold(true). + Width(maxWidth). + Padding(0, 1). + Render("Commands") + + content := lipgloss.JoinVertical( + lipgloss.Left, + title, + styles.BaseStyle.Width(maxWidth).Render(""), + styles.BaseStyle.Width(maxWidth).Render(lipgloss.JoinVertical(lipgloss.Left, commandItems...)), + styles.BaseStyle.Width(maxWidth).Render(""), + styles.BaseStyle.Width(maxWidth).Padding(0, 1).Foreground(styles.ForgroundDim).Render("↑/k: up ↓/j: down enter: select esc: cancel"), + ) + + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(lipgloss.Width(content) + 4). + Render(content) +} + +func (c *commandDialogCmp) BindingKeys() []key.Binding { + return layout.KeyMapToSlice(commandKeys) +} + +func (c *commandDialogCmp) SetCommands(commands []Command) { + c.commands = commands + + // If we have a selected command ID, find its index + if c.selectedCommandID != "" { + for i, cmd := range commands { + if cmd.ID == c.selectedCommandID { + c.selectedIdx = i + return + } + } + } + + // Default to first command if selected not found + c.selectedIdx = 0 +} + +func (c *commandDialogCmp) SetSelectedCommand(commandID string) { + c.selectedCommandID = commandID + + // Update the selected index if commands are already loaded + if len(c.commands) > 0 { + for i, cmd := range c.commands { + if cmd.ID == commandID { + c.selectedIdx = i + return + } + } + } +} + +// NewCommandDialogCmp creates a new command selection dialog +func NewCommandDialogCmp() CommandDialog { + return &commandDialogCmp{ + commands: []Command{}, + selectedIdx: 0, + selectedCommandID: "", + } +} + diff --git a/internal/tui/components/dialog/init.go b/internal/tui/components/dialog/init.go new file mode 100644 index 0000000000000000000000000000000000000000..6098ca755627d06ce5dab26d34ddba8ab281eeef --- /dev/null +++ b/internal/tui/components/dialog/init.go @@ -0,0 +1,191 @@ +package dialog + +import ( + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + + "github.com/kujtimiihoxha/opencode/internal/tui/styles" + "github.com/kujtimiihoxha/opencode/internal/tui/util" +) + +// InitDialogCmp is a component that asks the user if they want to initialize the project. +type InitDialogCmp struct { + width, height int + selected int + keys initDialogKeyMap +} + +// NewInitDialogCmp creates a new InitDialogCmp. +func NewInitDialogCmp() InitDialogCmp { + return InitDialogCmp{ + selected: 0, + keys: initDialogKeyMap{}, + } +} + +type initDialogKeyMap struct { + Tab key.Binding + Left key.Binding + Right key.Binding + Enter key.Binding + Escape key.Binding + Y key.Binding + N key.Binding +} + +// ShortHelp implements key.Map. +func (k initDialogKeyMap) ShortHelp() []key.Binding { + return []key.Binding{ + key.NewBinding( + key.WithKeys("tab", "left", "right"), + key.WithHelp("tab/←/→", "toggle selection"), + ), + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "confirm"), + ), + key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "cancel"), + ), + key.NewBinding( + key.WithKeys("y", "n"), + key.WithHelp("y/n", "yes/no"), + ), + } +} + +// FullHelp implements key.Map. +func (k initDialogKeyMap) FullHelp() [][]key.Binding { + return [][]key.Binding{k.ShortHelp()} +} + +// Init implements tea.Model. +func (m InitDialogCmp) Init() tea.Cmd { + return nil +} + +// Update implements tea.Model. +func (m InitDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch { + case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))): + return m, util.CmdHandler(CloseInitDialogMsg{Initialize: false}) + case key.Matches(msg, key.NewBinding(key.WithKeys("tab", "left", "right", "h", "l"))): + m.selected = (m.selected + 1) % 2 + return m, nil + case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))): + return m, util.CmdHandler(CloseInitDialogMsg{Initialize: m.selected == 0}) + case key.Matches(msg, key.NewBinding(key.WithKeys("y"))): + return m, util.CmdHandler(CloseInitDialogMsg{Initialize: true}) + case key.Matches(msg, key.NewBinding(key.WithKeys("n"))): + return m, util.CmdHandler(CloseInitDialogMsg{Initialize: false}) + } + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + } + return m, nil +} + +// View implements tea.Model. +func (m InitDialogCmp) View() string { + // Calculate width needed for content + maxWidth := 60 // Width for explanation text + + title := styles.BaseStyle. + Foreground(styles.PrimaryColor). + Bold(true). + Width(maxWidth). + Padding(0, 1). + Render("Initialize Project") + + explanation := styles.BaseStyle. + Foreground(styles.Forground). + Width(maxWidth). + Padding(0, 1). + Render("Initialization generates a new OpenCode.md file that contains information about your codebase, this file serves as memory for each project, you can freely add to it to help the agents be better at their job.") + + question := styles.BaseStyle. + Foreground(styles.Forground). + Width(maxWidth). + Padding(1, 1). + Render("Would you like to initialize this project?") + + yesStyle := styles.BaseStyle + noStyle := styles.BaseStyle + + if m.selected == 0 { + yesStyle = yesStyle. + Background(styles.PrimaryColor). + Foreground(styles.Background). + Bold(true) + noStyle = noStyle. + Background(styles.Background). + Foreground(styles.PrimaryColor) + } else { + noStyle = noStyle. + Background(styles.PrimaryColor). + Foreground(styles.Background). + Bold(true) + yesStyle = yesStyle. + Background(styles.Background). + Foreground(styles.PrimaryColor) + } + + yes := yesStyle.Padding(0, 3).Render("Yes") + no := noStyle.Padding(0, 3).Render("No") + + buttons := lipgloss.JoinHorizontal(lipgloss.Center, yes, styles.BaseStyle.Render(" "), no) + buttons = styles.BaseStyle. + Width(maxWidth). + Padding(1, 0). + Render(buttons) + + help := styles.BaseStyle. + Width(maxWidth). + Padding(0, 1). + Foreground(styles.ForgroundDim). + Render("tab/←/→: toggle y/n: yes/no enter: confirm esc: cancel") + + content := lipgloss.JoinVertical( + lipgloss.Left, + title, + styles.BaseStyle.Width(maxWidth).Render(""), + explanation, + question, + buttons, + styles.BaseStyle.Width(maxWidth).Render(""), + help, + ) + + return styles.BaseStyle.Padding(1, 2). + Border(lipgloss.RoundedBorder()). + BorderBackground(styles.Background). + BorderForeground(styles.ForgroundDim). + Width(lipgloss.Width(content) + 4). + Render(content) +} + +// SetSize sets the size of the component. +func (m *InitDialogCmp) SetSize(width, height int) { + m.width = width + m.height = height +} + +// Bindings implements layout.Bindings. +func (m InitDialogCmp) Bindings() []key.Binding { + return m.keys.ShortHelp() +} + +// CloseInitDialogMsg is a message that is sent when the init dialog is closed. +type CloseInitDialogMsg struct { + Initialize bool +} + +// ShowInitDialogMsg is a message that is sent to show the init dialog. +type ShowInitDialogMsg struct { + Show bool +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 392b9ec41f636864df0b783926be76468f743367..4a723d40d5cff700a8cd24ea6e44f7c993429482 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -7,6 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/kujtimiihoxha/opencode/internal/app" + "github.com/kujtimiihoxha/opencode/internal/config" "github.com/kujtimiihoxha/opencode/internal/logging" "github.com/kujtimiihoxha/opencode/internal/permission" "github.com/kujtimiihoxha/opencode/internal/pubsub" @@ -23,6 +24,7 @@ type keyMap struct { Quit key.Binding Help key.Binding SwitchSession key.Binding + Commands key.Binding } var keys = keyMap{ @@ -44,6 +46,11 @@ var keys = keyMap{ key.WithKeys("ctrl+a"), key.WithHelp("ctrl+a", "switch session"), ), + + Commands: key.NewBinding( + key.WithKeys("ctrl+k"), + key.WithHelp("ctrl+K", "commands"), + ), } var helpEsc = key.NewBinding( @@ -82,6 +89,13 @@ type appModel struct { showSessionDialog bool sessionDialog dialog.SessionDialog + showCommandDialog bool + commandDialog dialog.CommandDialog + commands []dialog.Command + + showInitDialog bool + initDialog dialog.InitDialogCmp + editingMode bool } @@ -98,6 +112,23 @@ func (a appModel) Init() tea.Cmd { cmds = append(cmds, cmd) cmd = a.sessionDialog.Init() cmds = append(cmds, cmd) + cmd = a.commandDialog.Init() + cmds = append(cmds, cmd) + cmd = a.initDialog.Init() + cmds = append(cmds, cmd) + + // Check if we should show the init dialog + cmds = append(cmds, func() tea.Msg { + shouldShow, err := config.ShouldShowInitDialog() + if err != nil { + return util.InfoMsg{ + Type: util.InfoTypeError, + Msg: "Failed to check init status: " + err.Error(), + } + } + return dialog.ShowInitDialogMsg{Show: shouldShow} + }) + return tea.Batch(cmds...) } @@ -126,6 +157,12 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.sessionDialog = session.(dialog.SessionDialog) cmds = append(cmds, sessionCmd) + command, commandCmd := a.commandDialog.Update(msg) + a.commandDialog = command.(dialog.CommandDialog) + cmds = append(cmds, commandCmd) + + a.initDialog.SetSize(msg.Width, msg.Height) + return a, tea.Batch(cmds...) case chat.EditorFocusMsg: a.editingMode = bool(msg) @@ -207,6 +244,35 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showSessionDialog = false return a, nil + case dialog.CloseCommandDialogMsg: + a.showCommandDialog = false + return a, nil + + case dialog.ShowInitDialogMsg: + a.showInitDialog = msg.Show + return a, nil + + case dialog.CloseInitDialogMsg: + a.showInitDialog = false + if msg.Initialize { + // Run the initialization command + for _, cmd := range a.commands { + if cmd.ID == "init" { + // Mark the project as initialized + if err := config.MarkProjectInitialized(); err != nil { + return a, util.ReportError(err) + } + return a, cmd.Handler(cmd) + } + } + } else { + // Mark the project as initialized without running the command + if err := config.MarkProjectInitialized(); err != nil { + return a, util.ReportError(err) + } + } + return a, nil + case chat.SessionSelectedMsg: a.sessionDialog.SetSelectedSession(msg.ID) case dialog.SessionSelectedMsg: @@ -216,6 +282,14 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return a, nil + case dialog.CommandSelectedMsg: + a.showCommandDialog = false + // Execute the command handler if available + if msg.Command.Handler != nil { + return a, msg.Command.Handler(msg.Command) + } + return a, util.ReportInfo("Command selected: " + msg.Command.Title) + case tea.KeyMsg: switch { case key.Matches(msg, keys.Quit): @@ -226,9 +300,12 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if a.showSessionDialog { a.showSessionDialog = false } + if a.showCommandDialog { + a.showCommandDialog = false + } return a, nil case key.Matches(msg, keys.SwitchSession): - if a.currentPage == page.ChatPage && !a.showQuit && !a.showPermissions { + if a.currentPage == page.ChatPage && !a.showQuit && !a.showPermissions && !a.showCommandDialog { // Load sessions and show the dialog sessions, err := a.app.Sessions.List(context.Background()) if err != nil { @@ -242,6 +319,17 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, nil } return a, nil + case key.Matches(msg, keys.Commands): + if a.currentPage == page.ChatPage && !a.showQuit && !a.showPermissions && !a.showSessionDialog { + // Show commands dialog + if len(a.commands) == 0 { + return a, util.ReportWarn("No commands available") + } + a.commandDialog.SetCommands(a.commands) + a.showCommandDialog = true + return a, nil + } + return a, nil case key.Matches(msg, logsKeyReturnKey): if a.currentPage == page.LogsPage { return a, a.moveToPage(page.ChatPage) @@ -255,6 +343,14 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.showHelp = !a.showHelp return a, nil } + if a.showInitDialog { + a.showInitDialog = false + // Mark the project as initialized without running the command + if err := config.MarkProjectInitialized(); err != nil { + return a, util.ReportError(err) + } + return a, nil + } case key.Matches(msg, keys.Logs): return a, a.moveToPage(page.LogsPage) case key.Matches(msg, keys.Help): @@ -304,6 +400,26 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } + if a.showCommandDialog { + d, commandCmd := a.commandDialog.Update(msg) + a.commandDialog = d.(dialog.CommandDialog) + cmds = append(cmds, commandCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } + } + + if a.showInitDialog { + d, initCmd := a.initDialog.Update(msg) + a.initDialog = d.(dialog.InitDialogCmp) + cmds = append(cmds, initCmd) + // Only block key messages send all other messages down + if _, ok := msg.(tea.KeyMsg); ok { + return a, tea.Batch(cmds...) + } + } + s, _ := a.status.Update(msg) a.status = s.(core.StatusCmp) a.pages[a.currentPage], cmd = a.pages[a.currentPage].Update(msg) @@ -311,6 +427,11 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, tea.Batch(cmds...) } +// RegisterCommand adds a command to the command dialog +func (a *appModel) RegisterCommand(cmd dialog.Command) { + a.commands = append(a.commands, cmd) +} + func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd { if a.app.CoderAgent.IsBusy() { // For now we don't move to any page if the agent is busy @@ -422,24 +543,74 @@ func (a appModel) View() string { ) } + if a.showCommandDialog { + overlay := a.commandDialog.View() + row := lipgloss.Height(appView) / 2 + row -= lipgloss.Height(overlay) / 2 + col := lipgloss.Width(appView) / 2 + col -= lipgloss.Width(overlay) / 2 + appView = layout.PlaceOverlay( + col, + row, + overlay, + appView, + true, + ) + } + + if a.showInitDialog { + overlay := a.initDialog.View() + appView = layout.PlaceOverlay( + a.width/2-lipgloss.Width(overlay)/2, + a.height/2-lipgloss.Height(overlay)/2, + overlay, + appView, + true, + ) + } + return appView } func New(app *app.App) tea.Model { startPage := page.ChatPage - return &appModel{ + model := &appModel{ currentPage: startPage, loadedPages: make(map[page.PageID]bool), status: core.NewStatusCmp(app.LSPClients), help: dialog.NewHelpCmp(), quit: dialog.NewQuitCmp(), sessionDialog: dialog.NewSessionDialogCmp(), + commandDialog: dialog.NewCommandDialogCmp(), permissions: dialog.NewPermissionDialogCmp(), + initDialog: dialog.NewInitDialogCmp(), app: app, editingMode: true, + commands: []dialog.Command{}, pages: map[page.PageID]tea.Model{ page.ChatPage: page.NewChatPage(app), page.LogsPage: page.NewLogsPage(), }, } + + model.RegisterCommand(dialog.Command{ + ID: "init", + Title: "Initialize Project", + Description: "Create/Update the OpenCode.md memory file", + Handler: func(cmd dialog.Command) tea.Cmd { + prompt := `Please analyze this codebase and create a OpenCode.md file containing: +1. Build/lint/test commands - especially for running a single test +2. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc. + +The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long. +If there's already a opencode.md, improve it. +If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.` + return tea.Batch( + util.CmdHandler(chat.SendMsg{ + Text: prompt, + }), + ) + }, + }) + return model }