models_list.go

  1package dialog
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"sort"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/ui/list"
 10	"github.com/charmbracelet/crush/internal/ui/styles"
 11	"github.com/sahilm/fuzzy"
 12)
 13
 14// ModelsList is a list specifically for model items and groups.
 15type ModelsList struct {
 16	*list.List
 17	groups []ModelGroup
 18	query  string
 19	t      *styles.Styles
 20}
 21
 22// NewModelsList creates a new list suitable for model items and groups.
 23func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
 24	f := &ModelsList{
 25		List:   list.NewList(),
 26		groups: groups,
 27		t:      sty,
 28	}
 29	f.RegisterRenderCallback(list.FocusedRenderCallback(f.List))
 30	return f
 31}
 32
 33// Len returns the number of model items across all groups.
 34func (f *ModelsList) Len() int {
 35	n := 0
 36	for _, g := range f.groups {
 37		n += len(g.Items)
 38	}
 39	return n
 40}
 41
 42// SetGroups sets the model groups and updates the list items.
 43func (f *ModelsList) SetGroups(groups ...ModelGroup) {
 44	f.groups = groups
 45	items := []list.Item{}
 46	for _, g := range f.groups {
 47		items = append(items, &g)
 48		for _, item := range g.Items {
 49			items = append(items, item)
 50		}
 51		// Add a space separator after each provider section
 52		items = append(items, list.NewSpacerItem(1))
 53	}
 54	f.SetItems(items...)
 55}
 56
 57// SetFilter sets the filter query and updates the list items.
 58func (f *ModelsList) SetFilter(q string) {
 59	f.query = q
 60	f.SetItems(f.VisibleItems()...)
 61}
 62
 63// SetSelected sets the selected item index. It overrides the base method to
 64// skip non-model items.
 65func (f *ModelsList) SetSelected(index int) {
 66	if index < 0 || index >= f.Len() {
 67		f.List.SetSelected(index)
 68		return
 69	}
 70
 71	f.List.SetSelected(index)
 72	for {
 73		selectedItem := f.SelectedItem()
 74		if _, ok := selectedItem.(*ModelItem); ok {
 75			return
 76		}
 77		f.List.SetSelected(index + 1)
 78		index++
 79		if index >= f.Len() {
 80			return
 81		}
 82	}
 83}
 84
 85// SetSelectedItem sets the selected item in the list by item ID.
 86func (f *ModelsList) SetSelectedItem(itemID string) {
 87	if itemID == "" {
 88		f.SetSelected(0)
 89		return
 90	}
 91
 92	count := 0
 93	for _, g := range f.groups {
 94		for _, item := range g.Items {
 95			if item.ID() == itemID {
 96				f.SetSelected(count)
 97				return
 98			}
 99			count++
100		}
101	}
102}
103
104// SelectNext selects the next model item, skipping any non-focusable items
105// like group headers and spacers.
106func (f *ModelsList) SelectNext() (v bool) {
107	v = f.List.SelectNext()
108	for v {
109		selectedItem := f.SelectedItem()
110		if _, ok := selectedItem.(*ModelItem); ok {
111			return v
112		}
113		v = f.List.SelectNext()
114	}
115	return v
116}
117
118// SelectPrev selects the previous model item, skipping any non-focusable items
119// like group headers and spacers.
120func (f *ModelsList) SelectPrev() (v bool) {
121	v = f.List.SelectPrev()
122	for v {
123		selectedItem := f.SelectedItem()
124		if _, ok := selectedItem.(*ModelItem); ok {
125			return v
126		}
127		v = f.List.SelectPrev()
128	}
129	return v
130}
131
132// SelectFirst selects the first model item in the list.
133func (f *ModelsList) SelectFirst() (v bool) {
134	v = f.List.SelectFirst()
135	for v {
136		selectedItem := f.SelectedItem()
137		_, ok := selectedItem.(*ModelItem)
138		if ok {
139			return v
140		}
141		v = f.List.SelectNext()
142	}
143	return v
144}
145
146// SelectLast selects the last model item in the list.
147func (f *ModelsList) SelectLast() (v bool) {
148	v = f.List.SelectLast()
149	for v {
150		selectedItem := f.SelectedItem()
151		if _, ok := selectedItem.(*ModelItem); ok {
152			return v
153		}
154		v = f.List.SelectPrev()
155	}
156	return v
157}
158
159// IsSelectedFirst checks if the selected item is the first model item.
160func (f *ModelsList) IsSelectedFirst() bool {
161	originalIndex := f.Selected()
162	f.SelectFirst()
163	isFirst := f.Selected() == originalIndex
164	f.List.SetSelected(originalIndex)
165	return isFirst
166}
167
168// IsSelectedLast checks if the selected item is the last model item.
169func (f *ModelsList) IsSelectedLast() bool {
170	originalIndex := f.Selected()
171	f.SelectLast()
172	isLast := f.Selected() == originalIndex
173	f.List.SetSelected(originalIndex)
174	return isLast
175}
176
177// VisibleItems returns the visible items after filtering.
178func (f *ModelsList) VisibleItems() []list.Item {
179	query := strings.ToLower(strings.ReplaceAll(f.query, " ", ""))
180
181	if query == "" {
182		// No filter, return all items with group headers
183		items := []list.Item{}
184		for _, g := range f.groups {
185			items = append(items, &g)
186			for _, item := range g.Items {
187				item.SetMatch(fuzzy.Match{})
188				items = append(items, item)
189			}
190			// Add a space separator after each provider section
191			items = append(items, list.NewSpacerItem(1))
192		}
193		return items
194	}
195
196	filterableItems := make([]list.FilterableItem, 0, f.Len())
197	for _, g := range f.groups {
198		for _, item := range g.Items {
199			filterableItems = append(filterableItems, item)
200		}
201	}
202
203	items := []list.Item{}
204	visitedGroups := map[int]bool{}
205
206	// Reconstruct groups with matched items
207	// Find which group this item belongs to
208	for gi, g := range f.groups {
209		addedCount := 0
210		name := strings.ToLower(g.Title) + " "
211
212		names := make([]string, len(filterableItems))
213		for i, item := range filterableItems {
214			ms := item.(*ModelItem)
215			names[i] = fmt.Sprintf("%s%s", name, ms.Filter())
216		}
217
218		matches := fuzzy.Find(query, names)
219		sort.SliceStable(matches, func(i, j int) bool {
220			return matches[i].Score > matches[j].Score
221		})
222
223		for _, match := range matches {
224			item := filterableItems[match.Index].(*ModelItem)
225			idxs := []int{}
226			for _, idx := range match.MatchedIndexes {
227				// Adjusts removing provider name highlights
228				if idx < len(name) {
229					continue
230				}
231				idxs = append(idxs, idx-len(name))
232			}
233
234			match.MatchedIndexes = idxs
235			if slices.Contains(g.Items, item) {
236				if !visitedGroups[gi] {
237					// Add section header
238					items = append(items, &g)
239					visitedGroups[gi] = true
240				}
241				// Add the matched item
242				item.SetMatch(match)
243				items = append(items, item)
244				addedCount++
245			}
246		}
247		if addedCount > 0 {
248			// Add a space separator after each provider section
249			items = append(items, list.NewSpacerItem(1))
250		}
251	}
252
253	return items
254}
255
256// Render renders the filterable list.
257func (f *ModelsList) Render() string {
258	f.SetItems(f.VisibleItems()...)
259	return f.List.Render()
260}
261
262type modelGroups []ModelGroup
263
264func (m modelGroups) Len() int {
265	n := 0
266	for _, g := range m {
267		n += len(g.Items)
268	}
269	return n
270}
271
272func (m modelGroups) String(i int) string {
273	count := 0
274	for _, g := range m {
275		if i < count+len(g.Items) {
276			return g.Items[i-count].Filter()
277		}
278		count += len(g.Items)
279	}
280	return ""
281}