Detailed changes
@@ -5,14 +5,24 @@ import (
)
type KeyMap struct {
- Cancel key.Binding
+ Select,
+ Next,
+ Previous key.Binding
}
func DefaultKeyMap() KeyMap {
return KeyMap{
- Cancel: key.NewBinding(
- key.WithKeys("esc"),
- key.WithHelp("esc", "cancel"),
+ Select: key.NewBinding(
+ key.WithKeys("enter", "ctrl+y"),
+ key.WithHelp("enter", "confirm"),
+ ),
+ Next: key.NewBinding(
+ key.WithKeys("down", "ctrl+n"),
+ key.WithHelp("↓", "next item"),
+ ),
+ Previous: key.NewBinding(
+ key.WithKeys("up", "ctrl+p"),
+ key.WithHelp("↑", "previous item"),
),
}
}
@@ -3,7 +3,10 @@ package splash
import (
"github.com/charmbracelet/bubbles/v2/key"
tea "github.com/charmbracelet/bubbletea/v2"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/tui/components/core/layout"
+ "github.com/charmbracelet/crush/internal/tui/components/core/list"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
"github.com/charmbracelet/crush/internal/tui/components/logo"
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/crush/internal/tui/util"
@@ -22,18 +25,48 @@ const (
SplashScreenPaddingY = 1 // Padding Y for the splash screen
)
+type SplashScreenState string
+
+const (
+ SplashScreenStateOnboarding SplashScreenState = "onboarding"
+ SplashScreenStateInitialize SplashScreenState = "initialize"
+ SplashScreenStateReady SplashScreenState = "ready"
+)
+
+// OnboardingCompleteMsg is sent when onboarding is complete
+type OnboardingCompleteMsg struct{}
+
type splashCmp struct {
- width, height int
- keyMap KeyMap
- logoRendered string
+ width, height int
+ keyMap KeyMap
+ logoRendered string
+ state SplashScreenState
+ modelList *models.ModelListComponent
+ cursorRow, cursorCol int
}
func New() Splash {
+ keyMap := DefaultKeyMap()
+ listKeyMap := list.DefaultKeyMap()
+ listKeyMap.Down.SetEnabled(false)
+ listKeyMap.Up.SetEnabled(false)
+ listKeyMap.HalfPageDown.SetEnabled(false)
+ listKeyMap.HalfPageUp.SetEnabled(false)
+ listKeyMap.Home.SetEnabled(false)
+ listKeyMap.End.SetEnabled(false)
+ listKeyMap.DownOneItem = keyMap.Next
+ listKeyMap.UpOneItem = keyMap.Previous
+
+ t := styles.CurrentTheme()
+ inputStyle := t.S().Base.Padding(0, 1, 0, 1)
+ modelList := models.NewModelListComponent(listKeyMap, inputStyle)
return &splashCmp{
width: 0,
height: 0,
- keyMap: DefaultKeyMap(),
+ keyMap: keyMap,
+ state: SplashScreenStateOnboarding,
logoRendered: "",
+ modelList: modelList,
}
}
@@ -44,7 +77,14 @@ func (s *splashCmp) GetSize() (int, int) {
// Init implements SplashPage.
func (s *splashCmp) Init() tea.Cmd {
- return nil
+ if config.HasInitialDataConfig() {
+ if b, _ := config.ProjectNeedsInitialization(); b {
+ s.state = SplashScreenStateInitialize
+ } else {
+ s.state = SplashScreenStateReady
+ }
+ }
+ return s.modelList.Init()
}
// SetSize implements SplashPage.
@@ -52,7 +92,12 @@ func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
s.width = width
s.height = height
s.logoRendered = s.logoBlock()
- return nil
+ listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
+ listWidth := min(60, width-(SplashScreenPaddingX*2))
+
+ // Calculate the cursor position based on the height and logo size
+ s.cursorRow = height - listHeigh
+ return s.modelList.SetSize(listWidth, listHeigh)
}
// Update implements SplashPage.
@@ -60,6 +105,13 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
return s, s.SetSize(msg.Width, msg.Height)
+ case tea.KeyPressMsg:
+ switch {
+ default:
+ u, cmd := s.modelList.Update(msg)
+ s.modelList = u
+ return s, cmd
+ }
}
return s, nil
}
@@ -67,8 +119,34 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// View implements SplashPage.
func (s *splashCmp) View() tea.View {
t := styles.CurrentTheme()
- content := lipgloss.JoinVertical(lipgloss.Left, s.logoRendered)
- return tea.NewView(
+ var cursor *tea.Cursor
+
+ var content string
+ switch s.state {
+ case SplashScreenStateOnboarding:
+ // Show logo and model selector
+ remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
+ modelListView := s.modelList.View()
+ cursor = s.moveCursor(modelListView.Cursor())
+ modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
+ lipgloss.JoinVertical(
+ lipgloss.Left,
+ t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
+ "",
+ modelListView.String(),
+ ),
+ )
+ content = lipgloss.JoinVertical(
+ lipgloss.Left,
+ s.logoRendered,
+ modelSelector,
+ )
+ default:
+ // Show just the logo for other states
+ content = s.logoRendered
+ }
+
+ view := tea.NewView(
t.S().Base.
Width(s.width).
Height(s.height).
@@ -76,10 +154,11 @@ func (s *splashCmp) View() tea.View {
PaddingLeft(SplashScreenPaddingX).
PaddingRight(SplashScreenPaddingX).
PaddingBottom(SplashScreenPaddingY).
- Render(
- content,
- ),
+ Render(content),
)
+
+ view.SetCursor(cursor)
+ return view
}
func (s *splashCmp) logoBlock() string {
@@ -95,9 +174,24 @@ func (s *splashCmp) logoBlock() string {
})
}
+func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
+ if cursor == nil {
+ return nil
+ }
+ offset := m.cursorRow
+ cursor.Y += offset
+ cursor.X = cursor.X + 3 // 3 for padding
+ return cursor
+}
+
// Bindings implements SplashPage.
func (s *splashCmp) Bindings() []key.Binding {
- return []key.Binding{
- s.keyMap.Cancel,
+ if s.state == SplashScreenStateOnboarding {
+ return []key.Binding{
+ s.keyMap.Select,
+ s.keyMap.Next,
+ s.keyMap.Previous,
+ }
}
+ return []key.Binding{}
}
@@ -0,0 +1,159 @@
+package models
+
+import (
+ "slices"
+
+ tea "github.com/charmbracelet/bubbletea/v2"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/fur/provider"
+ "github.com/charmbracelet/crush/internal/tui/components/completions"
+ "github.com/charmbracelet/crush/internal/tui/components/core/list"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
+ "github.com/charmbracelet/crush/internal/tui/util"
+ "github.com/charmbracelet/lipgloss/v2"
+)
+
+type ModelListComponent struct {
+ list list.ListModel
+ modelType int
+}
+
+func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style) *ModelListComponent {
+ modelList := list.New(
+ list.WithFilterable(true),
+ list.WithKeyMap(keyMap),
+ list.WithInputStyle(inputStyle),
+ list.WithWrapNavigation(true),
+ )
+
+ return &ModelListComponent{
+ list: modelList,
+ modelType: LargeModelType,
+ }
+}
+
+func (m *ModelListComponent) Init() tea.Cmd {
+ return tea.Batch(m.list.Init(), m.SetModelType(m.modelType))
+}
+
+func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
+ u, cmd := m.list.Update(msg)
+ m.list = u.(list.ListModel)
+ return m, cmd
+}
+
+func (m *ModelListComponent) View() tea.View {
+ return m.list.View()
+}
+
+func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
+ return m.list.SetSize(width, height)
+}
+
+func (m *ModelListComponent) Items() []util.Model {
+ return m.list.Items()
+}
+
+func (m *ModelListComponent) SelectedIndex() int {
+ return m.list.SelectedIndex()
+}
+
+func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
+ m.modelType = modelType
+
+ providers := config.Providers()
+ modelItems := []util.Model{}
+ selectIndex := 0
+
+ cfg := config.Get()
+ var currentModel config.PreferredModel
+ if m.modelType == LargeModelType {
+ currentModel = cfg.Models.Large
+ } else {
+ currentModel = cfg.Models.Small
+ }
+
+ addedProviders := make(map[provider.InferenceProvider]bool)
+
+ knownProviders := provider.KnownProviders()
+ for providerID, providerConfig := range cfg.Providers {
+ if providerConfig.Disabled {
+ continue
+ }
+
+ // Check if this provider is not in the known providers list
+ if !slices.Contains(knownProviders, providerID) {
+ configProvider := provider.Provider{
+ Name: string(providerID),
+ ID: providerID,
+ Models: make([]provider.Model, len(providerConfig.Models)),
+ }
+
+ for i, model := range providerConfig.Models {
+ configProvider.Models[i] = provider.Model{
+ ID: model.ID,
+ Name: model.Name,
+ CostPer1MIn: model.CostPer1MIn,
+ CostPer1MOut: model.CostPer1MOut,
+ CostPer1MInCached: model.CostPer1MInCached,
+ CostPer1MOutCached: model.CostPer1MOutCached,
+ ContextWindow: model.ContextWindow,
+ DefaultMaxTokens: model.DefaultMaxTokens,
+ CanReason: model.CanReason,
+ HasReasoningEffort: model.HasReasoningEffort,
+ DefaultReasoningEffort: model.ReasoningEffort,
+ SupportsImages: model.SupportsImages,
+ }
+ }
+
+ name := configProvider.Name
+ if name == "" {
+ name = string(configProvider.ID)
+ }
+ modelItems = append(modelItems, commands.NewItemSection(name))
+ for _, model := range configProvider.Models {
+ modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
+ Provider: configProvider,
+ Model: model,
+ }))
+ if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
+ selectIndex = len(modelItems) - 1
+ }
+ }
+ addedProviders[providerID] = true
+ }
+ }
+
+ for _, provider := range providers {
+ if addedProviders[provider.ID] {
+ continue
+ }
+
+ if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
+ continue
+ }
+
+ name := provider.Name
+ if name == "" {
+ name = string(provider.ID)
+ }
+ modelItems = append(modelItems, commands.NewItemSection(name))
+ for _, model := range provider.Models {
+ modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
+ Provider: provider,
+ Model: model,
+ }))
+ if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
+ selectIndex = len(modelItems) - 1
+ }
+ }
+ }
+
+ return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
+}
+
+// GetModelType returns the current model type
+func (m *ModelListComponent) GetModelType() int {
+ return m.modelType
+}
+
@@ -1,8 +1,6 @@
package models
import (
- "slices"
-
"github.com/charmbracelet/bubbles/v2/help"
"github.com/charmbracelet/bubbles/v2/key"
tea "github.com/charmbracelet/bubbletea/v2"
@@ -12,7 +10,6 @@ import (
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/core/list"
"github.com/charmbracelet/crush/internal/tui/components/dialogs"
- "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/crush/internal/tui/util"
"github.com/charmbracelet/lipgloss/v2"
@@ -53,10 +50,9 @@ type modelDialogCmp struct {
wWidth int
wHeight int
- modelList list.ListModel
+ modelList *ModelListComponent
keyMap KeyMap
help help.Model
- modelType int
}
func NewModelDialogCmp() ModelDialog {
@@ -75,12 +71,7 @@ func NewModelDialogCmp() ModelDialog {
t := styles.CurrentTheme()
inputStyle := t.S().Base.Padding(0, 1, 0, 1)
- modelList := list.New(
- list.WithFilterable(true),
- list.WithKeyMap(listKeyMap),
- list.WithInputStyle(inputStyle),
- list.WithWrapNavigation(true),
- )
+ modelList := NewModelListComponent(listKeyMap, inputStyle)
help := help.New()
help.Styles = t.S().Help
@@ -89,12 +80,10 @@ func NewModelDialogCmp() ModelDialog {
width: defaultWidth,
keyMap: DefaultKeyMap(),
help: help,
- modelType: LargeModelType,
}
}
func (m *modelDialogCmp) Init() tea.Cmd {
- m.SetModelType(m.modelType)
return m.modelList.Init()
}
@@ -103,7 +92,6 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.WindowSizeMsg:
m.wWidth = msg.Width
m.wHeight = msg.Height
- m.SetModelType(m.modelType)
return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
case tea.KeyPressMsg:
switch {
@@ -116,8 +104,8 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
var modelType config.SelectedModelType
- if m.modelType == LargeModelType {
- modelType = config.SelectedModelTypeLarge
+ if m.modelList.GetModelType() == LargeModelType {
+ modelType = config.LargeModel
} else {
modelType = config.SelectedModelTypeSmall
}
@@ -133,16 +121,16 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}),
)
case key.Matches(msg, m.keyMap.Tab):
- if m.modelType == LargeModelType {
- return m, m.SetModelType(SmallModelType)
+ if m.modelList.GetModelType() == LargeModelType {
+ return m, m.modelList.SetModelType(SmallModelType)
} else {
- return m, m.SetModelType(LargeModelType)
+ return m, m.modelList.SetModelType(LargeModelType)
}
case key.Matches(msg, m.keyMap.Close):
return m, util.CmdHandler(dialogs.CloseDialogMsg{})
default:
u, cmd := m.modelList.Update(msg)
- m.modelList = u.(list.ListModel)
+ m.modelList = u
return m, cmd
}
}
@@ -181,7 +169,8 @@ func (m *modelDialogCmp) listWidth() int {
}
func (m *modelDialogCmp) listHeight() int {
- listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
+ items := m.modelList.Items()
+ listHeigh := len(items) + 2 + 4
return min(listHeigh, m.wHeight/2)
}
@@ -209,115 +198,8 @@ func (m *modelDialogCmp) modelTypeRadio() string {
choices := []string{"Large Task", "Small Task"}
iconSelected := "◉"
iconUnselected := "○"
- if m.modelType == LargeModelType {
+ if m.modelList.GetModelType() == LargeModelType {
return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
}
return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
}
-
-func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
- m.modelType = modelType
-
- providers, err := config.Providers()
- if err != nil {
- return util.ReportError(err)
- }
-
- modelItems := []util.Model{}
- selectIndex := 0
-
- cfg := config.Get()
- var currentModel config.SelectedModel
- if m.modelType == LargeModelType {
- currentModel = cfg.Models[config.SelectedModelTypeLarge]
- } else {
- currentModel = cfg.Models[config.SelectedModelTypeSmall]
- }
-
- // Create a map to track which providers we've already added
- addedProviders := make(map[string]bool)
-
- // First, add any configured providers that are not in the known providers list
- // These should appear at the top of the list
- knownProviders := provider.KnownProviders()
- for providerID, providerConfig := range cfg.Providers {
- if providerConfig.Disable {
- continue
- }
-
- // Check if this provider is not in the known providers list
- if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
- // Convert config provider to provider.Provider format
- configProvider := provider.Provider{
- Name: string(providerID), // Use provider ID as name for unknown providers
- ID: provider.InferenceProvider(providerID),
- Models: make([]provider.Model, len(providerConfig.Models)),
- }
-
- // Convert models
- for i, model := range providerConfig.Models {
- configProvider.Models[i] = provider.Model{
- ID: model.ID,
- Name: model.Name,
- CostPer1MIn: model.CostPer1MIn,
- CostPer1MOut: model.CostPer1MOut,
- CostPer1MInCached: model.CostPer1MInCached,
- CostPer1MOutCached: model.CostPer1MOutCached,
- ContextWindow: model.ContextWindow,
- DefaultMaxTokens: model.DefaultMaxTokens,
- CanReason: model.CanReason,
- HasReasoningEffort: model.HasReasoningEffort,
- DefaultReasoningEffort: model.DefaultReasoningEffort,
- SupportsImages: model.SupportsImages,
- }
- }
-
- // Add this unknown provider to the list
- name := configProvider.Name
- if name == "" {
- name = string(configProvider.ID)
- }
- modelItems = append(modelItems, commands.NewItemSection(name))
- for _, model := range configProvider.Models {
- modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
- Provider: configProvider,
- Model: model,
- }))
- if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
- selectIndex = len(modelItems) - 1 // Set the selected index to the current model
- }
- }
- addedProviders[providerID] = true
- }
- }
-
- // Then add the known providers from the predefined list
- for _, provider := range providers {
- // Skip if we already added this provider as an unknown provider
- if addedProviders[string(provider.ID)] {
- continue
- }
-
- // Check if this provider is configured and not disabled
- if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
- continue
- }
-
- name := provider.Name
- if name == "" {
- name = string(provider.ID)
- }
- modelItems = append(modelItems, commands.NewItemSection(name))
- for _, model := range provider.Models {
- modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
- Provider: provider,
- Model: model,
- }))
- if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
- selectIndex = len(modelItems) - 1 // Set the selected index to the current model
- }
- }
- }
-
- return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
-}
@@ -109,17 +109,18 @@ func New(app *app.App) ChatPage {
keyMap: DefaultKeyMap(),
- header: header.New(app.LSPClients),
- sidebar: sidebar.New(app.History, app.LSPClients, false),
- chat: chat.New(app),
- editor: editor.New(app),
- splash: splash.New(),
+ header: header.New(app.LSPClients),
+ sidebar: sidebar.New(app.History, app.LSPClients, false),
+ chat: chat.New(app),
+ editor: editor.New(app),
+ splash: splash.New(),
+ focusedPane: PanelTypeSplash,
}
}
func (p *chatPage) Init() tea.Cmd {
cfg := config.Get()
- if cfg.IsReady() {
+ if config.HasInitialDataConfig() {
if b, _ := config.ProjectNeedsInitialization(); b {
p.state = ChatStateInitProject
} else {
@@ -248,6 +249,10 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
u, cmd := p.editor.Update(msg)
p.editor = u.(editor.Editor)
cmds = append(cmds, cmd)
+ case PanelTypeSplash:
+ u, cmd := p.splash.Update(msg)
+ p.splash = u.(splash.Splash)
+ cmds = append(cmds, cmd)
}
}
return p, tea.Batch(cmds...)
@@ -258,11 +263,7 @@ func (p *chatPage) View() tea.View {
t := styles.CurrentTheme()
switch p.state {
case ChatStateOnboarding, ChatStateInitProject:
- chatView = tea.NewView(
- t.S().Base.Render(
- p.splash.View().String(),
- ),
- )
+ chatView = p.splash.View()
case ChatStateNewMessage:
editorView := p.editor.View()
chatView = tea.NewView(