models_list.go

  1package dialog
  2
  3import (
  4	"slices"
  5
  6	"github.com/charmbracelet/crush/internal/ui/list"
  7	"github.com/charmbracelet/crush/internal/ui/styles"
  8	"github.com/sahilm/fuzzy"
  9)
 10
 11// ModelsList is a list specifically for model items and groups.
 12type ModelsList struct {
 13	*list.List
 14	groups []ModelGroup
 15	query  string
 16	t      *styles.Styles
 17}
 18
 19// NewModelsList creates a new list suitable for model items and groups.
 20func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
 21	f := &ModelsList{
 22		List:   list.NewList(),
 23		groups: groups,
 24		t:      sty,
 25	}
 26	return f
 27}
 28
 29// Len returns the number of model items across all groups.
 30func (f *ModelsList) Len() int {
 31	n := 0
 32	for _, g := range f.groups {
 33		n += len(g.Items)
 34	}
 35	return n
 36}
 37
 38// SetGroups sets the model groups and updates the list items.
 39func (f *ModelsList) SetGroups(groups ...ModelGroup) {
 40	f.groups = groups
 41	items := []list.Item{}
 42	for _, g := range f.groups {
 43		items = append(items, &g)
 44		for _, item := range g.Items {
 45			items = append(items, item)
 46		}
 47		// Add a space separator after each provider section
 48		items = append(items, list.NewSpacerItem(1))
 49	}
 50	f.List.SetItems(items...)
 51}
 52
 53// SetFilter sets the filter query and updates the list items.
 54func (f *ModelsList) SetFilter(q string) {
 55	f.query = q
 56}
 57
 58// SetSelectedItem sets the selected item in the list by item ID.
 59func (f *ModelsList) SetSelectedItem(itemID string) {
 60	if itemID == "" {
 61		f.SetSelected(0)
 62		return
 63	}
 64
 65	count := 0
 66	for _, g := range f.groups {
 67		for _, item := range g.Items {
 68			if item.ID() == itemID {
 69				f.List.SetSelected(count)
 70				return
 71			}
 72			count++
 73		}
 74	}
 75}
 76
 77// VisibleItems returns the visible items after filtering.
 78func (f *ModelsList) VisibleItems() []list.Item {
 79	if len(f.query) == 0 {
 80		// No filter, return all items with group headers
 81		items := []list.Item{}
 82		for _, g := range f.groups {
 83			items = append(items, &g)
 84			for _, item := range g.Items {
 85				items = append(items, item)
 86			}
 87			// Add a space separator after each provider section
 88			items = append(items, list.NewSpacerItem(1))
 89		}
 90		return items
 91	}
 92
 93	filterableItems := make([]list.FilterableItem, 0, f.Len())
 94	for _, g := range f.groups {
 95		for _, item := range g.Items {
 96			filterableItems = append(filterableItems, item)
 97		}
 98	}
 99
100	matches := fuzzy.FindFrom(f.query, list.FilterableItemsSource(filterableItems))
101	for _, match := range matches {
102		item := filterableItems[match.Index]
103		if ms, ok := item.(list.MatchSettable); ok {
104			ms.SetMatch(match)
105			item = ms.(list.FilterableItem)
106		}
107		filterableItems = append(filterableItems, item)
108	}
109
110	items := []list.Item{}
111	visitedGroups := map[int]bool{}
112
113	// Reconstruct groups with matched items
114	// Find which group this item belongs to
115	for gi, g := range f.groups {
116		addedCount := 0
117		for _, match := range matches {
118			item := filterableItems[match.Index]
119			if slices.Contains(g.Items, item.(*ModelItem)) {
120				if !visitedGroups[gi] {
121					// Add section header
122					items = append(items, &g)
123					visitedGroups[gi] = true
124				}
125				// Add the matched item
126				if ms, ok := item.(list.MatchSettable); ok {
127					ms.SetMatch(match)
128					item = ms.(list.FilterableItem)
129				}
130				items = append(items, item)
131				addedCount++
132			}
133		}
134		if addedCount > 0 {
135			// Add a space separator after each provider section
136			items = append(items, list.NewSpacerItem(1))
137		}
138	}
139
140	return items
141}
142
143// Render renders the filterable list.
144func (f *ModelsList) Render() string {
145	f.List.SetItems(f.VisibleItems()...)
146	return f.List.Render()
147}
148
149type modelGroups []ModelGroup
150
151func (m modelGroups) Len() int {
152	n := 0
153	for _, g := range m {
154		n += len(g.Items)
155	}
156	return n
157}
158
159func (m modelGroups) String(i int) string {
160	count := 0
161	for _, g := range m {
162		if i < count+len(g.Items) {
163			return g.Items[i-count].Filter()
164		}
165		count += len(g.Items)
166	}
167	return ""
168}