models.go

  1package models
  2
  3import (
  4	"github.com/charmbracelet/bubbles/v2/help"
  5	"github.com/charmbracelet/bubbles/v2/key"
  6	tea "github.com/charmbracelet/bubbletea/v2"
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/tui/components/completions"
 10	"github.com/charmbracelet/crush/internal/tui/components/core"
 11	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 12	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15	"github.com/charmbracelet/lipgloss/v2"
 16)
 17
 18const (
 19	ModelsDialogID dialogs.DialogID = "models"
 20
 21	defaultWidth = 60
 22)
 23
 24const (
 25	LargeModelType int = iota
 26	SmallModelType
 27
 28	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 29	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 30)
 31
 32// ModelSelectedMsg is sent when a model is selected
 33type ModelSelectedMsg struct {
 34	Model     config.SelectedModel
 35	ModelType config.SelectedModelType
 36}
 37
 38// CloseModelDialogMsg is sent when a model is selected
 39type CloseModelDialogMsg struct{}
 40
 41// ModelDialog interface for the model selection dialog
 42type ModelDialog interface {
 43	dialogs.DialogModel
 44}
 45
 46type ModelOption struct {
 47	Provider provider.Provider
 48	Model    provider.Model
 49}
 50
 51type modelDialogCmp struct {
 52	width   int
 53	wWidth  int
 54	wHeight int
 55
 56	modelList *ModelListComponent
 57	keyMap    KeyMap
 58	help      help.Model
 59}
 60
 61func NewModelDialogCmp() ModelDialog {
 62	listKeyMap := list.DefaultKeyMap()
 63	keyMap := DefaultKeyMap()
 64
 65	listKeyMap.Down.SetEnabled(false)
 66	listKeyMap.Up.SetEnabled(false)
 67	listKeyMap.HalfPageDown.SetEnabled(false)
 68	listKeyMap.HalfPageUp.SetEnabled(false)
 69	listKeyMap.Home.SetEnabled(false)
 70	listKeyMap.End.SetEnabled(false)
 71
 72	listKeyMap.DownOneItem = keyMap.Next
 73	listKeyMap.UpOneItem = keyMap.Previous
 74
 75	t := styles.CurrentTheme()
 76	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 77	modelList := NewModelListComponent(listKeyMap, inputStyle, "Choose a model for large, complex tasks")
 78	help := help.New()
 79	help.Styles = t.S().Help
 80
 81	return &modelDialogCmp{
 82		modelList: modelList,
 83		width:     defaultWidth,
 84		keyMap:    DefaultKeyMap(),
 85		help:      help,
 86	}
 87}
 88
 89func (m *modelDialogCmp) Init() tea.Cmd {
 90	return m.modelList.Init()
 91}
 92
 93func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 94	switch msg := msg.(type) {
 95	case tea.WindowSizeMsg:
 96		m.wWidth = msg.Width
 97		m.wHeight = msg.Height
 98		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
 99	case tea.KeyPressMsg:
100		switch {
101		case key.Matches(msg, m.keyMap.Select):
102			selectedItemInx := m.modelList.SelectedIndex()
103			if selectedItemInx == list.NoSelection {
104				return m, nil
105			}
106			items := m.modelList.Items()
107			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
108
109			var modelType config.SelectedModelType
110			if m.modelList.GetModelType() == LargeModelType {
111				modelType = config.SelectedModelTypeLarge
112			} else {
113				modelType = config.SelectedModelTypeSmall
114			}
115
116			return m, tea.Sequence(
117				util.CmdHandler(dialogs.CloseDialogMsg{}),
118				util.CmdHandler(ModelSelectedMsg{
119					Model: config.SelectedModel{
120						Model:    selectedItem.Model.ID,
121						Provider: string(selectedItem.Provider.ID),
122					},
123					ModelType: modelType,
124				}),
125			)
126		case key.Matches(msg, m.keyMap.Tab):
127			if m.modelList.GetModelType() == LargeModelType {
128				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
129				return m, m.modelList.SetModelType(SmallModelType)
130			} else {
131				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
132				return m, m.modelList.SetModelType(LargeModelType)
133			}
134		case key.Matches(msg, m.keyMap.Close):
135			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
136		default:
137			u, cmd := m.modelList.Update(msg)
138			m.modelList = u
139			return m, cmd
140		}
141	}
142	return m, nil
143}
144
145func (m *modelDialogCmp) View() tea.View {
146	t := styles.CurrentTheme()
147	listView := m.modelList.View()
148	radio := m.modelTypeRadio()
149	content := lipgloss.JoinVertical(
150		lipgloss.Left,
151		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
152		listView.String(),
153		"",
154		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
155	)
156	v := tea.NewView(m.style().Render(content))
157	if listView.Cursor() != nil {
158		c := m.moveCursor(listView.Cursor())
159		v.SetCursor(c)
160	}
161	return v
162}
163
164func (m *modelDialogCmp) style() lipgloss.Style {
165	t := styles.CurrentTheme()
166	return t.S().Base.
167		Width(m.width).
168		Border(lipgloss.RoundedBorder()).
169		BorderForeground(t.BorderFocus)
170}
171
172func (m *modelDialogCmp) listWidth() int {
173	return defaultWidth - 2 // 4 for padding
174}
175
176func (m *modelDialogCmp) listHeight() int {
177	items := m.modelList.Items()
178	listHeigh := len(items) + 2 + 4
179	return min(listHeigh, m.wHeight/2)
180}
181
182func (m *modelDialogCmp) Position() (int, int) {
183	row := m.wHeight/4 - 2 // just a bit above the center
184	col := m.wWidth / 2
185	col -= m.width / 2
186	return row, col
187}
188
189func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
190	row, col := m.Position()
191	offset := row + 3 // Border + title
192	cursor.Y += offset
193	cursor.X = cursor.X + col + 2
194	return cursor
195}
196
197func (m *modelDialogCmp) ID() dialogs.DialogID {
198	return ModelsDialogID
199}
200
201func (m *modelDialogCmp) modelTypeRadio() string {
202	t := styles.CurrentTheme()
203	choices := []string{"Large Task", "Small Task"}
204	iconSelected := "◉"
205	iconUnselected := "○"
206	if m.modelList.GetModelType() == LargeModelType {
207		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
208	}
209	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
210}