1package dialog
2
3import (
4 "slices"
5
6 "github.com/charmbracelet/crush/internal/ui/list"
7 "github.com/charmbracelet/crush/internal/ui/styles"
8 "github.com/sahilm/fuzzy"
9)
10
11// ModelsList is a list specifically for model items and groups.
12type ModelsList struct {
13 *list.List
14 groups []ModelGroup
15 query string
16 t *styles.Styles
17}
18
19// NewModelsList creates a new list suitable for model items and groups.
20func NewModelsList(sty *styles.Styles, groups ...ModelGroup) *ModelsList {
21 f := &ModelsList{
22 List: list.NewList(),
23 groups: groups,
24 t: sty,
25 }
26 return f
27}
28
29// Len returns the number of model items across all groups.
30func (f *ModelsList) Len() int {
31 n := 0
32 for _, g := range f.groups {
33 n += len(g.Items)
34 }
35 return n
36}
37
38// SetGroups sets the model groups and updates the list items.
39func (f *ModelsList) SetGroups(groups ...ModelGroup) {
40 f.groups = groups
41 items := []list.Item{}
42 for _, g := range f.groups {
43 items = append(items, &g)
44 for _, item := range g.Items {
45 items = append(items, item)
46 }
47 // Add a space separator after each provider section
48 items = append(items, list.NewSpacerItem(1))
49 }
50 f.List.SetItems(items...)
51}
52
53// SetFilter sets the filter query and updates the list items.
54func (f *ModelsList) SetFilter(q string) {
55 f.query = q
56}
57
58// SetSelectedItem sets the selected item in the list by item ID.
59func (f *ModelsList) SetSelectedItem(itemID string) {
60 if itemID == "" {
61 f.SetSelected(0)
62 return
63 }
64
65 count := 0
66 for _, g := range f.groups {
67 for _, item := range g.Items {
68 if item.ID() == itemID {
69 f.List.SetSelected(count)
70 return
71 }
72 count++
73 }
74 }
75}
76
77// VisibleItems returns the visible items after filtering.
78func (f *ModelsList) VisibleItems() []list.Item {
79 if len(f.query) == 0 {
80 // No filter, return all items with group headers
81 items := []list.Item{}
82 for _, g := range f.groups {
83 items = append(items, &g)
84 for _, item := range g.Items {
85 items = append(items, item)
86 }
87 // Add a space separator after each provider section
88 items = append(items, list.NewSpacerItem(1))
89 }
90 return items
91 }
92
93 filterableItems := make([]list.FilterableItem, 0, f.Len())
94 for _, g := range f.groups {
95 for _, item := range g.Items {
96 filterableItems = append(filterableItems, item)
97 }
98 }
99
100 matches := fuzzy.FindFrom(f.query, list.FilterableItemsSource(filterableItems))
101 for _, match := range matches {
102 item := filterableItems[match.Index]
103 if ms, ok := item.(list.MatchSettable); ok {
104 ms.SetMatch(match)
105 item = ms.(list.FilterableItem)
106 }
107 filterableItems = append(filterableItems, item)
108 }
109
110 items := []list.Item{}
111 visitedGroups := map[int]bool{}
112
113 // Reconstruct groups with matched items
114 // Find which group this item belongs to
115 for gi, g := range f.groups {
116 addedCount := 0
117 for _, match := range matches {
118 item := filterableItems[match.Index]
119 if slices.Contains(g.Items, item.(*ModelItem)) {
120 if !visitedGroups[gi] {
121 // Add section header
122 items = append(items, &g)
123 visitedGroups[gi] = true
124 }
125 // Add the matched item
126 if ms, ok := item.(list.MatchSettable); ok {
127 ms.SetMatch(match)
128 item = ms.(list.FilterableItem)
129 }
130 items = append(items, item)
131 addedCount++
132 }
133 }
134 if addedCount > 0 {
135 // Add a space separator after each provider section
136 items = append(items, list.NewSpacerItem(1))
137 }
138 }
139
140 return items
141}
142
143// Render renders the filterable list.
144func (f *ModelsList) Render() string {
145 f.List.SetItems(f.VisibleItems()...)
146 return f.List.Render()
147}
148
149type modelGroups []ModelGroup
150
151func (m modelGroups) Len() int {
152 n := 0
153 for _, g := range m {
154 n += len(g.Items)
155 }
156 return n
157}
158
159func (m modelGroups) String(i int) string {
160 count := 0
161 for _, g := range m {
162 if i < count+len(g.Items) {
163 return g.Items[i-count].Filter()
164 }
165 count += len(g.Items)
166 }
167 return ""
168}