1package models
2
3import (
4 "fmt"
5 "slices"
6
7 tea "github.com/charmbracelet/bubbletea/v2"
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/fur/provider"
10 "github.com/charmbracelet/crush/internal/tui/components/completions"
11 "github.com/charmbracelet/crush/internal/tui/components/core/list"
12 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15 "github.com/charmbracelet/lipgloss/v2"
16)
17
18type ModelListComponent struct {
19 list list.ListModel
20 modelType int
21 providers []provider.Provider
22}
23
24func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
25 modelList := list.New(
26 list.WithFilterable(true),
27 list.WithKeyMap(keyMap),
28 list.WithInputStyle(inputStyle),
29 list.WithFilterPlaceholder(inputPlaceholder),
30 list.WithWrapNavigation(true),
31 )
32
33 return &ModelListComponent{
34 list: modelList,
35 modelType: LargeModelType,
36 }
37}
38
39func (m *ModelListComponent) Init() tea.Cmd {
40 var cmds []tea.Cmd
41 if len(m.providers) == 0 {
42 providers, err := config.Providers()
43 m.providers = providers
44 if err != nil {
45 cmds = append(cmds, util.ReportError(err))
46 }
47 }
48 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
49 return tea.Batch(cmds...)
50}
51
52func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
53 u, cmd := m.list.Update(msg)
54 m.list = u.(list.ListModel)
55 return m, cmd
56}
57
58func (m *ModelListComponent) View() string {
59 return m.list.View()
60}
61
62func (m *ModelListComponent) Cursor() *tea.Cursor {
63 return m.list.Cursor()
64}
65
66func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
67 return m.list.SetSize(width, height)
68}
69
70func (m *ModelListComponent) Items() []util.Model {
71 return m.list.Items()
72}
73
74func (m *ModelListComponent) SelectedIndex() int {
75 return m.list.SelectedIndex()
76}
77
78func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
79 t := styles.CurrentTheme()
80 m.modelType = modelType
81
82 modelItems := []util.Model{}
83 selectIndex := 0
84
85 cfg := config.Get()
86 var currentModel config.SelectedModel
87 if m.modelType == LargeModelType {
88 currentModel = cfg.Models[config.SelectedModelTypeLarge]
89 } else {
90 currentModel = cfg.Models[config.SelectedModelTypeSmall]
91 }
92
93 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
94 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
95
96 // Create a map to track which providers we've already added
97 addedProviders := make(map[string]bool)
98
99 // First, add any configured providers that are not in the known providers list
100 // These should appear at the top of the list
101 knownProviders, err := config.Providers()
102 if err != nil {
103 return util.ReportError(err)
104 }
105 for providerID, providerConfig := range cfg.Providers {
106 if providerConfig.Disable {
107 continue
108 }
109
110 // Check if this provider is not in the known providers list
111 if !slices.ContainsFunc(knownProviders, func(p provider.Provider) bool { return p.ID == provider.InferenceProvider(providerID) }) {
112 // Convert config provider to provider.Provider format
113 configProvider := provider.Provider{
114 Name: providerConfig.Name,
115 ID: provider.InferenceProvider(providerID),
116 Models: make([]provider.Model, len(providerConfig.Models)),
117 }
118
119 // Convert models
120 for i, model := range providerConfig.Models {
121 configProvider.Models[i] = provider.Model{
122 ID: model.ID,
123 Model: model.Model,
124 CostPer1MIn: model.CostPer1MIn,
125 CostPer1MOut: model.CostPer1MOut,
126 CostPer1MInCached: model.CostPer1MInCached,
127 CostPer1MOutCached: model.CostPer1MOutCached,
128 ContextWindow: model.ContextWindow,
129 DefaultMaxTokens: model.DefaultMaxTokens,
130 CanReason: model.CanReason,
131 HasReasoningEffort: model.HasReasoningEffort,
132 DefaultReasoningEffort: model.DefaultReasoningEffort,
133 SupportsImages: model.SupportsImages,
134 }
135 }
136
137 // Add this unknown provider to the list
138 name := configProvider.Name
139 if name == "" {
140 name = string(configProvider.ID)
141 }
142 section := commands.NewItemSection(name)
143 section.SetInfo(configured)
144 modelItems = append(modelItems, section)
145 for _, model := range configProvider.Models {
146 modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{
147 Provider: configProvider,
148 Model: model,
149 }))
150 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
151 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
152 }
153 }
154 addedProviders[providerID] = true
155 }
156 }
157
158 // Then add the known providers from the predefined list
159 for _, provider := range m.providers {
160 // Skip if we already added this provider as an unknown provider
161 if addedProviders[string(provider.ID)] {
162 continue
163 }
164
165 // Check if this provider is configured and not disabled
166 if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
167 continue
168 }
169
170 name := provider.Name
171 if name == "" {
172 name = string(provider.ID)
173 }
174
175 section := commands.NewItemSection(name)
176 if _, ok := cfg.Providers[string(provider.ID)]; ok {
177 section.SetInfo(configured)
178 }
179 modelItems = append(modelItems, section)
180 for _, model := range provider.Models {
181 modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{
182 Provider: provider,
183 Model: model,
184 }))
185 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
186 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
187 }
188 }
189 }
190
191 return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
192}
193
194// GetModelType returns the current model type
195func (m *ModelListComponent) GetModelType() int {
196 return m.modelType
197}
198
199func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
200 m.list.SetFilterPlaceholder(placeholder)
201}
202
203func (m *ModelListComponent) SetProviders(providers []provider.Provider) {
204 m.providers = providers
205}