1package dialog
2
3import (
4 "fmt"
5 "slices"
6 "sort"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/ui/list"
10 "github.com/charmbracelet/crush/internal/ui/styles"
11 "github.com/sahilm/fuzzy"
12)
13
14// ModelsList is a list specifically for model items and groups.
15type ModelsList struct {
16 *list.List
17 groups []ModelGroup
18 query string
19 t *styles.Styles
20}
21
22// NewModelsList creates a new list suitable for model items and groups.
23func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
24 f := &ModelsList{
25 List: list.NewList(),
26 groups: groups,
27 t: sty,
28 }
29 f.RegisterRenderCallback(list.FocusedRenderCallback(f.List))
30 return f
31}
32
33// Len returns the number of model items across all groups.
34func (f *ModelsList) Len() int {
35 n := 0
36 for _, g := range f.groups {
37 n += len(g.Items)
38 }
39 return n
40}
41
42// SetGroups sets the model groups and updates the list items.
43func (f *ModelsList) SetGroups(groups ...ModelGroup) {
44 f.groups = groups
45 items := []list.Item{}
46 for _, g := range f.groups {
47 items = append(items, &g)
48 for _, item := range g.Items {
49 items = append(items, item)
50 }
51 // Add a space separator after each provider section
52 items = append(items, list.NewSpacerItem(1))
53 }
54 f.SetItems(items...)
55}
56
57// SetFilter sets the filter query and updates the list items.
58func (f *ModelsList) SetFilter(q string) {
59 f.query = q
60 f.SetItems(f.VisibleItems()...)
61}
62
63// SetSelected sets the selected item index. It overrides the base method to
64// skip non-model items.
65func (f *ModelsList) SetSelected(index int) {
66 if index < 0 || index >= f.Len() {
67 f.List.SetSelected(index)
68 return
69 }
70
71 f.List.SetSelected(index)
72 for {
73 selectedItem := f.SelectedItem()
74 if _, ok := selectedItem.(*ModelItem); ok {
75 return
76 }
77 f.List.SetSelected(index + 1)
78 index++
79 if index >= f.Len() {
80 return
81 }
82 }
83}
84
85// SetSelectedItem sets the selected item in the list by item ID.
86func (f *ModelsList) SetSelectedItem(itemID string) {
87 if itemID == "" {
88 f.SetSelected(0)
89 return
90 }
91
92 count := 0
93 for _, g := range f.groups {
94 for _, item := range g.Items {
95 if item.ID() == itemID {
96 f.SetSelected(count)
97 return
98 }
99 count++
100 }
101 }
102}
103
104// SelectNext selects the next model item, skipping any non-focusable items
105// like group headers and spacers.
106func (f *ModelsList) SelectNext() (v bool) {
107 v = f.List.SelectNext()
108 for v {
109 selectedItem := f.SelectedItem()
110 if _, ok := selectedItem.(*ModelItem); ok {
111 return v
112 }
113 v = f.List.SelectNext()
114 }
115 return v
116}
117
118// SelectPrev selects the previous model item, skipping any non-focusable items
119// like group headers and spacers.
120func (f *ModelsList) SelectPrev() (v bool) {
121 v = f.List.SelectPrev()
122 for v {
123 selectedItem := f.SelectedItem()
124 if _, ok := selectedItem.(*ModelItem); ok {
125 return v
126 }
127 v = f.List.SelectPrev()
128 }
129 return v
130}
131
132// SelectFirst selects the first model item in the list.
133func (f *ModelsList) SelectFirst() (v bool) {
134 v = f.List.SelectFirst()
135 for v {
136 selectedItem := f.SelectedItem()
137 _, ok := selectedItem.(*ModelItem)
138 if ok {
139 return v
140 }
141 v = f.List.SelectNext()
142 }
143 return v
144}
145
146// SelectLast selects the last model item in the list.
147func (f *ModelsList) SelectLast() (v bool) {
148 v = f.List.SelectLast()
149 for v {
150 selectedItem := f.SelectedItem()
151 if _, ok := selectedItem.(*ModelItem); ok {
152 return v
153 }
154 v = f.List.SelectPrev()
155 }
156 return v
157}
158
159// IsSelectedFirst checks if the selected item is the first model item.
160func (f *ModelsList) IsSelectedFirst() bool {
161 originalIndex := f.Selected()
162 f.SelectFirst()
163 isFirst := f.Selected() == originalIndex
164 f.List.SetSelected(originalIndex)
165 return isFirst
166}
167
168// IsSelectedLast checks if the selected item is the last model item.
169func (f *ModelsList) IsSelectedLast() bool {
170 originalIndex := f.Selected()
171 f.SelectLast()
172 isLast := f.Selected() == originalIndex
173 f.List.SetSelected(originalIndex)
174 return isLast
175}
176
177// VisibleItems returns the visible items after filtering.
178func (f *ModelsList) VisibleItems() []list.Item {
179 query := strings.ToLower(strings.ReplaceAll(f.query, " ", ""))
180
181 if query == "" {
182 // No filter, return all items with group headers
183 items := []list.Item{}
184 for _, g := range f.groups {
185 items = append(items, &g)
186 for _, item := range g.Items {
187 item.SetMatch(fuzzy.Match{})
188 items = append(items, item)
189 }
190 // Add a space separator after each provider section
191 items = append(items, list.NewSpacerItem(1))
192 }
193 return items
194 }
195
196 filterableItems := make([]list.FilterableItem, 0, f.Len())
197 for _, g := range f.groups {
198 for _, item := range g.Items {
199 filterableItems = append(filterableItems, item)
200 }
201 }
202
203 items := []list.Item{}
204 visitedGroups := map[int]bool{}
205
206 // Reconstruct groups with matched items
207 // Find which group this item belongs to
208 for gi, g := range f.groups {
209 addedCount := 0
210 name := strings.ToLower(g.Title) + " "
211
212 names := make([]string, len(filterableItems))
213 for i, item := range filterableItems {
214 ms := item.(*ModelItem)
215 names[i] = fmt.Sprintf("%s%s", name, ms.Filter())
216 }
217
218 matches := fuzzy.Find(query, names)
219 sort.SliceStable(matches, func(i, j int) bool {
220 return matches[i].Score > matches[j].Score
221 })
222
223 for _, match := range matches {
224 item := filterableItems[match.Index].(*ModelItem)
225 idxs := []int{}
226 for _, idx := range match.MatchedIndexes {
227 // Adjusts removing provider name highlights
228 if idx < len(name) {
229 continue
230 }
231 idxs = append(idxs, idx-len(name))
232 }
233
234 match.MatchedIndexes = idxs
235 if slices.Contains(g.Items, item) {
236 if !visitedGroups[gi] {
237 // Add section header
238 items = append(items, &g)
239 visitedGroups[gi] = true
240 }
241 // Add the matched item
242 item.SetMatch(match)
243 items = append(items, item)
244 addedCount++
245 }
246 }
247 if addedCount > 0 {
248 // Add a space separator after each provider section
249 items = append(items, list.NewSpacerItem(1))
250 }
251 }
252
253 return items
254}
255
256// Render renders the filterable list.
257func (f *ModelsList) Render() string {
258 f.SetItems(f.VisibleItems()...)
259 return f.List.Render()
260}
261
262type modelGroups []ModelGroup
263
264func (m modelGroups) Len() int {
265 n := 0
266 for _, g := range m {
267 n += len(g.Items)
268 }
269 return n
270}
271
272func (m modelGroups) String(i int) string {
273 count := 0
274 for _, g := range m {
275 if i < count+len(g.Items) {
276 return g.Items[i-count].Filter()
277 }
278 count += len(g.Items)
279 }
280 return ""
281}