1package dialog
2
3import (
4 "fmt"
5 "slices"
6 "sort"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/ui/list"
10 "github.com/charmbracelet/crush/internal/ui/styles"
11 "github.com/sahilm/fuzzy"
12)
13
14// ModelsList is a list specifically for model items and groups.
15type ModelsList struct {
16 *list.List
17 groups []ModelGroup
18 query string
19 t *styles.Styles
20}
21
22// NewModelsList creates a new list suitable for model items and groups.
23func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
24 f := &ModelsList{
25 List: list.NewList(),
26 groups: groups,
27 t: sty,
28 }
29 return f
30}
31
32// Len returns the number of model items across all groups.
33func (f *ModelsList) Len() int {
34 n := 0
35 for _, g := range f.groups {
36 n += len(g.Items)
37 }
38 return n
39}
40
41// SetGroups sets the model groups and updates the list items.
42func (f *ModelsList) SetGroups(groups ...ModelGroup) {
43 f.groups = groups
44 items := []list.Item{}
45 for _, g := range f.groups {
46 items = append(items, &g)
47 for _, item := range g.Items {
48 items = append(items, item)
49 }
50 // Add a space separator after each provider section
51 items = append(items, list.NewSpacerItem(1))
52 }
53 f.SetItems(items...)
54}
55
56// SetFilter sets the filter query and updates the list items.
57func (f *ModelsList) SetFilter(q string) {
58 f.query = q
59}
60
61// SetSelected sets the selected item index. It overrides the base method to
62// skip non-model items.
63func (f *ModelsList) SetSelected(index int) {
64 if index < 0 || index >= f.Len() {
65 f.List.SetSelected(index)
66 return
67 }
68
69 f.List.SetSelected(index)
70 for {
71 selectedItem := f.SelectedItem()
72 if _, ok := selectedItem.(*ModelItem); ok {
73 return
74 }
75 f.List.SetSelected(index + 1)
76 index++
77 if index >= f.Len() {
78 return
79 }
80 }
81}
82
83// SetSelectedItem sets the selected item in the list by item ID.
84func (f *ModelsList) SetSelectedItem(itemID string) {
85 if itemID == "" {
86 f.SetSelected(0)
87 return
88 }
89
90 count := 0
91 for _, g := range f.groups {
92 for _, item := range g.Items {
93 if item.ID() == itemID {
94 f.SetSelected(count)
95 return
96 }
97 count++
98 }
99 }
100}
101
102// SelectNext selects the next model item, skipping any non-focusable items
103// like group headers and spacers.
104func (f *ModelsList) SelectNext() (v bool) {
105 for {
106 v = f.List.SelectNext()
107 selectedItem := f.SelectedItem()
108 if _, ok := selectedItem.(*ModelItem); ok {
109 return v
110 }
111 }
112}
113
114// SelectPrev selects the previous model item, skipping any non-focusable items
115// like group headers and spacers.
116func (f *ModelsList) SelectPrev() (v bool) {
117 for {
118 v = f.List.SelectPrev()
119 selectedItem := f.SelectedItem()
120 if _, ok := selectedItem.(*ModelItem); ok {
121 return v
122 }
123 }
124}
125
126// SelectFirst selects the first model item in the list.
127func (f *ModelsList) SelectFirst() (v bool) {
128 v = f.List.SelectFirst()
129 for {
130 selectedItem := f.SelectedItem()
131 if _, ok := selectedItem.(*ModelItem); ok {
132 return v
133 }
134 v = f.List.SelectNext()
135 }
136}
137
138// SelectLast selects the last model item in the list.
139func (f *ModelsList) SelectLast() (v bool) {
140 v = f.List.SelectLast()
141 for {
142 selectedItem := f.SelectedItem()
143 if _, ok := selectedItem.(*ModelItem); ok {
144 return v
145 }
146 v = f.List.SelectPrev()
147 }
148}
149
150// IsSelectedFirst checks if the selected item is the first model item.
151func (f *ModelsList) IsSelectedFirst() bool {
152 originalIndex := f.Selected()
153 f.SelectFirst()
154 isFirst := f.Selected() == originalIndex
155 f.List.SetSelected(originalIndex)
156 return isFirst
157}
158
159// IsSelectedLast checks if the selected item is the last model item.
160func (f *ModelsList) IsSelectedLast() bool {
161 originalIndex := f.Selected()
162 f.SelectLast()
163 isLast := f.Selected() == originalIndex
164 f.List.SetSelected(originalIndex)
165 return isLast
166}
167
168// VisibleItems returns the visible items after filtering.
169func (f *ModelsList) VisibleItems() []list.Item {
170 query := strings.ToLower(strings.ReplaceAll(f.query, " ", ""))
171
172 if query == "" {
173 // No filter, return all items with group headers
174 items := []list.Item{}
175 for _, g := range f.groups {
176 items = append(items, &g)
177 for _, item := range g.Items {
178 item.SetMatch(fuzzy.Match{})
179 items = append(items, item)
180 }
181 // Add a space separator after each provider section
182 items = append(items, list.NewSpacerItem(1))
183 }
184 return items
185 }
186
187 filterableItems := make([]list.FilterableItem, 0, f.Len())
188 for _, g := range f.groups {
189 for _, item := range g.Items {
190 filterableItems = append(filterableItems, item)
191 }
192 }
193
194 items := []list.Item{}
195 visitedGroups := map[int]bool{}
196
197 // Reconstruct groups with matched items
198 // Find which group this item belongs to
199 for gi, g := range f.groups {
200 addedCount := 0
201 name := strings.ToLower(g.Title) + " "
202
203 names := make([]string, len(filterableItems))
204 for i, item := range filterableItems {
205 ms := item.(*ModelItem)
206 names[i] = fmt.Sprintf("%s%s", name, ms.Filter())
207 }
208
209 matches := fuzzy.Find(query, names)
210 sort.SliceStable(matches, func(i, j int) bool {
211 return matches[i].Score > matches[j].Score
212 })
213
214 for _, match := range matches {
215 item := filterableItems[match.Index].(*ModelItem)
216 idxs := []int{}
217 for _, idx := range match.MatchedIndexes {
218 // Adjusts removing provider name highlights
219 if idx < len(name) {
220 continue
221 }
222 idxs = append(idxs, idx-len(name))
223 }
224
225 match.MatchedIndexes = idxs
226 if slices.Contains(g.Items, item) {
227 if !visitedGroups[gi] {
228 // Add section header
229 items = append(items, &g)
230 visitedGroups[gi] = true
231 }
232 // Add the matched item
233 item.SetMatch(match)
234 items = append(items, item)
235 addedCount++
236 }
237 }
238 if addedCount > 0 {
239 // Add a space separator after each provider section
240 items = append(items, list.NewSpacerItem(1))
241 }
242 }
243
244 return items
245}
246
247// Render renders the filterable list.
248func (f *ModelsList) Render() string {
249 f.SetItems(f.VisibleItems()...)
250 return f.List.Render()
251}
252
253type modelGroups []ModelGroup
254
255func (m modelGroups) Len() int {
256 n := 0
257 for _, g := range m {
258 n += len(g.Items)
259 }
260 return n
261}
262
263func (m modelGroups) String(i int) string {
264 count := 0
265 for _, g := range m {
266 if i < count+len(g.Items) {
267 return g.Items[i-count].Filter()
268 }
269 count += len(g.Items)
270 }
271 return ""
272}