models_list.go

  1package dialog
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"sort"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/ui/list"
 10	"github.com/charmbracelet/crush/internal/ui/styles"
 11	"github.com/sahilm/fuzzy"
 12)
 13
 14// ModelsList is a list specifically for model items and groups.
 15type ModelsList struct {
 16	*list.List
 17	groups []ModelGroup
 18	query  string
 19	t      *styles.Styles
 20}
 21
 22// NewModelsList creates a new list suitable for model items and groups.
 23func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
 24	f := &ModelsList{
 25		List:   list.NewList(),
 26		groups: groups,
 27		t:      sty,
 28	}
 29	f.RegisterRenderCallback(list.FocusedRenderCallback(f.List))
 30	return f
 31}
 32
 33// Len returns the number of model items across all groups.
 34func (f *ModelsList) Len() int {
 35	n := 0
 36	for _, g := range f.groups {
 37		n += len(g.Items)
 38	}
 39	return n
 40}
 41
 42// SetGroups sets the model groups and updates the list items.
 43func (f *ModelsList) SetGroups(groups ...ModelGroup) {
 44	f.groups = groups
 45	items := []list.Item{}
 46	for _, g := range f.groups {
 47		items = append(items, &g)
 48		for _, item := range g.Items {
 49			items = append(items, item)
 50		}
 51		// Add a space separator after each provider section
 52		items = append(items, list.NewSpacerItem(1))
 53	}
 54	f.SetItems(items...)
 55}
 56
 57// SetFilter sets the filter query and updates the list items.
 58func (f *ModelsList) SetFilter(q string) {
 59	f.query = q
 60}
 61
 62// SetSelected sets the selected item index. It overrides the base method to
 63// skip non-model items.
 64func (f *ModelsList) SetSelected(index int) {
 65	if index < 0 || index >= f.Len() {
 66		f.List.SetSelected(index)
 67		return
 68	}
 69
 70	f.List.SetSelected(index)
 71	for {
 72		selectedItem := f.SelectedItem()
 73		if _, ok := selectedItem.(*ModelItem); ok {
 74			return
 75		}
 76		f.List.SetSelected(index + 1)
 77		index++
 78		if index >= f.Len() {
 79			return
 80		}
 81	}
 82}
 83
 84// SetSelectedItem sets the selected item in the list by item ID.
 85func (f *ModelsList) SetSelectedItem(itemID string) {
 86	if itemID == "" {
 87		f.SetSelected(0)
 88		return
 89	}
 90
 91	count := 0
 92	for _, g := range f.groups {
 93		for _, item := range g.Items {
 94			if item.ID() == itemID {
 95				f.SetSelected(count)
 96				return
 97			}
 98			count++
 99		}
100	}
101}
102
103// SelectNext selects the next model item, skipping any non-focusable items
104// like group headers and spacers.
105func (f *ModelsList) SelectNext() (v bool) {
106	for {
107		v = f.List.SelectNext()
108		selectedItem := f.SelectedItem()
109		if _, ok := selectedItem.(*ModelItem); ok {
110			return v
111		}
112	}
113}
114
115// SelectPrev selects the previous model item, skipping any non-focusable items
116// like group headers and spacers.
117func (f *ModelsList) SelectPrev() (v bool) {
118	for {
119		v = f.List.SelectPrev()
120		selectedItem := f.SelectedItem()
121		if _, ok := selectedItem.(*ModelItem); ok {
122			return v
123		}
124	}
125}
126
127// SelectFirst selects the first model item in the list.
128func (f *ModelsList) SelectFirst() (v bool) {
129	v = f.List.SelectFirst()
130	for {
131		selectedItem := f.SelectedItem()
132		if _, ok := selectedItem.(*ModelItem); ok {
133			return v
134		}
135		v = f.List.SelectNext()
136	}
137}
138
139// SelectLast selects the last model item in the list.
140func (f *ModelsList) SelectLast() (v bool) {
141	v = f.List.SelectLast()
142	for {
143		selectedItem := f.SelectedItem()
144		if _, ok := selectedItem.(*ModelItem); ok {
145			return v
146		}
147		v = f.List.SelectPrev()
148	}
149}
150
151// IsSelectedFirst checks if the selected item is the first model item.
152func (f *ModelsList) IsSelectedFirst() bool {
153	originalIndex := f.Selected()
154	f.SelectFirst()
155	isFirst := f.Selected() == originalIndex
156	f.List.SetSelected(originalIndex)
157	return isFirst
158}
159
160// IsSelectedLast checks if the selected item is the last model item.
161func (f *ModelsList) IsSelectedLast() bool {
162	originalIndex := f.Selected()
163	f.SelectLast()
164	isLast := f.Selected() == originalIndex
165	f.List.SetSelected(originalIndex)
166	return isLast
167}
168
169// VisibleItems returns the visible items after filtering.
170func (f *ModelsList) VisibleItems() []list.Item {
171	query := strings.ToLower(strings.ReplaceAll(f.query, " ", ""))
172
173	if query == "" {
174		// No filter, return all items with group headers
175		items := []list.Item{}
176		for _, g := range f.groups {
177			items = append(items, &g)
178			for _, item := range g.Items {
179				item.SetMatch(fuzzy.Match{})
180				items = append(items, item)
181			}
182			// Add a space separator after each provider section
183			items = append(items, list.NewSpacerItem(1))
184		}
185		return items
186	}
187
188	filterableItems := make([]list.FilterableItem, 0, f.Len())
189	for _, g := range f.groups {
190		for _, item := range g.Items {
191			filterableItems = append(filterableItems, item)
192		}
193	}
194
195	items := []list.Item{}
196	visitedGroups := map[int]bool{}
197
198	// Reconstruct groups with matched items
199	// Find which group this item belongs to
200	for gi, g := range f.groups {
201		addedCount := 0
202		name := strings.ToLower(g.Title) + " "
203
204		names := make([]string, len(filterableItems))
205		for i, item := range filterableItems {
206			ms := item.(*ModelItem)
207			names[i] = fmt.Sprintf("%s%s", name, ms.Filter())
208		}
209
210		matches := fuzzy.Find(query, names)
211		sort.SliceStable(matches, func(i, j int) bool {
212			return matches[i].Score > matches[j].Score
213		})
214
215		for _, match := range matches {
216			item := filterableItems[match.Index].(*ModelItem)
217			idxs := []int{}
218			for _, idx := range match.MatchedIndexes {
219				// Adjusts removing provider name highlights
220				if idx < len(name) {
221					continue
222				}
223				idxs = append(idxs, idx-len(name))
224			}
225
226			match.MatchedIndexes = idxs
227			if slices.Contains(g.Items, item) {
228				if !visitedGroups[gi] {
229					// Add section header
230					items = append(items, &g)
231					visitedGroups[gi] = true
232				}
233				// Add the matched item
234				item.SetMatch(match)
235				items = append(items, item)
236				addedCount++
237			}
238		}
239		if addedCount > 0 {
240			// Add a space separator after each provider section
241			items = append(items, list.NewSpacerItem(1))
242		}
243	}
244
245	return items
246}
247
248// Render renders the filterable list.
249func (f *ModelsList) Render() string {
250	f.SetItems(f.VisibleItems()...)
251	return f.List.Render()
252}
253
254type modelGroups []ModelGroup
255
256func (m modelGroups) Len() int {
257	n := 0
258	for _, g := range m {
259		n += len(g.Items)
260	}
261	return n
262}
263
264func (m modelGroups) String(i int) string {
265	count := 0
266	for _, g := range m {
267		if i < count+len(g.Items) {
268			return g.Items[i-count].Filter()
269		}
270		count += len(g.Items)
271	}
272	return ""
273}