1package models
2
3import (
4 "fmt"
5 "slices"
6
7 tea "github.com/charmbracelet/bubbletea/v2"
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/fur/provider"
10 "github.com/charmbracelet/crush/internal/tui/components/completions"
11 "github.com/charmbracelet/crush/internal/tui/components/core/list"
12 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15 "github.com/charmbracelet/lipgloss/v2"
16)
17
18type ModelListComponent struct {
19 list list.ListModel
20 modelType int
21}
22
23func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
24 modelList := list.New(
25 list.WithFilterable(true),
26 list.WithKeyMap(keyMap),
27 list.WithInputStyle(inputStyle),
28 list.WithFilterPlaceholder(inputPlaceholder),
29 list.WithWrapNavigation(true),
30 )
31
32 return &ModelListComponent{
33 list: modelList,
34 modelType: LargeModelType,
35 }
36}
37
38func (m *ModelListComponent) Init() tea.Cmd {
39 return tea.Batch(m.list.Init(), m.SetModelType(m.modelType))
40}
41
42func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
43 u, cmd := m.list.Update(msg)
44 m.list = u.(list.ListModel)
45 return m, cmd
46}
47
48func (m *ModelListComponent) View() tea.View {
49 return m.list.View()
50}
51
52func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
53 return m.list.SetSize(width, height)
54}
55
56func (m *ModelListComponent) Items() []util.Model {
57 return m.list.Items()
58}
59
60func (m *ModelListComponent) SelectedIndex() int {
61 return m.list.SelectedIndex()
62}
63
64func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
65 t := styles.CurrentTheme()
66 m.modelType = modelType
67
68 providers, err := config.Providers()
69 if err != nil {
70 return util.ReportError(err)
71 }
72
73 modelItems := []util.Model{}
74 selectIndex := 0
75
76 cfg := config.Get()
77 var currentModel config.SelectedModel
78 if m.modelType == LargeModelType {
79 currentModel = cfg.Models[config.SelectedModelTypeLarge]
80 } else {
81 currentModel = cfg.Models[config.SelectedModelTypeSmall]
82 }
83
84 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
85 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
86
87 // Create a map to track which providers we've already added
88 addedProviders := make(map[string]bool)
89
90 // First, add any configured providers that are not in the known providers list
91 // These should appear at the top of the list
92 knownProviders := provider.KnownProviders()
93 for providerID, providerConfig := range cfg.Providers {
94 if providerConfig.Disable {
95 continue
96 }
97
98 // Check if this provider is not in the known providers list
99 if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
100 // Convert config provider to provider.Provider format
101 configProvider := provider.Provider{
102 Name: string(providerID), // Use provider ID as name for unknown providers
103 ID: provider.InferenceProvider(providerID),
104 Models: make([]provider.Model, len(providerConfig.Models)),
105 }
106
107 // Convert models
108 for i, model := range providerConfig.Models {
109 configProvider.Models[i] = provider.Model{
110 ID: model.ID,
111 Name: model.Name,
112 CostPer1MIn: model.CostPer1MIn,
113 CostPer1MOut: model.CostPer1MOut,
114 CostPer1MInCached: model.CostPer1MInCached,
115 CostPer1MOutCached: model.CostPer1MOutCached,
116 ContextWindow: model.ContextWindow,
117 DefaultMaxTokens: model.DefaultMaxTokens,
118 CanReason: model.CanReason,
119 HasReasoningEffort: model.HasReasoningEffort,
120 DefaultReasoningEffort: model.DefaultReasoningEffort,
121 SupportsImages: model.SupportsImages,
122 }
123 }
124
125 // Add this unknown provider to the list
126 name := configProvider.Name
127 if name == "" {
128 name = string(configProvider.ID)
129 }
130 section := commands.NewItemSection(name)
131 section.SetInfo(configured)
132 modelItems = append(modelItems, section)
133 for _, model := range configProvider.Models {
134 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
135 Provider: configProvider,
136 Model: model,
137 }))
138 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
139 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
140 }
141 }
142 addedProviders[providerID] = true
143 }
144 }
145
146 // Then add the known providers from the predefined list
147 for _, provider := range providers {
148 // Skip if we already added this provider as an unknown provider
149 if addedProviders[string(provider.ID)] {
150 continue
151 }
152
153 // Check if this provider is configured and not disabled
154 if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
155 continue
156 }
157
158 name := provider.Name
159 if name == "" {
160 name = string(provider.ID)
161 }
162
163 section := commands.NewItemSection(name)
164 if _, ok := cfg.Providers[string(provider.ID)]; ok {
165 section.SetInfo(configured)
166 }
167 modelItems = append(modelItems, section)
168 for _, model := range provider.Models {
169 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
170 Provider: provider,
171 Model: model,
172 }))
173 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
174 selectIndex = len(modelItems) - 1 // Set the selected index to the current model
175 }
176 }
177 }
178
179 return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
180}
181
182// GetModelType returns the current model type
183func (m *ModelListComponent) GetModelType() int {
184 return m.modelType
185}
186
187func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
188 m.list.SetFilterPlaceholder(placeholder)
189}