1package dialog
2
3import (
4 "fmt"
5 "slices"
6 "sort"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/ui/list"
10 "github.com/charmbracelet/crush/internal/ui/styles"
11 "github.com/sahilm/fuzzy"
12)
13
14// ModelsList is a list specifically for model items and groups.
15type ModelsList struct {
16 *list.List
17 groups []ModelGroup
18 query string
19 t *styles.Styles
20}
21
22// NewModelsList creates a new list suitable for model items and groups.
23func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
24 f := &ModelsList{
25 List: list.NewList(),
26 groups: groups,
27 t: sty,
28 }
29 f.RegisterRenderCallback(list.FocusedRenderCallback(f.List))
30 return f
31}
32
33// Len returns the number of model items across all groups.
34func (f *ModelsList) Len() int {
35 n := 0
36 for _, g := range f.groups {
37 n += len(g.Items)
38 }
39 return n
40}
41
42// SetGroups sets the model groups and updates the list items.
43func (f *ModelsList) SetGroups(groups ...ModelGroup) {
44 f.groups = groups
45 items := []list.Item{}
46 for _, g := range f.groups {
47 items = append(items, &g)
48 for _, item := range g.Items {
49 items = append(items, item)
50 }
51 // Add a space separator after each provider section
52 items = append(items, list.NewSpacerItem(1))
53 }
54 f.SetItems(items...)
55}
56
57// SetFilter sets the filter query and updates the list items.
58func (f *ModelsList) SetFilter(q string) {
59 f.query = q
60}
61
62// SetSelected sets the selected item index. It overrides the base method to
63// skip non-model items.
64func (f *ModelsList) SetSelected(index int) {
65 if index < 0 || index >= f.Len() {
66 f.List.SetSelected(index)
67 return
68 }
69
70 f.List.SetSelected(index)
71 for {
72 selectedItem := f.SelectedItem()
73 if _, ok := selectedItem.(*ModelItem); ok {
74 return
75 }
76 f.List.SetSelected(index + 1)
77 index++
78 if index >= f.Len() {
79 return
80 }
81 }
82}
83
84// SetSelectedItem sets the selected item in the list by item ID.
85func (f *ModelsList) SetSelectedItem(itemID string) {
86 if itemID == "" {
87 f.SetSelected(0)
88 return
89 }
90
91 count := 0
92 for _, g := range f.groups {
93 for _, item := range g.Items {
94 if item.ID() == itemID {
95 f.SetSelected(count)
96 return
97 }
98 count++
99 }
100 }
101}
102
103// SelectNext selects the next model item, skipping any non-focusable items
104// like group headers and spacers.
105func (f *ModelsList) SelectNext() (v bool) {
106 for {
107 v = f.List.SelectNext()
108 selectedItem := f.SelectedItem()
109 if _, ok := selectedItem.(*ModelItem); ok {
110 return v
111 }
112 }
113}
114
115// SelectPrev selects the previous model item, skipping any non-focusable items
116// like group headers and spacers.
117func (f *ModelsList) SelectPrev() (v bool) {
118 for {
119 v = f.List.SelectPrev()
120 selectedItem := f.SelectedItem()
121 if _, ok := selectedItem.(*ModelItem); ok {
122 return v
123 }
124 }
125}
126
127// SelectFirst selects the first model item in the list.
128func (f *ModelsList) SelectFirst() (v bool) {
129 v = f.List.SelectFirst()
130 for {
131 selectedItem := f.SelectedItem()
132 if _, ok := selectedItem.(*ModelItem); ok {
133 return v
134 }
135 v = f.List.SelectNext()
136 }
137}
138
139// SelectLast selects the last model item in the list.
140func (f *ModelsList) SelectLast() (v bool) {
141 v = f.List.SelectLast()
142 for {
143 selectedItem := f.SelectedItem()
144 if _, ok := selectedItem.(*ModelItem); ok {
145 return v
146 }
147 v = f.List.SelectPrev()
148 }
149}
150
151// IsSelectedFirst checks if the selected item is the first model item.
152func (f *ModelsList) IsSelectedFirst() bool {
153 originalIndex := f.Selected()
154 f.SelectFirst()
155 isFirst := f.Selected() == originalIndex
156 f.List.SetSelected(originalIndex)
157 return isFirst
158}
159
160// IsSelectedLast checks if the selected item is the last model item.
161func (f *ModelsList) IsSelectedLast() bool {
162 originalIndex := f.Selected()
163 f.SelectLast()
164 isLast := f.Selected() == originalIndex
165 f.List.SetSelected(originalIndex)
166 return isLast
167}
168
169// VisibleItems returns the visible items after filtering.
170func (f *ModelsList) VisibleItems() []list.Item {
171 query := strings.ToLower(strings.ReplaceAll(f.query, " ", ""))
172
173 if query == "" {
174 // No filter, return all items with group headers
175 items := []list.Item{}
176 for _, g := range f.groups {
177 items = append(items, &g)
178 for _, item := range g.Items {
179 item.SetMatch(fuzzy.Match{})
180 items = append(items, item)
181 }
182 // Add a space separator after each provider section
183 items = append(items, list.NewSpacerItem(1))
184 }
185 return items
186 }
187
188 filterableItems := make([]list.FilterableItem, 0, f.Len())
189 for _, g := range f.groups {
190 for _, item := range g.Items {
191 filterableItems = append(filterableItems, item)
192 }
193 }
194
195 items := []list.Item{}
196 visitedGroups := map[int]bool{}
197
198 // Reconstruct groups with matched items
199 // Find which group this item belongs to
200 for gi, g := range f.groups {
201 addedCount := 0
202 name := strings.ToLower(g.Title) + " "
203
204 names := make([]string, len(filterableItems))
205 for i, item := range filterableItems {
206 ms := item.(*ModelItem)
207 names[i] = fmt.Sprintf("%s%s", name, ms.Filter())
208 }
209
210 matches := fuzzy.Find(query, names)
211 sort.SliceStable(matches, func(i, j int) bool {
212 return matches[i].Score > matches[j].Score
213 })
214
215 for _, match := range matches {
216 item := filterableItems[match.Index].(*ModelItem)
217 idxs := []int{}
218 for _, idx := range match.MatchedIndexes {
219 // Adjusts removing provider name highlights
220 if idx < len(name) {
221 continue
222 }
223 idxs = append(idxs, idx-len(name))
224 }
225
226 match.MatchedIndexes = idxs
227 if slices.Contains(g.Items, item) {
228 if !visitedGroups[gi] {
229 // Add section header
230 items = append(items, &g)
231 visitedGroups[gi] = true
232 }
233 // Add the matched item
234 item.SetMatch(match)
235 items = append(items, item)
236 addedCount++
237 }
238 }
239 if addedCount > 0 {
240 // Add a space separator after each provider section
241 items = append(items, list.NewSpacerItem(1))
242 }
243 }
244
245 return items
246}
247
248// Render renders the filterable list.
249func (f *ModelsList) Render() string {
250 f.SetItems(f.VisibleItems()...)
251 return f.List.Render()
252}
253
254type modelGroups []ModelGroup
255
256func (m modelGroups) Len() int {
257 n := 0
258 for _, g := range m {
259 n += len(g.Items)
260 }
261 return n
262}
263
264func (m modelGroups) String(i int) string {
265 count := 0
266 for _, g := range m {
267 if i < count+len(g.Items) {
268 return g.Items[i-count].Filter()
269 }
270 count += len(g.Items)
271 }
272 return ""
273}