wip: oboarding splashscreen

Kujtim Hoxha created

Change summary

internal/tui/components/chat/splash/keys.go      |  18 +
internal/tui/components/chat/splash/splash.go    | 120 ++++++++++++-
internal/tui/components/dialogs/models/list.go   | 159 ++++++++++++++++++
internal/tui/components/dialogs/models/models.go | 140 +--------------
internal/tui/page/chat/chat.go                   |  23 +-
5 files changed, 303 insertions(+), 157 deletions(-)

Detailed changes

internal/tui/components/chat/splash/keys.go 🔗

@@ -5,14 +5,24 @@ import (
 )
 
 type KeyMap struct {
-	Cancel key.Binding
+	Select,
+	Next,
+	Previous key.Binding
 }
 
 func DefaultKeyMap() KeyMap {
 	return KeyMap{
-		Cancel: key.NewBinding(
-			key.WithKeys("esc"),
-			key.WithHelp("esc", "cancel"),
+		Select: key.NewBinding(
+			key.WithKeys("enter", "ctrl+y"),
+			key.WithHelp("enter", "confirm"),
+		),
+		Next: key.NewBinding(
+			key.WithKeys("down", "ctrl+n"),
+			key.WithHelp("↓", "next item"),
+		),
+		Previous: key.NewBinding(
+			key.WithKeys("up", "ctrl+p"),
+			key.WithHelp("↑", "previous item"),
 		),
 	}
 }

internal/tui/components/chat/splash/splash.go 🔗

@@ -3,7 +3,10 @@ package splash
 import (
 	"github.com/charmbracelet/bubbles/v2/key"
 	tea "github.com/charmbracelet/bubbletea/v2"
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
+	"github.com/charmbracelet/crush/internal/tui/components/core/list"
+	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 	"github.com/charmbracelet/crush/internal/tui/components/logo"
 	"github.com/charmbracelet/crush/internal/tui/styles"
 	"github.com/charmbracelet/crush/internal/tui/util"
@@ -22,18 +25,48 @@ const (
 	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 )
 
+type SplashScreenState string
+
+const (
+	SplashScreenStateOnboarding SplashScreenState = "onboarding"
+	SplashScreenStateInitialize SplashScreenState = "initialize"
+	SplashScreenStateReady      SplashScreenState = "ready"
+)
+
+// OnboardingCompleteMsg is sent when onboarding is complete
+type OnboardingCompleteMsg struct{}
+
 type splashCmp struct {
-	width, height int
-	keyMap        KeyMap
-	logoRendered  string
+	width, height        int
+	keyMap               KeyMap
+	logoRendered         string
+	state                SplashScreenState
+	modelList            *models.ModelListComponent
+	cursorRow, cursorCol int
 }
 
 func New() Splash {
+	keyMap := DefaultKeyMap()
+	listKeyMap := list.DefaultKeyMap()
+	listKeyMap.Down.SetEnabled(false)
+	listKeyMap.Up.SetEnabled(false)
+	listKeyMap.HalfPageDown.SetEnabled(false)
+	listKeyMap.HalfPageUp.SetEnabled(false)
+	listKeyMap.Home.SetEnabled(false)
+	listKeyMap.End.SetEnabled(false)
+	listKeyMap.DownOneItem = keyMap.Next
+	listKeyMap.UpOneItem = keyMap.Previous
+
+	t := styles.CurrentTheme()
+	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
+	modelList := models.NewModelListComponent(listKeyMap, inputStyle)
 	return &splashCmp{
 		width:        0,
 		height:       0,
-		keyMap:       DefaultKeyMap(),
+		keyMap:       keyMap,
+		state:        SplashScreenStateOnboarding,
 		logoRendered: "",
+		modelList:    modelList,
 	}
 }
 
@@ -44,7 +77,14 @@ func (s *splashCmp) GetSize() (int, int) {
 
 // Init implements SplashPage.
 func (s *splashCmp) Init() tea.Cmd {
-	return nil
+	if config.HasInitialDataConfig() {
+		if b, _ := config.ProjectNeedsInitialization(); b {
+			s.state = SplashScreenStateInitialize
+		} else {
+			s.state = SplashScreenStateReady
+		}
+	}
+	return s.modelList.Init()
 }
 
 // SetSize implements SplashPage.
@@ -52,7 +92,12 @@ func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
 	s.width = width
 	s.height = height
 	s.logoRendered = s.logoBlock()
-	return nil
+	listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
+	listWidth := min(60, width-(SplashScreenPaddingX*2))
+
+	// Calculate the cursor position based on the height and logo size
+	s.cursorRow = height - listHeigh
+	return s.modelList.SetSize(listWidth, listHeigh)
 }
 
 // Update implements SplashPage.
@@ -60,6 +105,13 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	switch msg := msg.(type) {
 	case tea.WindowSizeMsg:
 		return s, s.SetSize(msg.Width, msg.Height)
+	case tea.KeyPressMsg:
+		switch {
+		default:
+			u, cmd := s.modelList.Update(msg)
+			s.modelList = u
+			return s, cmd
+		}
 	}
 	return s, nil
 }
@@ -67,8 +119,34 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 // View implements SplashPage.
 func (s *splashCmp) View() tea.View {
 	t := styles.CurrentTheme()
-	content := lipgloss.JoinVertical(lipgloss.Left, s.logoRendered)
-	return tea.NewView(
+	var cursor *tea.Cursor
+
+	var content string
+	switch s.state {
+	case SplashScreenStateOnboarding:
+		// Show logo and model selector
+		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
+		modelListView := s.modelList.View()
+		cursor = s.moveCursor(modelListView.Cursor())
+		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
+			lipgloss.JoinVertical(
+				lipgloss.Left,
+				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
+				"",
+				modelListView.String(),
+			),
+		)
+		content = lipgloss.JoinVertical(
+			lipgloss.Left,
+			s.logoRendered,
+			modelSelector,
+		)
+	default:
+		// Show just the logo for other states
+		content = s.logoRendered
+	}
+
+	view := tea.NewView(
 		t.S().Base.
 			Width(s.width).
 			Height(s.height).
@@ -76,10 +154,11 @@ func (s *splashCmp) View() tea.View {
 			PaddingLeft(SplashScreenPaddingX).
 			PaddingRight(SplashScreenPaddingX).
 			PaddingBottom(SplashScreenPaddingY).
-			Render(
-				content,
-			),
+			Render(content),
 	)
+
+	view.SetCursor(cursor)
+	return view
 }
 
 func (s *splashCmp) logoBlock() string {
@@ -95,9 +174,24 @@ func (s *splashCmp) logoBlock() string {
 	})
 }
 
+func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
+	if cursor == nil {
+		return nil
+	}
+	offset := m.cursorRow
+	cursor.Y += offset
+	cursor.X = cursor.X + 3 // 3 for padding
+	return cursor
+}
+
 // Bindings implements SplashPage.
 func (s *splashCmp) Bindings() []key.Binding {
-	return []key.Binding{
-		s.keyMap.Cancel,
+	if s.state == SplashScreenStateOnboarding {
+		return []key.Binding{
+			s.keyMap.Select,
+			s.keyMap.Next,
+			s.keyMap.Previous,
+		}
 	}
+	return []key.Binding{}
 }

internal/tui/components/dialogs/models/list.go 🔗

@@ -0,0 +1,159 @@
+package models
+
+import (
+	"slices"
+
+	tea "github.com/charmbracelet/bubbletea/v2"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/fur/provider"
+	"github.com/charmbracelet/crush/internal/tui/components/completions"
+	"github.com/charmbracelet/crush/internal/tui/components/core/list"
+	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
+	"github.com/charmbracelet/crush/internal/tui/util"
+	"github.com/charmbracelet/lipgloss/v2"
+)
+
+type ModelListComponent struct {
+	list      list.ListModel
+	modelType int
+}
+
+func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style) *ModelListComponent {
+	modelList := list.New(
+		list.WithFilterable(true),
+		list.WithKeyMap(keyMap),
+		list.WithInputStyle(inputStyle),
+		list.WithWrapNavigation(true),
+	)
+
+	return &ModelListComponent{
+		list:      modelList,
+		modelType: LargeModelType,
+	}
+}
+
+func (m *ModelListComponent) Init() tea.Cmd {
+	return tea.Batch(m.list.Init(), m.SetModelType(m.modelType))
+}
+
+func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
+	u, cmd := m.list.Update(msg)
+	m.list = u.(list.ListModel)
+	return m, cmd
+}
+
+func (m *ModelListComponent) View() tea.View {
+	return m.list.View()
+}
+
+func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
+	return m.list.SetSize(width, height)
+}
+
+func (m *ModelListComponent) Items() []util.Model {
+	return m.list.Items()
+}
+
+func (m *ModelListComponent) SelectedIndex() int {
+	return m.list.SelectedIndex()
+}
+
+func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
+	m.modelType = modelType
+
+	providers := config.Providers()
+	modelItems := []util.Model{}
+	selectIndex := 0
+
+	cfg := config.Get()
+	var currentModel config.PreferredModel
+	if m.modelType == LargeModelType {
+		currentModel = cfg.Models.Large
+	} else {
+		currentModel = cfg.Models.Small
+	}
+
+	addedProviders := make(map[provider.InferenceProvider]bool)
+
+	knownProviders := provider.KnownProviders()
+	for providerID, providerConfig := range cfg.Providers {
+		if providerConfig.Disabled {
+			continue
+		}
+
+		// Check if this provider is not in the known providers list
+		if !slices.Contains(knownProviders, providerID) {
+			configProvider := provider.Provider{
+				Name:   string(providerID),
+				ID:     providerID,
+				Models: make([]provider.Model, len(providerConfig.Models)),
+			}
+
+			for i, model := range providerConfig.Models {
+				configProvider.Models[i] = provider.Model{
+					ID:                     model.ID,
+					Name:                   model.Name,
+					CostPer1MIn:            model.CostPer1MIn,
+					CostPer1MOut:           model.CostPer1MOut,
+					CostPer1MInCached:      model.CostPer1MInCached,
+					CostPer1MOutCached:     model.CostPer1MOutCached,
+					ContextWindow:          model.ContextWindow,
+					DefaultMaxTokens:       model.DefaultMaxTokens,
+					CanReason:              model.CanReason,
+					HasReasoningEffort:     model.HasReasoningEffort,
+					DefaultReasoningEffort: model.ReasoningEffort,
+					SupportsImages:         model.SupportsImages,
+				}
+			}
+
+			name := configProvider.Name
+			if name == "" {
+				name = string(configProvider.ID)
+			}
+			modelItems = append(modelItems, commands.NewItemSection(name))
+			for _, model := range configProvider.Models {
+				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
+					Provider: configProvider,
+					Model:    model,
+				}))
+				if model.ID == currentModel.ModelID && configProvider.ID == currentModel.Provider {
+					selectIndex = len(modelItems) - 1
+				}
+			}
+			addedProviders[providerID] = true
+		}
+	}
+
+	for _, provider := range providers {
+		if addedProviders[provider.ID] {
+			continue
+		}
+
+		if providerConfig, exists := cfg.Providers[provider.ID]; exists && providerConfig.Disabled {
+			continue
+		}
+
+		name := provider.Name
+		if name == "" {
+			name = string(provider.ID)
+		}
+		modelItems = append(modelItems, commands.NewItemSection(name))
+		for _, model := range provider.Models {
+			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
+				Provider: provider,
+				Model:    model,
+			}))
+			if model.ID == currentModel.ModelID && provider.ID == currentModel.Provider {
+				selectIndex = len(modelItems) - 1
+			}
+		}
+	}
+
+	return tea.Sequence(m.list.SetItems(modelItems), m.list.SetSelected(selectIndex))
+}
+
+// GetModelType returns the current model type
+func (m *ModelListComponent) GetModelType() int {
+	return m.modelType
+}
+

