1package models
2
3import (
4 "slices"
5
6 tea "github.com/charmbracelet/bubbletea/v2"
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/fur/provider"
9 "github.com/charmbracelet/crush/internal/tui/components/completions"
10 "github.com/charmbracelet/crush/internal/tui/components/core/list"
11 "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
12 "github.com/charmbracelet/crush/internal/tui/util"
13 "github.com/charmbracelet/lipgloss/v2"
14)
15
16type ModelListComponent struct {
17 list list.ListModel
18 modelType int
19}
20
21func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style) *ModelListComponent {
22 modelList := list.New(
23 list.WithFilterable(true),
24 list.WithKeyMap(keyMap),
25 list.WithInputStyle(inputStyle),
26 list.WithWrapNavigation(true),
27 )
28
29 return &ModelListComponent{
30 list: modelList,
31 modelType: LargeModelType,
32 }
33}
34
35func (m *ModelListComponent) Init() tea.Cmd {
36 return tea.Batch(m.list.Init(), m.SetModelType(m.modelType))
37}
38
39func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
40 u, cmd := m.list.Update(msg)
41 m.list = u.(list.ListModel)
42 return m, cmd
43}
44
45func (m *ModelListComponent) View() tea.View {
46 return m.list.View()
47}
48
49func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
50 return m.list.SetSize(width, height)
51}
52
53func (m *ModelListComponent) Items() []util.Model {
54 return m.list.Items()
55}
56
57func (m *ModelListComponent) SelectedIndex() int {
58 return m.list.SelectedIndex()
59}
60
61func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
62 m.modelType = modelType
63
64 providers := config.Providers()
65 modelItems := []util.Model{}
66 selectIndex := 0
67
68 cfg := config.Get()
69 var currentModel config.PreferredModel
70 if m.modelType == LargeModelType {
71 currentModel = cfg.Models.Large
72 } else {
73 currentModel = cfg.Models.Small
74 }
75
76 addedProviders := make(map[provider.InferenceProvider]bool)
77
78 knownProviders := provider.KnownProviders()
79 for providerID, providerConfig := range cfg.Providers {
80 if providerConfig.Disabled {
81 continue
82 }
83
84 // Check if this provider is not in the known providers list
85 if !slices.Contains(knownProviders, providerID) {
86 configProvider := provider.Provider{
87 Name: string(providerID),
88 ID: providerID,
89 Models: make([]provider.Model, len(providerConfig.Models)),
90 }
91
92 for i, model := range providerConfig.Models {
93 configProvider.Models[i] = provider.Model{
94 ID: model.ID,
95 Name: model.Name,
96 CostPer1MIn: model.CostPer1MIn,
97 CostPer1MOut: model.CostPer1MOut,
98 CostPer1MInCached: model.CostPer1MInCached,
99 CostPer1MOutCached: model.CostPer1MOutCached,
100 ContextWindow: model.ContextWindow,
101 DefaultMaxTokens: model.DefaultMaxTokens,
102 CanReason: model.CanReason,
103 HasReasoningEffort: model.HasReasoningEffort,
104 DefaultReasoningEffort: model.ReasoningEffort,
105 SupportsImages: model.SupportsImages,
106 }
107 }
108
109 name := configProvider.Name
110 if name == "" {
111 name = string(configProvider.ID)
112 }
113 modelItems = append(modelItems, commands.NewItemSection(name))
114 for _, model := range configProvider.Models {
115 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
116 Provider: configProvider,
117 Model: model,
118 }))
119 if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
120 selectIndex = len(modelItems) - 1
121 }
122 }
123 addedProviders[providerID] = true
124 }
125 }
126
127 for _, provider := range providers {
128 if addedProviders[provider.ID] {
129 continue
130 }
131
132 if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
133 continue
134 }
135
136 name := provider.Name
137 if name == "" {
138 name = string(provider.ID)
139 }
140 modelItems = append(modelItems, commands.NewItemSection(name))
141 for _, model := range provider.Models {
142 modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
143 Provider: provider,
144 Model: model,
145 }))
146 if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
147 selectIndex = len(modelItems) - 1
148 }
149 }
150 }
151
152 return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
153}
154
155// GetModelType returns the current model type
156func (m *ModelListComponent) GetModelType() int {
157 return m.modelType
158}
159