diff --git a/internal/tui/components/chat/splash/keys.go b/internal/tui/components/chat/splash/keys.go index df715c89e86971a0f788915737bf41a212c65b5a..2a90441da52924f49f187c88744f9ea01c80e745 100644 --- a/internal/tui/components/chat/splash/keys.go +++ b/internal/tui/components/chat/splash/keys.go @@ -5,14 +5,24 @@ import ( ) type KeyMap struct { - Cancel key.Binding + Select, + Next, + Previous key.Binding } func DefaultKeyMap() KeyMap { return KeyMap{ - Cancel: key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "cancel"), + Select: key.NewBinding( + key.WithKeys("enter", "ctrl+y"), + key.WithHelp("enter", "confirm"), + ), + Next: key.NewBinding( + key.WithKeys("down", "ctrl+n"), + key.WithHelp("↓", "next item"), + ), + Previous: key.NewBinding( + key.WithKeys("up", "ctrl+p"), + key.WithHelp("↑", "previous item"), ), } } diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 34828baf4c0b91cb71f4495f9e436a13d75ffe46..7cbc7b5fd371aa9db9cec3613ca40377a70c5d77 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -3,7 +3,10 @@ package splash import ( "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/tui/components/core/layout" + "github.com/charmbracelet/crush/internal/tui/components/core/list" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/models" "github.com/charmbracelet/crush/internal/tui/components/logo" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" @@ -22,18 +25,48 @@ const ( SplashScreenPaddingY = 1 // Padding Y for the splash screen ) +type SplashScreenState string + +const ( + SplashScreenStateOnboarding SplashScreenState = "onboarding" + SplashScreenStateInitialize SplashScreenState = "initialize" + SplashScreenStateReady SplashScreenState = "ready" +) + +// OnboardingCompleteMsg is sent when onboarding is complete +type OnboardingCompleteMsg struct{} + type splashCmp struct { - width, height int - keyMap KeyMap - logoRendered string + width, height int + keyMap KeyMap + logoRendered string + state SplashScreenState + modelList *models.ModelListComponent + cursorRow, cursorCol int } func New() Splash { + keyMap := DefaultKeyMap() + listKeyMap := list.DefaultKeyMap() + listKeyMap.Down.SetEnabled(false) + listKeyMap.Up.SetEnabled(false) + listKeyMap.HalfPageDown.SetEnabled(false) + listKeyMap.HalfPageUp.SetEnabled(false) + listKeyMap.Home.SetEnabled(false) + listKeyMap.End.SetEnabled(false) + listKeyMap.DownOneItem = keyMap.Next + listKeyMap.UpOneItem = keyMap.Previous + + t := styles.CurrentTheme() + inputStyle := t.S().Base.Padding(0, 1, 0, 1) + modelList := models.NewModelListComponent(listKeyMap, inputStyle) return &splashCmp{ width: 0, height: 0, - keyMap: DefaultKeyMap(), + keyMap: keyMap, + state: SplashScreenStateOnboarding, logoRendered: "", + modelList: modelList, } } @@ -44,7 +77,14 @@ func (s *splashCmp) GetSize() (int, int) { // Init implements SplashPage. func (s *splashCmp) Init() tea.Cmd { - return nil + if config.HasInitialDataConfig() { + if b, _ := config.ProjectNeedsInitialization(); b { + s.state = SplashScreenStateInitialize + } else { + s.state = SplashScreenStateReady + } + } + return s.modelList.Init() } // SetSize implements SplashPage. @@ -52,7 +92,12 @@ func (s *splashCmp) SetSize(width int, height int) tea.Cmd { s.width = width s.height = height s.logoRendered = s.logoBlock() - return nil + listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title + listWidth := min(60, width-(SplashScreenPaddingX*2)) + + // Calculate the cursor position based on the height and logo size + s.cursorRow = height - listHeigh + return s.modelList.SetSize(listWidth, listHeigh) } // Update implements SplashPage. @@ -60,6 +105,13 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: return s, s.SetSize(msg.Width, msg.Height) + case tea.KeyPressMsg: + switch { + default: + u, cmd := s.modelList.Update(msg) + s.modelList = u + return s, cmd + } } return s, nil } @@ -67,8 +119,34 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // View implements SplashPage. func (s *splashCmp) View() tea.View { t := styles.CurrentTheme() - content := lipgloss.JoinVertical(lipgloss.Left, s.logoRendered) - return tea.NewView( + var cursor *tea.Cursor + + var content string + switch s.state { + case SplashScreenStateOnboarding: + // Show logo and model selector + remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) + modelListView := s.modelList.View() + cursor = s.moveCursor(modelListView.Cursor()) + modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render( + lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"), + "", + modelListView.String(), + ), + ) + content = lipgloss.JoinVertical( + lipgloss.Left, + s.logoRendered, + modelSelector, + ) + default: + // Show just the logo for other states + content = s.logoRendered + } + + view := tea.NewView( t.S().Base. Width(s.width). Height(s.height). @@ -76,10 +154,11 @@ func (s *splashCmp) View() tea.View { PaddingLeft(SplashScreenPaddingX). PaddingRight(SplashScreenPaddingX). PaddingBottom(SplashScreenPaddingY). - Render( - content, - ), + Render(content), ) + + view.SetCursor(cursor) + return view } func (s *splashCmp) logoBlock() string { @@ -95,9 +174,24 @@ func (s *splashCmp) logoBlock() string { }) } +func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor { + if cursor == nil { + return nil + } + offset := m.cursorRow + cursor.Y += offset + cursor.X = cursor.X + 3 // 3 for padding + return cursor +} + // Bindings implements SplashPage. func (s *splashCmp) Bindings() []key.Binding { - return []key.Binding{ - s.keyMap.Cancel, + if s.state == SplashScreenStateOnboarding { + return []key.Binding{ + s.keyMap.Select, + s.keyMap.Next, + s.keyMap.Previous, + } } + return []key.Binding{} } diff --git a/internal/tui/components/dialogs/models/list.go b/internal/tui/components/dialogs/models/list.go new file mode 100644 index 0000000000000000000000000000000000000000..e9da21de7725b6cb903f6d3ff7777e1b90e01a69 --- /dev/null +++ b/internal/tui/components/dialogs/models/list.go @@ -0,0 +1,159 @@ +package models + +import ( + "slices" + + tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/fur/provider" + "github.com/charmbracelet/crush/internal/tui/components/completions" + "github.com/charmbracelet/crush/internal/tui/components/core/list" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands" + "github.com/charmbracelet/crush/internal/tui/util" + "github.com/charmbracelet/lipgloss/v2" +) + +type ModelListComponent struct { + list list.ListModel + modelType int +} + +func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style) *ModelListComponent { + modelList := list.New( + list.WithFilterable(true), + list.WithKeyMap(keyMap), + list.WithInputStyle(inputStyle), + list.WithWrapNavigation(true), + ) + + return &ModelListComponent{ + list: modelList, + modelType: LargeModelType, + } +} + +func (m *ModelListComponent) Init() tea.Cmd { + return tea.Batch(m.list.Init(), m.SetModelType(m.modelType)) +} + +func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) { + u, cmd := m.list.Update(msg) + m.list = u.(list.ListModel) + return m, cmd +} + +func (m *ModelListComponent) View() tea.View { + return m.list.View() +} + +func (m *ModelListComponent) SetSize(width, height int) tea.Cmd { + return m.list.SetSize(width, height) +} + +func (m *ModelListComponent) Items() []util.Model { + return m.list.Items() +} + +func (m *ModelListComponent) SelectedIndex() int { + return m.list.SelectedIndex() +} + +func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd { + m.modelType = modelType + + providers := config.Providers() + modelItems := []util.Model{} + selectIndex := 0 + + cfg := config.Get() + var currentModel config.PreferredModel + if m.modelType == LargeModelType { + currentModel = cfg.Models.Large + } else { + currentModel = cfg.Models.Small + } + + addedProviders := make(map[provider.InferenceProvider]bool) + + knownProviders := provider.KnownProviders() + for providerID, providerConfig := range cfg.Providers { + if providerConfig.Disabled { + continue + } + + // Check if this provider is not in the known providers list + if !slices.Contains(knownProviders, providerID) { + configProvider := provider.Provider{ + Name: string(providerID), + ID: providerID, + Models: make([]provider.Model, len(providerConfig.Models)), + } + + for i, model := range providerConfig.Models { + configProvider.Models[i] = provider.Model{ + ID: model.ID, + Name: model.Name, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + HasReasoningEffort: model.HasReasoningEffort, + DefaultReasoningEffort: model.ReasoningEffort, + SupportsImages: model.SupportsImages, + } + } + + name := configProvider.Name + if name == "" { + name = string(configProvider.ID) + } + modelItems = append(modelItems, commands.NewItemSection(name)) + for _, model := range configProvider.Models { + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ + Provider: configProvider, + Model: model, + })) + if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider { + selectIndex = len(modelItems) - 1 + } + } + addedProviders[providerID] = true + } + } + + for _, provider := range providers { + if addedProviders[provider.ID] { + continue + } + + if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled { + continue + } + + name := provider.Name + if name == "" { + name = string(provider.ID) + } + modelItems = append(modelItems, commands.NewItemSection(name)) + for _, model := range provider.Models { + modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ + Provider: provider, + Model: model, + })) + if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider { + selectIndex = len(modelItems) - 1 + } + } + } + + return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex)) +} + +// GetModelType returns the current model type +func (m *ModelListComponent) GetModelType() int { + return m.modelType +} + diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index dc82e2fa1c745fc46f14895680c93d30864f317a..cab864ef1dafd67fe2b5f6933e480f37fdb19498 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -1,8 +1,6 @@ package models import ( - "slices" - "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" tea "github.com/charmbracelet/bubbletea/v2" @@ -12,7 +10,6 @@ import ( "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/list" "github.com/charmbracelet/crush/internal/tui/components/dialogs" - "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" "github.com/charmbracelet/lipgloss/v2" @@ -53,10 +50,9 @@ type modelDialogCmp struct { wWidth int wHeight int - modelList list.ListModel + modelList *ModelListComponent keyMap KeyMap help help.Model - modelType int } func NewModelDialogCmp() ModelDialog { @@ -75,12 +71,7 @@ func NewModelDialogCmp() ModelDialog { t := styles.CurrentTheme() inputStyle := t.S().Base.Padding(0, 1, 0, 1) - modelList := list.New( - list.WithFilterable(true), - list.WithKeyMap(listKeyMap), - list.WithInputStyle(inputStyle), - list.WithWrapNavigation(true), - ) + modelList := NewModelListComponent(listKeyMap, inputStyle) help := help.New() help.Styles = t.S().Help @@ -89,12 +80,10 @@ func NewModelDialogCmp() ModelDialog { width: defaultWidth, keyMap: DefaultKeyMap(), help: help, - modelType: LargeModelType, } } func (m *modelDialogCmp) Init() tea.Cmd { - m.SetModelType(m.modelType) return m.modelList.Init() } @@ -103,7 +92,6 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.WindowSizeMsg: m.wWidth = msg.Width m.wHeight = msg.Height - m.SetModelType(m.modelType) return m, m.modelList.SetSize(m.listWidth(), m.listHeight()) case tea.KeyPressMsg: switch { @@ -116,8 +104,8 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption) var modelType config.SelectedModelType - if m.modelType == LargeModelType { - modelType = config.SelectedModelTypeLarge + if m.modelList.GetModelType() == LargeModelType { + modelType = config.LargeModel } else { modelType = config.SelectedModelTypeSmall } @@ -133,16 +121,16 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { }), ) case key.Matches(msg, m.keyMap.Tab): - if m.modelType == LargeModelType { - return m, m.SetModelType(SmallModelType) + if m.modelList.GetModelType() == LargeModelType { + return m, m.modelList.SetModelType(SmallModelType) } else { - return m, m.SetModelType(LargeModelType) + return m, m.modelList.SetModelType(LargeModelType) } case key.Matches(msg, m.keyMap.Close): return m, util.CmdHandler(dialogs.CloseDialogMsg{}) default: u, cmd := m.modelList.Update(msg) - m.modelList = u.(list.ListModel) + m.modelList = u return m, cmd } } @@ -181,7 +169,8 @@ func (m *modelDialogCmp) listWidth() int { } func (m *modelDialogCmp) listHeight() int { - listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections + items := m.modelList.Items() + listHeigh := len(items) + 2 + 4 return min(listHeigh, m.wHeight/2) } @@ -209,115 +198,8 @@ func (m *modelDialogCmp) modelTypeRadio() string { choices := []string{"Large Task", "Small Task"} iconSelected := "◉" iconUnselected := "○" - if m.modelType == LargeModelType { + if m.modelList.GetModelType() == LargeModelType { return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1]) } return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1]) } - -func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd { - m.modelType = modelType - - providers, err := config.Providers() - if err != nil { - return util.ReportError(err) - } - - modelItems := []util.Model{} - selectIndex := 0 - - cfg := config.Get() - var currentModel config.SelectedModel - if m.modelType == LargeModelType { - currentModel = cfg.Models[config.SelectedModelTypeLarge] - } else { - currentModel = cfg.Models[config.SelectedModelTypeSmall] - } - - // Create a map to track which providers we've already added - addedProviders := make(map[string]bool) - - // First, add any configured providers that are not in the known providers list - // These should appear at the top of the list - knownProviders := provider.KnownProviders() - for providerID, providerConfig := range cfg.Providers { - if providerConfig.Disable { - continue - } - - // Check if this provider is not in the known providers list - if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) { - // Convert config provider to provider.Provider format - configProvider := provider.Provider{ - Name: string(providerID), // Use provider ID as name for unknown providers - ID: provider.InferenceProvider(providerID), - Models: make([]provider.Model, len(providerConfig.Models)), - } - - // Convert models - for i, model := range providerConfig.Models { - configProvider.Models[i] = provider.Model{ - ID: model.ID, - Name: model.Name, - CostPer1MIn: model.CostPer1MIn, - CostPer1MOut: model.CostPer1MOut, - CostPer1MInCached: model.CostPer1MInCached, - CostPer1MOutCached: model.CostPer1MOutCached, - ContextWindow: model.ContextWindow, - DefaultMaxTokens: model.DefaultMaxTokens, - CanReason: model.CanReason, - HasReasoningEffort: model.HasReasoningEffort, - DefaultReasoningEffort: model.DefaultReasoningEffort, - SupportsImages: model.SupportsImages, - } - } - - // Add this unknown provider to the list - name := configProvider.Name - if name == "" { - name = string(configProvider.ID) - } - modelItems = append(modelItems, commands.NewItemSection(name)) - for _, model := range configProvider.Models { - modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ - Provider: configProvider, - Model: model, - })) - if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider { - selectIndex = len(modelItems) - 1 // Set the selected index to the current model - } - } - addedProviders[providerID] = true - } - } - - // Then add the known providers from the predefined list - for _, provider := range providers { - // Skip if we already added this provider as an unknown provider - if addedProviders[string(provider.ID)] { - continue - } - - // Check if this provider is configured and not disabled - if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable { - continue - } - - name := provider.Name - if name == "" { - name = string(provider.ID) - } - modelItems = append(modelItems, commands.NewItemSection(name)) - for _, model := range provider.Models { - modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{ - Provider: provider, - Model: model, - })) - if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider { - selectIndex = len(modelItems) - 1 // Set the selected index to the current model - } - } - } - - return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex)) -} diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 39fee8e09b731a315338656b30bac692e58c9af9..daf6881426a6d1b59020ee98e06fb44d6c28f17e 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -109,17 +109,18 @@ func New(app *app.App) ChatPage { keyMap: DefaultKeyMap(), - header: header.New(app.LSPClients), - sidebar: sidebar.New(app.History, app.LSPClients, false), - chat: chat.New(app), - editor: editor.New(app), - splash: splash.New(), + header: header.New(app.LSPClients), + sidebar: sidebar.New(app.History, app.LSPClients, false), + chat: chat.New(app), + editor: editor.New(app), + splash: splash.New(), + focusedPane: PanelTypeSplash, } } func (p *chatPage) Init() tea.Cmd { cfg := config.Get() - if cfg.IsReady() { + if config.HasInitialDataConfig() { if b, _ := config.ProjectNeedsInitialization(); b { p.state = ChatStateInitProject } else { @@ -248,6 +249,10 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { u, cmd := p.editor.Update(msg) p.editor = u.(editor.Editor) cmds = append(cmds, cmd) + case PanelTypeSplash: + u, cmd := p.splash.Update(msg) + p.splash = u.(splash.Splash) + cmds = append(cmds, cmd) } } return p, tea.Batch(cmds...) @@ -258,11 +263,7 @@ func (p *chatPage) View() tea.View { t := styles.CurrentTheme() switch p.state { case ChatStateOnboarding, ChatStateInitProject: - chatView = tea.NewView( - t.S().Base.Render( - p.splash.View().String(), - ), - ) + chatView = p.splash.View() case ChatStateNewMessage: editorView := p.editor.View() chatView = tea.NewView(