internal/tui/components/dialogs/models/models.go 🔗

@@ -1,8 +1,6 @@
 package models
 
 import (
-	"slices"
-
 	"github.com/charmbracelet/bubbles/v2/help"
 	"github.com/charmbracelet/bubbles/v2/key"
 	tea "github.com/charmbracelet/bubbletea/v2"
@@ -12,7 +10,6 @@ import (
 	"github.com/charmbracelet/crush/internal/tui/components/core"
 	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
-	"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
 	"github.com/charmbracelet/crush/internal/tui/styles"
 	"github.com/charmbracelet/crush/internal/tui/util"
 	"github.com/charmbracelet/lipgloss/v2"
@@ -53,10 +50,9 @@ type modelDialogCmp struct {
 	wWidth  int
 	wHeight int
 
-	modelList list.ListModel
+	modelList *ModelListComponent
 	keyMap    KeyMap
 	help      help.Model
-	modelType int
 }
 
 func NewModelDialogCmp() ModelDialog {
@@ -75,12 +71,7 @@ func NewModelDialogCmp() ModelDialog {
 
 	t := styles.CurrentTheme()
 	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
-	modelList := list.New(
-		list.WithFilterable(true),
-		list.WithKeyMap(listKeyMap),
-		list.WithInputStyle(inputStyle),
-		list.WithWrapNavigation(true),
-	)
+	modelList := NewModelListComponent(listKeyMap, inputStyle)
 	help := help.New()
 	help.Styles = t.S().Help
 
@@ -89,12 +80,10 @@ func NewModelDialogCmp() ModelDialog {
 		width:     defaultWidth,
 		keyMap:    DefaultKeyMap(),
 		help:      help,
-		modelType: LargeModelType,
 	}
 }
 
 func (m *modelDialogCmp) Init() tea.Cmd {
-	m.SetModelType(m.modelType)
 	return m.modelList.Init()
 }
 
@@ -103,7 +92,6 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	case tea.WindowSizeMsg:
 		m.wWidth = msg.Width
 		m.wHeight = msg.Height
-		m.SetModelType(m.modelType)
 		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
 	case tea.KeyPressMsg:
 		switch {
@@ -116,8 +104,8 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
 
 			var modelType config.SelectedModelType
-			if m.modelType == LargeModelType {
-				modelType = config.SelectedModelTypeLarge
+			if m.modelList.GetModelType() == LargeModelType {
+				modelType = config.LargeModel
 			} else {
 				modelType = config.SelectedModelTypeSmall
 			}
@@ -133,16 +121,16 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 				}),
 			)
 		case key.Matches(msg, m.keyMap.Tab):
-			if m.modelType == LargeModelType {
-				return m, m.SetModelType(SmallModelType)
+			if m.modelList.GetModelType() == LargeModelType {
+				return m, m.modelList.SetModelType(SmallModelType)
 			} else {
-				return m, m.SetModelType(LargeModelType)
+				return m, m.modelList.SetModelType(LargeModelType)
 			}
 		case key.Matches(msg, m.keyMap.Close):
 			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
 		default:
 			u, cmd := m.modelList.Update(msg)
-			m.modelList = u.(list.ListModel)
+			m.modelList = u
 			return m, cmd
 		}
 	}
@@ -181,7 +169,8 @@ func (m *modelDialogCmp) listWidth() int {
 }
 
 func (m *modelDialogCmp) listHeight() int {
-	listHeigh := len(m.modelList.Items()) + 2 + 4 // height based on items + 2 for the input + 4 for the sections
+	items := m.modelList.Items()
+	listHeigh := len(items) + 2 + 4
 	return min(listHeigh, m.wHeight/2)
 }
 
@@ -209,115 +198,8 @@ func (m *modelDialogCmp) modelTypeRadio() string {
 	choices := []string{"Large Task", "Small Task"}
 	iconSelected := "◉"
 	iconUnselected := "○"
-	if m.modelType == LargeModelType {
+	if m.modelList.GetModelType() == LargeModelType {
 		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
 	}
 	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
 }
-
-func (m *modelDialogCmp) SetModelType(modelType int) tea.Cmd {
-	m.modelType = modelType
-
-	providers, err := config.Providers()
-	if err != nil {
-		return util.ReportError(err)
-	}
-
-	modelItems := []util.Model{}
-	selectIndex := 0
-
-	cfg := config.Get()
-	var currentModel config.SelectedModel
-	if m.modelType == LargeModelType {
-		currentModel = cfg.Models[config.SelectedModelTypeLarge]
-	} else {
-		currentModel = cfg.Models[config.SelectedModelTypeSmall]
-	}
-
-	// Create a map to track which providers we've already added
-	addedProviders := make(map[string]bool)
-
-	// First, add any configured providers that are not in the known providers list
-	// These should appear at the top of the list
-	knownProviders := provider.KnownProviders()
-	for providerID, providerConfig := range cfg.Providers {
-		if providerConfig.Disable {
-			continue
-		}
-
-		// Check if this provider is not in the known providers list
-		if !slices.Contains(knownProviders, provider.InferenceProvider(providerID)) {
-			// Convert config provider to provider.Provider format
-			configProvider := provider.Provider{
-				Name:   string(providerID), // Use provider ID as name for unknown providers
-				ID:     provider.InferenceProvider(providerID),
-				Models: make([]provider.Model, len(providerConfig.Models)),
-			}
-
-			// Convert models
-			for i, model := range providerConfig.Models {
-				configProvider.Models[i] = provider.Model{
-					ID:                     model.ID,
-					Name:                   model.Name,
-					CostPer1MIn:            model.CostPer1MIn,
-					CostPer1MOut:           model.CostPer1MOut,
-					CostPer1MInCached:      model.CostPer1MInCached,
-					CostPer1MOutCached:     model.CostPer1MOutCached,
-					ContextWindow:          model.ContextWindow,
-					DefaultMaxTokens:       model.DefaultMaxTokens,
-					CanReason:              model.CanReason,
-					HasReasoningEffort:     model.HasReasoningEffort,
-					DefaultReasoningEffort: model.DefaultReasoningEffort,
-					SupportsImages:         model.SupportsImages,
-				}
-			}
-
-			// Add this unknown provider to the list
-			name := configProvider.Name
-			if name == "" {
-				name = string(configProvider.ID)
-			}
-			modelItems = append(modelItems, commands.NewItemSection(name))
-			for _, model := range configProvider.Models {
-				modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
-					Provider: configProvider,
-					Model:    model,
-				}))
-				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
-					selectIndex = len(modelItems) - 1 // Set the selected index to the current model
-				}
-			}
-			addedProviders[providerID] = true
-		}
-	}
-
-	// Then add the known providers from the predefined list
-	for _, provider := range providers {
-		// Skip if we already added this provider as an unknown provider
-		if addedProviders[string(provider.ID)] {
-			continue
-		}
-
-		// Check if this provider is configured and not disabled
-		if providerConfig, exists := cfg.Providers[string(provider.ID)]; exists && providerConfig.Disable {
-			continue
-		}
-
-		name := provider.Name
-		if name == "" {
-			name = string(provider.ID)
-		}
-		modelItems = append(modelItems, commands.NewItemSection(name))
-		for _, model := range provider.Models {
-			modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
-				Provider: provider,
-				Model:    model,
-			}))
-			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
-				selectIndex = len(modelItems) - 1 // Set the selected index to the current model
-			}
-		}
-	}
-
-	return tea.Sequence(m.modelList.SetItems(modelItems), m.modelList.SetSelected(selectIndex))
-}

