From 56b8f43f6a57ad7dbe88a97f6726a07e29115f60 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 23 Jul 2025 15:11:20 -0400 Subject: [PATCH 1/9] feat(tui): completions: dynamically adjust width based on items This will dynamically adjust the width of the completions popup based on the width of the last 10 items in the list, ensuring that the popup fits the content better and avoids unnecessary horizontal scrolling. --- .../tui/components/completions/completions.go | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index 6c63afd22e982e5ba40f5d175fc71449bcd0879e..ed1e90557bc98e87cf799abfb3c29fb28c94007c 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -102,7 +102,7 @@ func (c *completionsCmp) Init() tea.Cmd { func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: - c.width = min(msg.Width-c.x, maxCompletionsWidth) + c.width = min(listWidth(c.list.Items()), maxCompletionsWidth) c.height = min(msg.Height-c.y, 15) return c, nil case tea.KeyPressMsg: @@ -168,10 +168,11 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { item := NewCompletionItem(completion.Title, completion.Value, WithBackgroundColor(t.BgSubtle)) items = append(items, item) } + c.width = listWidth(msg.Completions) c.height = max(min(c.height, len(items)), 1) // Ensure at least 1 item height return c, tea.Batch( - c.list.SetSize(c.width, c.height), c.list.SetItems(items), + c.list.SetSize(c.width, c.height), util.CmdHandler(CompletionsOpenedMsg{}), ) case FilterCompletionsMsg: @@ -195,7 +196,9 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { c.query = msg.Query var cmds []tea.Cmd cmds = append(cmds, c.list.Filter(msg.Query)) - itemsLen := len(c.list.Items()) + items := c.list.Items() + itemsLen := len(items) + c.width = listWidth(items) c.height = max(min(maxCompletionsHeight, itemsLen), 1) cmds = append(cmds, c.list.SetSize(c.width, c.height)) if itemsLen == 0 { @@ -215,15 +218,34 @@ func (c *completionsCmp) View() string { return "" } - return c.style().Render(c.list.View()) -} - -func (c *completionsCmp) style() lipgloss.Style { t := styles.CurrentTheme() - return t.S().Base. + style := t.S().Base. Width(c.width). Height(c.height). Background(t.BgSubtle) + + return style.Render(c.list.View()) +} + +// listWidth returns the width of the last 10 items in the list, which is used +// to determine the width of the completions popup. +// Note this only works for [completionItemCmp] items. +func listWidth[T any](items []T) int { + var width int + if len(items) == 0 { + return width + } + + for i := len(items) - 1; i >= 0 && i >= len(items)-10; i-- { + item, ok := any(items[i]).(*completionItemCmp) + if !ok { + continue + } + itemWidth := lipgloss.Width(item.text) + 2 // +2 for padding + width = max(width, itemWidth) + } + + return width } func (c *completionsCmp) Open() bool { From 6f1f7a20cbeef21454699e48844a4cddb222059a Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 23 Jul 2025 17:52:12 -0400 Subject: [PATCH 2/9] fix(tui): completions: keep track of the popup position --- .../tui/components/completions/completions.go | 38 ++++++++++++++----- internal/tui/tui.go | 10 ----- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index ed1e90557bc98e87cf799abfb3c29fb28c94007c..e8670cec13b7545dbb0bc72d77bbbeface24a920 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -51,18 +51,22 @@ type Completions interface { } type completionsCmp struct { - width int - height int // Height of the completions component` - x int // X position for the completions popup - y int // Y position for the completions popup - open bool // Indicates if the completions are open - keyMap KeyMap + wWidth int // The window width + width int + height int // Height of the completions component` + x, xorig int // X position for the completions popup + y int // Y position for the completions popup + open bool // Indicates if the completions are open + keyMap KeyMap list list.ListModel query string // The current filter query } -const maxCompletionsWidth = 80 // Maximum width for the completions popup +const ( + maxCompletionsWidth = 80 // Maximum width for the completions popup + minCompletionsWidth = 20 // Minimum width for the completions popup +) func New() Completions { completionsKeyMap := DefaultKeyMap() @@ -102,6 +106,7 @@ func (c *completionsCmp) Init() tea.Cmd { func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: + c.wWidth = msg.Width c.width = min(listWidth(c.list.Items()), maxCompletionsWidth) c.height = min(msg.Height-c.y, 15) return c, nil @@ -160,7 +165,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case OpenCompletionsMsg: c.open = true c.query = "" - c.x = msg.X + c.x, c.xorig = msg.X, msg.X c.y = msg.Y items := []util.Model{} t := styles.CurrentTheme() @@ -168,7 +173,14 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { item := NewCompletionItem(completion.Title, completion.Value, WithBackgroundColor(t.BgSubtle)) items = append(items, item) } - c.width = listWidth(msg.Completions) + width := listWidth(items) + if len(items) == 0 { + width = listWidth(c.list.Items()) + } + if c.x+width >= c.wWidth { + c.x = c.wWidth - width - 1 + } + c.width = max(width, c.wWidth-minCompletionsWidth-1) c.height = max(min(c.height, len(items)), 1) // Ensure at least 1 item height return c, tea.Batch( c.list.SetItems(items), @@ -198,7 +210,13 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, c.list.Filter(msg.Query)) items := c.list.Items() itemsLen := len(items) - c.width = listWidth(items) + width := listWidth(items) + if c.x < 0 { + c.x = c.xorig + } else if c.x+width >= c.wWidth { + c.x = c.wWidth - width - 1 + } + c.width = width c.height = max(min(maxCompletionsHeight, itemsLen), 1) cmds = append(cmds, c.list.SetSize(c.width, c.height)) if itemsLen == 0 { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 22a2a52b92c52200d5ecc843c107d2ef33634a1b..16112401290f2e8e6765d7f7ee55b54672190bd7 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -119,16 +119,6 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg: u, completionCmd := a.completions.Update(msg) a.completions = u.(completions.Completions) - switch msg := msg.(type) { - case completions.OpenCompletionsMsg: - x, _ := a.completions.Position() - if a.completions.Width()+x >= a.wWidth { - // Adjust X position to fit in the window. - msg.X = a.wWidth - a.completions.Width() - 1 - u, completionCmd = a.completions.Update(msg) - a.completions = u.(completions.Completions) - } - } return a, completionCmd // Dialog messages From fd1adf07e6152a5c28fb105980bb36b76244f3f2 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 23 Jul 2025 17:58:10 -0400 Subject: [PATCH 3/9] fix(tui): completions: don't set initial width --- internal/tui/components/completions/completions.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index e8670cec13b7545dbb0bc72d77bbbeface24a920..c034d0da4fefaa4731b8eb5899543134f13d1e52 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -107,7 +107,6 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: c.wWidth = msg.Width - c.width = min(listWidth(c.list.Items()), maxCompletionsWidth) c.height = min(msg.Height-c.y, 15) return c, nil case tea.KeyPressMsg: @@ -180,7 +179,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if c.x+width >= c.wWidth { c.x = c.wWidth - width - 1 } - c.width = max(width, c.wWidth-minCompletionsWidth-1) + c.width = width c.height = max(min(c.height, len(items)), 1) // Ensure at least 1 item height return c, tea.Batch( c.list.SetItems(items), From 6dc3cf2f3d0db7e99d28ee7d6d457a345c604915 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 23 Jul 2025 18:01:11 -0400 Subject: [PATCH 4/9] fix(tui): completions: readjust position on filter change --- .../tui/components/completions/completions.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index c034d0da4fefaa4731b8eb5899543134f13d1e52..e82dff1816f53e1467d3c18478eeeee9dbfbb856 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -51,13 +51,14 @@ type Completions interface { } type completionsCmp struct { - wWidth int // The window width - width int - height int // Height of the completions component` - x, xorig int // X position for the completions popup - y int // Y position for the completions popup - open bool // Indicates if the completions are open - keyMap KeyMap + wWidth int // The window width + width int + lastWidth int + height int // Height of the completions component` + x, xorig int // X position for the completions popup + y int // Y position for the completions popup + open bool // Indicates if the completions are open + keyMap KeyMap list list.ListModel query string // The current filter query @@ -210,7 +211,8 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { items := c.list.Items() itemsLen := len(items) width := listWidth(items) - if c.x < 0 { + c.lastWidth = c.width + if c.x < 0 || width < c.lastWidth { c.x = c.xorig } else if c.x+width >= c.wWidth { c.x = c.wWidth - width - 1 From 28aac45067eeab82a5abe887d285d2b2f915ce28 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Wed, 23 Jul 2025 18:23:37 -0400 Subject: [PATCH 5/9] fix(tui): completions: reposition popup on window resize --- internal/tui/components/chat/editor/editor.go | 9 +++++++++ internal/tui/components/completions/completions.go | 12 +++++++++--- internal/tui/page/chat/chat.go | 4 +++- internal/tui/tui.go | 2 +- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index 55a5e7525a430039b314cd810cb94856185cf5af..7f06e69a388e10a49cbb792d0f4b8d231613eb3b 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -161,10 +161,19 @@ func (m *editorCmp) send() tea.Cmd { ) } +func (m *editorCmp) repositionCompletions() tea.Msg { + cur := m.textarea.Cursor() + x := cur.X + m.x // adjust for padding + y := cur.Y + m.y + 1 + return completions.RepositionCompletionsMsg{X: x, Y: y} +} + func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd var cmds []tea.Cmd switch msg := msg.(type) { + case tea.WindowSizeMsg: + return m, m.repositionCompletions case filepicker.FilePickedMsg: if len(m.attachments) >= maxAttachments { return m, util.ReportError(fmt.Errorf("cannot add more than %d images", maxAttachments)) diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index e82dff1816f53e1467d3c18478eeeee9dbfbb856..fae46d70d806f6847eeb40ea9b727da1671145d9 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -29,6 +29,10 @@ type FilterCompletionsMsg struct { Reopen bool } +type RepositionCompletionsMsg struct { + X, Y int +} + type CompletionsClosedMsg struct{} type CompletionsOpenedMsg struct{} @@ -52,6 +56,7 @@ type Completions interface { type completionsCmp struct { wWidth int // The window width + wHeight int // The window height width int lastWidth int height int // Height of the completions component` @@ -88,7 +93,7 @@ func New() Completions { ) return &completionsCmp{ width: 0, - height: 0, + height: maxCompletionsHeight, list: l, query: "", keyMap: completionsKeyMap, @@ -107,8 +112,7 @@ func (c *completionsCmp) Init() tea.Cmd { func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: - c.wWidth = msg.Width - c.height = min(msg.Height-c.y, 15) + c.wWidth, c.wHeight = msg.Width, msg.Height return c, nil case tea.KeyPressMsg: switch { @@ -159,6 +163,8 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case key.Matches(msg, c.keyMap.Cancel): return c, util.CmdHandler(CloseCompletionsMsg{}) } + case RepositionCompletionsMsg: + c.x, c.y = msg.X, msg.Y case CloseCompletionsMsg: c.open = false return c, util.CmdHandler(CompletionsClosedMsg{}) diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 0d28f13f3ca0a42c9ae15612f21678cdeb8f4bf2..5cb6c31f9f5be5554280ef7deba16bd82eb2395d 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -165,7 +165,9 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.keyboardEnhancements = msg return p, nil case tea.WindowSizeMsg: - return p, p.SetSize(msg.Width, msg.Height) + u, cmd := p.editor.Update(msg) + p.editor = u.(editor.Editor) + return p, tea.Batch(p.SetSize(msg.Width, msg.Height), cmd) case CancelTimerExpiredMsg: p.isCanceling = false return p, nil diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 16112401290f2e8e6765d7f7ee55b54672190bd7..1cdc0c38243da39b2bd8c8eb276beea78f1dd37f 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -116,7 +116,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, a.handleWindowResize(msg.Width, msg.Height) // Completions messages - case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg: + case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg: u, completionCmd := a.completions.Update(msg) a.completions = u.(completions.Completions) return a, completionCmd From b4468381b8557a6f5cf439fd3308e327b529f3fa Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 24 Jul 2025 09:45:58 -0400 Subject: [PATCH 6/9] fix(tui): completions: ensure minimum height for completions list --- internal/tui/components/completions/completions.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index fae46d70d806f6847eeb40ea9b727da1671145d9..ab29d900010bb2a80e4e2b7d6135e44f6486769c 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -187,7 +187,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { c.x = c.wWidth - width - 1 } c.width = width - c.height = max(min(c.height, len(items)), 1) // Ensure at least 1 item height + c.height = max(min(maxCompletionsHeight, len(items)), 1) // Ensure at least 1 item height return c, tea.Batch( c.list.SetItems(items), c.list.SetSize(c.width, c.height), From 97290c82f2e9d72808fe1634226da51b5a06b216 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 24 Jul 2025 13:59:43 -0400 Subject: [PATCH 7/9] fix(tui): completions: improve positioning and handling completions With this, the completions popup will now reposition itself on fitlering, resizing, and when the cursor moves. It also ensures that the completions are correctly positioned relative to the textarea cursor position. --- internal/tui/components/chat/editor/editor.go | 54 ++++++++++++------- .../tui/components/completions/completions.go | 29 ++++++---- internal/tui/tui.go | 3 +- 3 files changed, 58 insertions(+), 28 deletions(-) diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index 7f06e69a388e10a49cbb792d0f4b8d231613eb3b..4e5f0bc431eb466cea5c6c7d436234c7a5e8531b 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -162,9 +162,7 @@ func (m *editorCmp) send() tea.Cmd { } func (m *editorCmp) repositionCompletions() tea.Msg { - cur := m.textarea.Cursor() - x := cur.X + m.x // adjust for padding - y := cur.Y + m.y + 1 + x, y := m.completionsPosition() return completions.RepositionCompletionsMsg{X: x, Y: y} } @@ -191,32 +189,37 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } if item, ok := msg.Value.(FileCompletionItem); ok { + word := m.textarea.Word() // If the selected item is a file, insert its path into the textarea value := m.textarea.Value() - value = value[:m.completionsStartIndex] - value += item.Path + value = value[:m.completionsStartIndex] + // Remove the current query + item.Path + // Insert the file path + value[m.completionsStartIndex+len(word):] // Append the rest of the value + // XXX: This will always move the cursor to the end of the textarea. m.textarea.SetValue(value) + m.textarea.MoveToEnd() if !msg.Insert { m.isCompletionsOpen = false m.currentQuery = "" m.completionsStartIndex = 0 } - return m, nil } case openEditorMsg: m.textarea.SetValue(msg.Text) m.textarea.MoveToEnd() case tea.KeyPressMsg: + cur := m.textarea.Cursor() + curIdx := m.textarea.Width()*cur.Y + cur.X switch { // Completions case msg.String() == "/" && !m.isCompletionsOpen && - // only show if beginning of prompt, or if previous char is a space: - (len(m.textarea.Value()) == 0 || m.textarea.Value()[len(m.textarea.Value())-1] == ' '): + // only show if beginning of prompt, or if previous char is a space or newline: + (len(m.textarea.Value()) == 0 || unicode.IsSpace(rune(m.textarea.Value()[len(m.textarea.Value())-1]))): m.isCompletionsOpen = true m.currentQuery = "" - m.completionsStartIndex = len(m.textarea.Value()) + m.completionsStartIndex = curIdx cmds = append(cmds, m.startCompletions) - case m.isCompletionsOpen && m.textarea.Cursor().X <= m.completionsStartIndex: + case m.isCompletionsOpen && curIdx <= m.completionsStartIndex: cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{})) } if key.Matches(msg, DeleteKeyMaps.AttachmentDeleteMode) { @@ -253,6 +256,7 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } if key.Matches(msg, m.keyMap.Newline) { m.textarea.InsertRune('\n') + cmds = append(cmds, util.CmdHandler(completions.CloseCompletionsMsg{})) } // Handle Enter key if m.textarea.Focused() && key.Matches(msg, m.keyMap.SendMessage) { @@ -284,12 +288,18 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // XXX: wont' work if editing in the middle of the field. m.completionsStartIndex = strings.LastIndex(m.textarea.Value(), word) m.currentQuery = word[1:] + x, y := m.completionsPosition() + x -= len(m.currentQuery) m.isCompletionsOpen = true - cmds = append(cmds, util.CmdHandler(completions.FilterCompletionsMsg{ - Query: m.currentQuery, - Reopen: m.isCompletionsOpen, - })) - } else { + cmds = append(cmds, + util.CmdHandler(completions.FilterCompletionsMsg{ + Query: m.currentQuery, + Reopen: m.isCompletionsOpen, + X: x, + Y: y, + }), + ) + } else if m.isCompletionsOpen { m.isCompletionsOpen = false m.currentQuery = "" m.completionsStartIndex = 0 @@ -302,6 +312,16 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(cmds...) } +func (m *editorCmp) completionsPosition() (int, int) { + cur := m.textarea.Cursor() + if cur == nil { + return m.x, m.y + 1 // adjust for padding + } + x := cur.X + m.x + y := cur.Y + m.y + 1 // adjust for padding + return x, y +} + func (m *editorCmp) Cursor() *tea.Cursor { cursor := m.textarea.Cursor() if cursor != nil { @@ -382,9 +402,7 @@ func (m *editorCmp) startCompletions() tea.Msg { }) } - cur := m.textarea.Cursor() - x := cur.X + m.x // adjust for padding - y := cur.Y + m.y + 1 + x, y := m.completionsPosition() return completions.OpenCompletionsMsg{ Completions: completionItems, X: x, diff --git a/internal/tui/components/completions/completions.go b/internal/tui/components/completions/completions.go index ab29d900010bb2a80e4e2b7d6135e44f6486769c..aad5dc8c83c163712a4d9b56e7a6442ce2380f25 100644 --- a/internal/tui/components/completions/completions.go +++ b/internal/tui/components/completions/completions.go @@ -27,6 +27,8 @@ type OpenCompletionsMsg struct { type FilterCompletionsMsg struct { Query string // The query to filter completions Reopen bool + X int // X position for the completions popup + Y int // Y position for the completions popup } type RepositionCompletionsMsg struct { @@ -165,6 +167,7 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } case RepositionCompletionsMsg: c.x, c.y = msg.X, msg.Y + c.adjustPosition() case CloseCompletionsMsg: c.open = false return c, util.CmdHandler(CompletionsClosedMsg{}) @@ -216,15 +219,9 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, c.list.Filter(msg.Query)) items := c.list.Items() itemsLen := len(items) - width := listWidth(items) - c.lastWidth = c.width - if c.x < 0 || width < c.lastWidth { - c.x = c.xorig - } else if c.x+width >= c.wWidth { - c.x = c.wWidth - width - 1 - } - c.width = width - c.height = max(min(maxCompletionsHeight, itemsLen), 1) + c.xorig = msg.X + c.x, c.y = msg.X, msg.Y + c.adjustPosition() cmds = append(cmds, c.list.SetSize(c.width, c.height)) if itemsLen == 0 { cmds = append(cmds, util.CmdHandler(CloseCompletionsMsg{})) @@ -237,6 +234,20 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return c, nil } +func (c *completionsCmp) adjustPosition() { + items := c.list.Items() + itemsLen := len(items) + width := listWidth(items) + c.lastWidth = c.width + if c.x < 0 || width < c.lastWidth { + c.x = c.xorig + } else if c.x+width >= c.wWidth { + c.x = c.wWidth - width - 1 + } + c.width = width + c.height = max(min(maxCompletionsHeight, itemsLen), 1) +} + // View implements Completions. func (c *completionsCmp) View() string { if !c.open || len(c.list.Items()) == 0 { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 1cdc0c38243da39b2bd8c8eb276beea78f1dd37f..770e7b26945e9bf7109f3076e1ad95a1f24aa51a 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -116,7 +116,8 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, a.handleWindowResize(msg.Width, msg.Height) // Completions messages - case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg: + case completions.OpenCompletionsMsg, completions.FilterCompletionsMsg, + completions.CloseCompletionsMsg, completions.RepositionCompletionsMsg: u, completionCmd := a.completions.Update(msg) a.completions = u.(completions.Completions) return a, completionCmd From 8c874293c93e5eee319e9f75a5e2735a9205fbd2 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 25 Jul 2025 11:52:00 +0200 Subject: [PATCH 8/9] Taciturnaxolotl/custom anthropic providers (#300) * feat: support anthropic provider type in custom provider configs * docs: fix provider configuration field name and add anthropic example - Change `provider_type` to `type` in documentation to match actual struct field - Add comprehensive examples for both OpenAI and Anthropic custom providers - Include missing `api_key` field in examples for completeness * feat: resolve headers to allow for custom scripts and such in headers * feat: allow headers in the anthropic client * feat: if api_key has "Bearer " in front then using it as an Authorization header and skip the X-API-Key header in the anthropic client * feat: add support for templating in the config resolve.go something like `Bearer $(echo $ENVVAR)-$(bash ~/.config/crush/script.sh)` would work now; also added some tests since the first iteration of this broke stuff majorly lol * feat: add a system prompt prefix option to the config --------- Co-authored-by: Kieran Klukas Co-authored-by: Kieran Klukas --- README.md | 44 ++++++- internal/config/config.go | 3 + internal/config/load.go | 2 +- internal/config/load_test.go | 29 +++++ internal/config/resolve.go | 114 ++++++++++++++++--- internal/config/resolve_test.go | 177 +++++++++++++++++++++++++++-- internal/llm/provider/anthropic.go | 55 +++++++-- internal/llm/provider/gemini.go | 12 +- internal/llm/provider/openai.go | 6 +- internal/llm/provider/provider.go | 44 ++++--- 10 files changed, 427 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index 26cb7308bb7a614603c61c3f4f4f5d1cee3fe40f..5fed716c8c6bf437e75ca65401c15e5be64441d5 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,6 @@ Crush supports Model Context Protocol (MCP) servers through three transport type ### Logging Enable debug logging with the `-d` flag or in config. View logs with `crush logs`. Logs are stored in `.crush/logs/crush.log`. - ```bash # Run with debug logging crush -d @@ -186,16 +185,21 @@ The `allowed_tools` array accepts: You can also skip all permission prompts entirely by running Crush with the `--yolo` flag. -### OpenAI-Compatible APIs +### Custom Providers + +Crush supports custom provider configurations for both OpenAI-compatible and Anthropic-compatible APIs. + +#### OpenAI-Compatible APIs -Crush supports all OpenAI-compatible APIs. Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment. +Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment. ```json { "providers": { "deepseek": { - "provider_type": "openai", + "type": "openai", "base_url": "https://api.deepseek.com/v1", + "api_key": "$DEEPSEEK_API_KEY", "models": [ { "id": "deepseek-chat", @@ -213,6 +217,38 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D } ``` +#### Anthropic-Compatible APIs + +You can also configure custom Anthropic-compatible providers: + +```json +{ + "providers": { + "custom-anthropic": { + "type": "anthropic", + "base_url": "https://api.anthropic.com/v1", + "api_key": "$ANTHROPIC_API_KEY", + "extra_headers": { + "anthropic-version": "2023-06-01" + }, + "models": [ + { + "id": "claude-3-sonnet", + "model": "Claude 3 Sonnet", + "cost_per_1m_in": 3000, + "cost_per_1m_out": 15000, + "cost_per_1m_in_cached": 300, + "cost_per_1m_out_cached": 15000, + "context_window": 200000, + "default_max_tokens": 4096, + "supports_attachments": true + } + ] + } + } +} +``` + ## Whatcha think? We’d love to hear your thoughts on this project. Feel free to drop us a note! diff --git a/internal/config/config.go b/internal/config/config.go index 9709c11a0636d91cb492b7735b63e46e5e843c74..9a0da2a376abc88c5e584d7d39744da6f1890ce3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -77,6 +77,9 @@ type ProviderConfig struct { // Marks the provider as disabled. Disable bool `json:"disable,omitempty"` + // Custom system prompt prefix. + SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"` + // Extra headers to send with each request to the provider. ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Extra body diff --git a/internal/config/load.go b/internal/config/load.go index 98569d41be810dd0b9382c4df56cfb3e9c1c5842..44bcf8e3ce87953b9c3589cacaf2fe8a248e97aa 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -232,7 +232,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know c.Providers.Del(id) continue } - if providerConfig.Type != catwalk.TypeOpenAI { + if providerConfig.Type != catwalk.TypeOpenAI && providerConfig.Type != catwalk.TypeAnthropic { slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type) c.Providers.Del(id) continue diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 5a52426f51ace9ee9e26bb42208511a72009dc3b..3f4ff51db2b04b7e4d8f9f5e86306eb3b2f7dc91 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -613,6 +613,35 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { assert.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL) }) + t.Run("custom anthropic provider is supported", func(t *testing.T) { + cfg := &Config{ + Providers: map[string]ProviderConfig{ + "custom-anthropic": { + APIKey: "test-key", + BaseURL: "https://api.anthropic.com/v1", + Type: catwalk.TypeAnthropic, + Models: []catwalk.Model{{ + ID: "claude-3-sonnet", + }}, + }, + }, + } + cfg.setDefaults("/tmp") + + env := env.NewFromMap(map[string]string{}) + resolver := NewEnvironmentVariableResolver(env) + err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) + assert.NoError(t, err) + + assert.Len(t, cfg.Providers, 1) + customProvider, exists := cfg.Providers["custom-anthropic"] + assert.True(t, exists) + assert.Equal(t, "custom-anthropic", customProvider.ID) + assert.Equal(t, "test-key", customProvider.APIKey) + assert.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL) + assert.Equal(t, catwalk.TypeAnthropic, customProvider.Type) + }) + t.Run("disabled custom provider is removed", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ diff --git a/internal/config/resolve.go b/internal/config/resolve.go index 3c97a6456cf7fe5968311746d62b2772b21d6aaa..3ef3522b09e504d3c57105e8bbe393b0f7c38b2b 100644 --- a/internal/config/resolve.go +++ b/internal/config/resolve.go @@ -35,34 +35,120 @@ func NewShellVariableResolver(env env.Env) VariableResolver { } // ResolveValue is a method for resolving values, such as environment variables. -// it will expect strings that start with `$` to be resolved as environment variables or shell commands. -// if the string does not start with `$`, it will return the string as is. +// it will resolve shell-like variable substitution anywhere in the string, including: +// - $(command) for command substitution +// - $VAR or ${VAR} for environment variables func (r *shellVariableResolver) ResolveValue(value string) (string, error) { - if !strings.HasPrefix(value, "$") { + // Special case: lone $ is an error (backward compatibility) + if value == "$" { + return "", fmt.Errorf("invalid value format: %s", value) + } + + // If no $ found, return as-is + if !strings.Contains(value, "$") { return value, nil } - if strings.HasPrefix(value, "$(") && strings.HasSuffix(value, ")") { - command := strings.TrimSuffix(strings.TrimPrefix(value, "$("), ")") + result := value + + // Handle command substitution: $(command) + for { + start := strings.Index(result, "$(") + if start == -1 { + break + } + + // Find matching closing parenthesis + depth := 0 + end := -1 + for i := start + 2; i < len(result); i++ { + if result[i] == '(' { + depth++ + } else if result[i] == ')' { + if depth == 0 { + end = i + break + } + depth-- + } + } + + if end == -1 { + return "", fmt.Errorf("unmatched $( in value: %s", value) + } + + command := result[start+2 : end] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() stdout, _, err := r.shell.Exec(ctx, command) + cancel() if err != nil { - return "", fmt.Errorf("command execution failed: %w", err) + return "", fmt.Errorf("command execution failed for '%s': %w", command, err) } - return strings.TrimSpace(stdout), nil + + // Replace the $(command) with the output + replacement := strings.TrimSpace(stdout) + result = result[:start] + replacement + result[end+1:] } - if after, ok := strings.CutPrefix(value, "$"); ok { - varName := after - value = r.env.Get(varName) - if value == "" { + // Handle environment variables: $VAR and ${VAR} + searchStart := 0 + for { + start := strings.Index(result[searchStart:], "$") + if start == -1 { + break + } + start += searchStart // Adjust for the offset + + // Skip if this is part of $( which we already handled + if start+1 < len(result) && result[start+1] == '(' { + // Skip past this $(...) + searchStart = start + 1 + continue + } + var varName string + var end int + + if start+1 < len(result) && result[start+1] == '{' { + // Handle ${VAR} format + closeIdx := strings.Index(result[start+2:], "}") + if closeIdx == -1 { + return "", fmt.Errorf("unmatched ${ in value: %s", value) + } + varName = result[start+2 : start+2+closeIdx] + end = start + 2 + closeIdx + 1 + } else { + // Handle $VAR format - variable names must start with letter or underscore + if start+1 >= len(result) { + return "", fmt.Errorf("incomplete variable reference at end of string: %s", value) + } + + if result[start+1] != '_' && + (result[start+1] < 'a' || result[start+1] > 'z') && + (result[start+1] < 'A' || result[start+1] > 'Z') { + return "", fmt.Errorf("invalid variable name starting with '%c' in: %s", result[start+1], value) + } + + end = start + 1 + for end < len(result) && (result[end] == '_' || + (result[end] >= 'a' && result[end] <= 'z') || + (result[end] >= 'A' && result[end] <= 'Z') || + (result[end] >= '0' && result[end] <= '9')) { + end++ + } + varName = result[start+1 : end] + } + + envValue := r.env.Get(varName) + if envValue == "" { return "", fmt.Errorf("environment variable %q not set", varName) } - return value, nil + + result = result[:start] + envValue + result[end:] + searchStart = start + len(envValue) // Continue searching after the replacement } - return "", fmt.Errorf("invalid value format: %s", value) + + return result, nil } type environmentVariableResolver struct { diff --git a/internal/config/resolve_test.go b/internal/config/resolve_test.go index 7cdcd2a7913cb581e5312f787791e8e89e699281..26ab184b26f82e70bf95320492b900a080f3e015 100644 --- a/internal/config/resolve_test.go +++ b/internal/config/resolve_test.go @@ -47,17 +47,7 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) { envVars: map[string]string{}, expectError: true, }, - { - name: "shell command execution", - value: "$(echo hello)", - shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { - if command == "echo hello" { - return "hello\n", "", nil - } - return "", "", errors.New("unexpected command") - }, - expected: "hello", - }, + { name: "shell command with whitespace trimming", value: "$(echo ' spaced ')", @@ -104,6 +94,171 @@ func TestShellVariableResolver_ResolveValue(t *testing.T) { } } +func TestShellVariableResolver_EnhancedResolveValue(t *testing.T) { + tests := []struct { + name string + value string + envVars map[string]string + shellFunc func(ctx context.Context, command string) (stdout, stderr string, err error) + expected string + expectError bool + }{ + { + name: "command substitution within string", + value: "Bearer $(echo token123)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo token123" { + return "token123\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "Bearer token123", + }, + { + name: "environment variable within string", + value: "Bearer $TOKEN", + envVars: map[string]string{"TOKEN": "sk-ant-123"}, + expected: "Bearer sk-ant-123", + }, + { + name: "environment variable with braces within string", + value: "Bearer ${TOKEN}", + envVars: map[string]string{"TOKEN": "sk-ant-456"}, + expected: "Bearer sk-ant-456", + }, + { + name: "mixed command and environment substitution", + value: "$USER-$(date +%Y)-$HOST", + envVars: map[string]string{ + "USER": "testuser", + "HOST": "localhost", + }, + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "date +%Y" { + return "2024\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "testuser-2024-localhost", + }, + { + name: "multiple command substitutions", + value: "$(echo hello) $(echo world)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + switch command { + case "echo hello": + return "hello\n", "", nil + case "echo world": + return "world\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "hello world", + }, + { + name: "nested parentheses in command", + value: "$(echo $(echo inner))", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "echo $(echo inner)" { + return "nested\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "nested", + }, + { + name: "lone dollar with non-variable chars", + value: "prefix$123suffix", // Numbers can't start variable names + expectError: true, + }, + { + name: "dollar with special chars", + value: "a$@b$#c", // Special chars aren't valid in variable names + expectError: true, + }, + { + name: "empty environment variable substitution", + value: "Bearer $EMPTY_VAR", + envVars: map[string]string{}, + expectError: true, + }, + { + name: "unmatched command substitution opening", + value: "Bearer $(echo test", + expectError: true, + }, + { + name: "unmatched environment variable braces", + value: "Bearer ${TOKEN", + expectError: true, + }, + { + name: "command substitution with error", + value: "Bearer $(false)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + return "", "", errors.New("command failed") + }, + expectError: true, + }, + { + name: "complex real-world example", + value: "Bearer $(cat /tmp/token.txt | base64 -w 0)", + shellFunc: func(ctx context.Context, command string) (stdout, stderr string, err error) { + if command == "cat /tmp/token.txt | base64 -w 0" { + return "c2stYW50LXRlc3Q=\n", "", nil + } + return "", "", errors.New("unexpected command") + }, + expected: "Bearer c2stYW50LXRlc3Q=", + }, + { + name: "environment variable with underscores and numbers", + value: "Bearer $API_KEY_V2", + envVars: map[string]string{"API_KEY_V2": "sk-test-123"}, + expected: "Bearer sk-test-123", + }, + { + name: "no substitution needed", + value: "Bearer sk-ant-static-token", + expected: "Bearer sk-ant-static-token", + }, + { + name: "incomplete variable at end", + value: "Bearer $", + expectError: true, + }, + { + name: "variable with invalid character", + value: "Bearer $VAR-NAME", // Hyphen not allowed in variable names + expectError: true, + }, + { + name: "multiple invalid variables", + value: "$1$2$3", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testEnv := env.NewFromMap(tt.envVars) + resolver := &shellVariableResolver{ + shell: &mockShell{execFunc: tt.shellFunc}, + env: testEnv, + } + + result, err := resolver.ResolveValue(tt.value) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + func TestEnvironmentVariableResolver_ResolveValue(t *testing.T) { tests := []struct { name string diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 0765389a05ecaf33c6c521770e1880a24210d35f..3de8c805b3f0cfa08b1b2bb6b60577742ce8cc1d 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -39,8 +39,30 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client { anthropicClientOptions := []option.RequestOption{} - if opts.apiKey != "" { - anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) + + // Check if Authorization header is provided in extra headers + hasBearerAuth := false + if opts.extraHeaders != nil { + for key := range opts.extraHeaders { + if strings.ToLower(key) == "authorization" { + hasBearerAuth = true + break + } + } + } + + isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ") + + if opts.apiKey != "" && !hasBearerAuth { + if isBearerToken { + slog.Debug("API key starts with 'Bearer ', using as Authorization header") + anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey)) + } else { + // Use standard X-Api-Key header + anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) + } + } else if hasBearerAuth { + slog.Debug("Skipping X-Api-Key header because Authorization header is provided") } if useBedrock { anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) @@ -200,6 +222,25 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to maxTokens = int64(a.adjustedMaxTokens) } + systemBlocks := []anthropic.TextBlockParam{} + + // Add custom system prompt prefix if configured + if a.providerOptions.systemPromptPrefix != "" { + systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ + Text: a.providerOptions.systemPromptPrefix, + CacheControl: anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + }, + }) + } + + systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ + Text: a.providerOptions.systemMessage, + CacheControl: anthropic.CacheControlEphemeralParam{ + Type: "ephemeral", + }, + }) + return anthropic.MessageNewParams{ Model: anthropic.Model(model.ID), MaxTokens: maxTokens, @@ -207,14 +248,7 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to Messages: messages, Tools: tools, Thinking: thinkingParam, - System: []anthropic.TextBlockParam{ - { - Text: a.providerOptions.systemMessage, - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: "ephemeral", - }, - }, - }, + System: systemBlocks, } } @@ -393,6 +427,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message close(eventChan) return } + // If there is an error we are going to see if we can retry the call retry, after, retryErr := a.shouldRetry(attempts, err) if retryErr != nil { diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index b2d1da11148e74362e7b529b9ec78dc1810d0f0d..0070d246012547a691f8c6a8cbd8de2234cd93ec 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -180,12 +180,16 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too if modelConfig.MaxTokens > 0 { maxTokens = modelConfig.MaxTokens } + systemMessage := g.providerOptions.systemMessage + if g.providerOptions.systemPromptPrefix != "" { + systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ MaxOutputTokens: int32(maxTokens), SystemInstruction: &genai.Content{ - Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, + Parts: []*genai.Part{{Text: systemMessage}}, }, } config.Tools = g.convertTools(tools) @@ -280,12 +284,16 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t if g.providerOptions.maxTokens > 0 { maxTokens = g.providerOptions.maxTokens } + systemMessage := g.providerOptions.systemMessage + if g.providerOptions.systemPromptPrefix != "" { + systemMessage = g.providerOptions.systemPromptPrefix + "\n" + systemMessage + } history := geminiMessages[:len(geminiMessages)-1] // All but last message lastMsg := geminiMessages[len(geminiMessages)-1] config := &genai.GenerateContentConfig{ MaxOutputTokens: int32(maxTokens), SystemInstruction: &genai.Content{ - Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}}, + Parts: []*genai.Part{{Text: systemMessage}}, }, } config.Tools = g.convertTools(tools) diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 397d6954d0a5c8f3dbe25f4a34115ade4c242012..23e247830a48ba1860ba7bde5059da69fab6d3ac 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -57,7 +57,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client { func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { // Add system message first - openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage)) + systemMessage := o.providerOptions.systemMessage + if o.providerOptions.systemPromptPrefix != "" { + systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage + } + openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage)) for _, msg := range messages { switch msg.Role { diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 062c2aa977c6ff101d1d8ab6f32809845bd48ff3..c236c10f0b0e9bf9b4db50544ca664291ef13b65 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -61,17 +61,18 @@ type Provider interface { } type providerClientOptions struct { - baseURL string - config config.ProviderConfig - apiKey string - modelType config.SelectedModelType - model func(config.SelectedModelType) catwalk.Model - disableCache bool - systemMessage string - maxTokens int64 - extraHeaders map[string]string - extraBody map[string]any - extraParams map[string]string + baseURL string + config config.ProviderConfig + apiKey string + modelType config.SelectedModelType + model func(config.SelectedModelType) catwalk.Model + disableCache bool + systemMessage string + systemPromptPrefix string + maxTokens int64 + extraHeaders map[string]string + extraBody map[string]any + extraParams map[string]string } type ProviderClientOption func(*providerClientOptions) @@ -143,12 +144,23 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err) } + // Resolve extra headers + resolvedExtraHeaders := make(map[string]string) + for key, value := range cfg.ExtraHeaders { + resolvedValue, err := config.Get().Resolve(value) + if err != nil { + return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err) + } + resolvedExtraHeaders[key] = resolvedValue + } + clientOptions := providerClientOptions{ - baseURL: cfg.BaseURL, - config: cfg, - apiKey: resolvedAPIKey, - extraHeaders: cfg.ExtraHeaders, - extraBody: cfg.ExtraBody, + baseURL: cfg.BaseURL, + config: cfg, + apiKey: resolvedAPIKey, + extraHeaders: resolvedExtraHeaders, + extraBody: cfg.ExtraBody, + systemPromptPrefix: cfg.SystemPromptPrefix, model: func(tp config.SelectedModelType) catwalk.Model { return *config.Get().GetModelByType(tp) }, From 05a3e9a0303569feb3796e0e3c8eb5b7cfcf34f2 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 25 Jul 2025 12:04:51 +0200 Subject: [PATCH 9/9] chore: fix tests --- internal/config/load_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 3f4ff51db2b04b7e4d8f9f5e86306eb3b2f7dc91..8c2735bd15fb3b52fe0c87401f57534e9b007e5b 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -615,7 +615,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom anthropic provider is supported", func(t *testing.T) { cfg := &Config{ - Providers: map[string]ProviderConfig{ + Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom-anthropic": { APIKey: "test-key", BaseURL: "https://api.anthropic.com/v1", @@ -624,7 +624,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { ID: "claude-3-sonnet", }}, }, - }, + }), } cfg.setDefaults("/tmp") @@ -633,8 +633,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) assert.NoError(t, err) - assert.Len(t, cfg.Providers, 1) - customProvider, exists := cfg.Providers["custom-anthropic"] + assert.Equal(t, cfg.Providers.Len(), 1) + customProvider, exists := cfg.Providers.Get("custom-anthropic") assert.True(t, exists) assert.Equal(t, "custom-anthropic", customProvider.ID) assert.Equal(t, "test-key", customProvider.APIKey)