models_list.go

  1package dialog
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/crush/internal/ui/list"
  7	"github.com/charmbracelet/crush/internal/ui/styles"
  8	"github.com/sahilm/fuzzy"
  9)
 10
 11// ModelsList is a list specifically for model items and groups.
 12type ModelsList struct {
 13	*list.List
 14	groups []ModelGroup
 15	items  []list.Item
 16	query  string
 17	t      *styles.Styles
 18}
 19
 20// NewModelsList creates a new list suitable for model items and groups.
 21func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
 22	f := &ModelsList{
 23		List:   list.NewList(),
 24		groups: groups,
 25		t:      sty,
 26	}
 27	return f
 28}
 29
 30// SetGroups sets the model groups and updates the list items.
 31func (f *ModelsList) SetGroups(groups ...ModelGroup) {
 32	f.groups = groups
 33}
 34
 35// SetFilter sets the filter query and updates the list items.
 36func (f *ModelsList) SetFilter(q string) {
 37	f.query = q
 38}
 39
 40// SetSelectedItem sets the selected item in the list by item ID.
 41func (f *ModelsList) SetSelectedItem(itemID string) {
 42	count := 0
 43	for _, g := range f.groups {
 44		for _, item := range g.Items {
 45			if item.ID() == itemID {
 46				f.List.SetSelected(count)
 47				return
 48			}
 49			count++
 50		}
 51	}
 52}
 53
 54// SelectNext selects the next selectable item in the list.
 55func (f *ModelsList) SelectNext() bool {
 56	for f.List.SelectNext() {
 57		if _, ok := f.List.SelectedItem().(*ModelItem); ok {
 58			return true
 59		}
 60	}
 61	return false
 62}
 63
 64// SelectPrev selects the previous selectable item in the list.
 65func (f *ModelsList) SelectPrev() bool {
 66	for f.List.SelectPrev() {
 67		if _, ok := f.List.SelectedItem().(*ModelItem); ok {
 68			return true
 69		}
 70	}
 71	return false
 72}
 73
 74// VisibleItems returns the visible items after filtering.
 75func (f *ModelsList) VisibleItems() []list.Item {
 76	if len(f.query) == 0 {
 77		// No filter, return all items with group headers
 78		items := []list.Item{}
 79		for _, g := range f.groups {
 80			items = append(items, &g)
 81			for _, item := range g.Items {
 82				items = append(items, item)
 83			}
 84			// Add a space separator after each provider section
 85			items = append(items, list.NewSpacerItem(1))
 86		}
 87		return items
 88	}
 89
 90	groupItems := map[int][]*ModelItem{}
 91	filterableItems := []list.FilterableItem{}
 92	for i, g := range f.groups {
 93		for _, item := range g.Items {
 94			filterableItems = append(filterableItems, item)
 95			groupItems[i] = append(groupItems[i], item)
 96		}
 97	}
 98
 99	matches := fuzzy.FindFrom(f.query, list.FilterableItemsSource(filterableItems))
100	for _, match := range matches {
101		item := filterableItems[match.Index]
102		if ms, ok := item.(list.MatchSettable); ok {
103			ms.SetMatch(match)
104			item = ms.(list.FilterableItem)
105		}
106		filterableItems = append(filterableItems, item)
107	}
108
109	items := []list.Item{}
110	visitedGroups := map[int]bool{}
111
112	// Reconstruct groups with matched items
113	for _, match := range matches {
114		item := filterableItems[match.Index]
115		// Find which group this item belongs to
116		for gi, g := range f.groups {
117			if slices.Contains(groupItems[gi], item.(*ModelItem)) {
118				if !visitedGroups[gi] {
119					// Add section header
120					items = append(items, &g)
121					visitedGroups[gi] = true
122				}
123				// Add the matched item
124				if ms, ok := item.(list.MatchSettable); ok {
125					ms.SetMatch(match)
126					item = ms.(list.FilterableItem)
127				}
128				// Add a space separator after each provider section
129				items = append(items, item, list.NewSpacerItem(1))
130				break
131			}
132		}
133	}
134
135	return items
136}
137
138// Render renders the filterable list.
139func (f *ModelsList) Render() string {
140	f.List.SetItems(f.VisibleItems()...)
141	return f.List.Render()
142}
143
144type modelGroups []ModelGroup
145
146func (m modelGroups) Len() int {
147	n := 0
148	for _, g := range m {
149		n += len(g.Items)
150	}
151	return n
152}
153
154func (m modelGroups) String(i int) string {
155	count := 0
156	for _, g := range m {
157		if i < count+len(g.Items) {
158			return g.Items[i-count].Filter()
159		}
160		count += len(g.Items)
161	}
162	return ""
163}