internal/tui/page/chat/chat.go 🔗

@@ -109,17 +109,18 @@ func New(app *app.App) ChatPage {
 
 		keyMap: DefaultKeyMap(),
 
-		header:  header.New(app.LSPClients),
-		sidebar: sidebar.New(app.History, app.LSPClients, false),
-		chat:    chat.New(app),
-		editor:  editor.New(app),
-		splash:  splash.New(),
+		header:      header.New(app.LSPClients),
+		sidebar:     sidebar.New(app.History, app.LSPClients, false),
+		chat:        chat.New(app),
+		editor:      editor.New(app),
+		splash:      splash.New(),
+		focusedPane: PanelTypeSplash,
 	}
 }
 
 func (p *chatPage) Init() tea.Cmd {
 	cfg := config.Get()
-	if cfg.IsReady() {
+	if config.HasInitialDataConfig() {
 		if b, _ := config.ProjectNeedsInitialization(); b {
 			p.state = ChatStateInitProject
 		} else {
@@ -248,6 +249,10 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			u, cmd := p.editor.Update(msg)
 			p.editor = u.(editor.Editor)
 			cmds = append(cmds, cmd)
+		case PanelTypeSplash:
+			u, cmd := p.splash.Update(msg)
+			p.splash = u.(splash.Splash)
+			cmds = append(cmds, cmd)
 		}
 	}
 	return p, tea.Batch(cmds...)
@@ -258,11 +263,7 @@ func (p *chatPage) View() tea.View {
 	t := styles.CurrentTheme()
 	switch p.state {
 	case ChatStateOnboarding, ChatStateInitProject:
-		chatView = tea.NewView(
-			t.S().Base.Render(
-				p.splash.View().String(),
-			),
-		)
+		chatView = p.splash.View()
 	case ChatStateNewMessage:
 		editorView := p.editor.View()
 		chatView = tea.NewView(