splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/tui/components/chat"
 15	"github.com/charmbracelet/crush/internal/tui/components/completions"
 16	"github.com/charmbracelet/crush/internal/tui/components/core"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 18	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 20	"github.com/charmbracelet/crush/internal/tui/components/logo"
 21	"github.com/charmbracelet/crush/internal/tui/styles"
 22	"github.com/charmbracelet/crush/internal/tui/util"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/charmbracelet/lipgloss/v2"
 25)
 26
 27type Splash interface {
 28	util.Model
 29	layout.Sizeable
 30	layout.Help
 31	Cursor() *tea.Cursor
 32	// SetOnboarding controls whether the splash shows model selection UI
 33	SetOnboarding(bool)
 34	// SetProjectInit controls whether the splash shows project initialization prompt
 35	SetProjectInit(bool)
 36
 37	// Showing API key input
 38	IsShowingAPIKey() bool
 39}
 40
 41const (
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43
 44	LogoGap = 6
 45)
 46
 47// OnboardingCompleteMsg is sent when onboarding is complete
 48type OnboardingCompleteMsg struct{}
 49
 50type splashCmp struct {
 51	width, height int
 52	keyMap        KeyMap
 53	logoRendered  string
 54
 55	// State
 56	isOnboarding     bool
 57	needsProjectInit bool
 58	needsAPIKey      bool
 59	selectedNo       bool
 60
 61	listHeight    int
 62	modelList     *models.ModelListComponent
 63	apiKeyInput   *models.APIKeyInput
 64	selectedModel *models.ModelOption
 65}
 66
 67func New() Splash {
 68	keyMap := DefaultKeyMap()
 69	listKeyMap := list.DefaultKeyMap()
 70	listKeyMap.Down.SetEnabled(false)
 71	listKeyMap.Up.SetEnabled(false)
 72	listKeyMap.HalfPageDown.SetEnabled(false)
 73	listKeyMap.HalfPageUp.SetEnabled(false)
 74	listKeyMap.Home.SetEnabled(false)
 75	listKeyMap.End.SetEnabled(false)
 76	listKeyMap.DownOneItem = keyMap.Next
 77	listKeyMap.UpOneItem = keyMap.Previous
 78
 79	t := styles.CurrentTheme()
 80	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 81	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 82	apiKeyInput := models.NewAPIKeyInput()
 83
 84	return &splashCmp{
 85		width:        0,
 86		height:       0,
 87		keyMap:       keyMap,
 88		logoRendered: "",
 89		modelList:    modelList,
 90		apiKeyInput:  apiKeyInput,
 91		selectedNo:   false,
 92	}
 93}
 94
 95func (s *splashCmp) SetOnboarding(onboarding bool) {
 96	s.isOnboarding = onboarding
 97	if onboarding {
 98		providers, err := config.Providers()
 99		if err != nil {
100			return
101		}
102		filteredProviders := []provider.Provider{}
103		simpleProviders := []string{
104			"anthropic",
105			"openai",
106			"gemini",
107			"xai",
108			"groq",
109			"openrouter",
110		}
111		for _, p := range providers {
112			if slices.Contains(simpleProviders, string(p.ID)) {
113				filteredProviders = append(filteredProviders, p)
114			}
115		}
116		s.modelList.SetProviders(filteredProviders)
117	}
118}
119
120func (s *splashCmp) SetProjectInit(needsInit bool) {
121	s.needsProjectInit = needsInit
122}
123
124// GetSize implements SplashPage.
125func (s *splashCmp) GetSize() (int, int) {
126	return s.width, s.height
127}
128
129// Init implements SplashPage.
130func (s *splashCmp) Init() tea.Cmd {
131	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
132}
133
134// SetSize implements SplashPage.
135func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
136	s.height = height
137	s.width = width
138	s.logoRendered = s.logoBlock()
139	// remove padding, logo height, gap, title space
140	s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
141	listWidth := min(60, width)
142	return s.modelList.SetSize(listWidth, s.listHeight)
143}
144
145// Update implements SplashPage.
146func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
147	switch msg := msg.(type) {
148	case tea.WindowSizeMsg:
149		return s, s.SetSize(msg.Width, msg.Height)
150	case tea.KeyPressMsg:
151		switch {
152		case key.Matches(msg, s.keyMap.Back):
153			if s.needsAPIKey {
154				// Go back to model selection
155				s.needsAPIKey = false
156				s.selectedModel = nil
157				return s, nil
158			}
159		case key.Matches(msg, s.keyMap.Select):
160			if s.isOnboarding && !s.needsAPIKey {
161				modelInx := s.modelList.SelectedIndex()
162				items := s.modelList.Items()
163				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
164				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
165					cmd := s.setPreferredModel(selectedItem)
166					s.isOnboarding = false
167					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
168				} else {
169					// Provider not configured, show API key input
170					s.needsAPIKey = true
171					s.selectedModel = &selectedItem
172					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
173					return s, nil
174				}
175			} else if s.needsAPIKey {
176				// Handle API key submission
177				apiKey := s.apiKeyInput.Value()
178				if apiKey != "" {
179					return s, s.saveAPIKeyAndContinue(apiKey)
180				}
181			} else if s.needsProjectInit {
182				return s, s.initializeProject()
183			}
184		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
185			if s.needsProjectInit {
186				s.selectedNo = !s.selectedNo
187				return s, nil
188			}
189		case key.Matches(msg, s.keyMap.Yes):
190			if s.needsProjectInit {
191				return s, s.initializeProject()
192			}
193		case key.Matches(msg, s.keyMap.No):
194			s.selectedNo = true
195			return s, s.initializeProject()
196		default:
197			if s.needsAPIKey {
198				u, cmd := s.apiKeyInput.Update(msg)
199				s.apiKeyInput = u.(*models.APIKeyInput)
200				return s, cmd
201			} else if s.isOnboarding {
202				u, cmd := s.modelList.Update(msg)
203				s.modelList = u
204				return s, cmd
205			}
206		}
207	case tea.PasteMsg:
208		if s.needsAPIKey {
209			u, cmd := s.apiKeyInput.Update(msg)
210			s.apiKeyInput = u.(*models.APIKeyInput)
211			return s, cmd
212		} else if s.isOnboarding {
213			var cmd tea.Cmd
214			s.modelList, cmd = s.modelList.Update(msg)
215			return s, cmd
216		}
217	}
218	return s, nil
219}
220
221func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
222	if s.selectedModel == nil {
223		return util.ReportError(fmt.Errorf("no model selected"))
224	}
225
226	cfg := config.Get()
227	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
228	if err != nil {
229		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
230	}
231
232	// Reset API key state and continue with model selection
233	s.needsAPIKey = false
234	cmd := s.setPreferredModel(*s.selectedModel)
235	s.isOnboarding = false
236	s.selectedModel = nil
237
238	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
239}
240
241func (s *splashCmp) initializeProject() tea.Cmd {
242	s.needsProjectInit = false
243
244	if err := config.MarkProjectInitialized(); err != nil {
245		return util.ReportError(err)
246	}
247	var cmds []tea.Cmd
248
249	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
250	if !s.selectedNo {
251		cmds = append(cmds,
252			util.CmdHandler(chat.SessionClearedMsg{}),
253			util.CmdHandler(chat.SendMsg{
254				Text: prompt.Initialize(),
255			}),
256		)
257	}
258	return tea.Sequence(cmds...)
259}
260
261func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
262	cfg := config.Get()
263	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
264	if model == nil {
265		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
266	}
267
268	selectedModel := config.SelectedModel{
269		Model:           selectedItem.Model.ID,
270		Provider:        string(selectedItem.Provider.ID),
271		ReasoningEffort: model.DefaultReasoningEffort,
272		MaxTokens:       model.DefaultMaxTokens,
273	}
274
275	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
276	if err != nil {
277		return util.ReportError(err)
278	}
279
280	// Now lets automatically setup the small model
281	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
282	if err != nil {
283		return util.ReportError(err)
284	}
285	if knownProvider == nil {
286		// for local provider we just use the same model
287		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
288		if err != nil {
289			return util.ReportError(err)
290		}
291	} else {
292		smallModel := knownProvider.DefaultSmallModelID
293		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
294		// should never happen
295		if model == nil {
296			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
297			if err != nil {
298				return util.ReportError(err)
299			}
300			return nil
301		}
302		smallSelectedModel := config.SelectedModel{
303			Model:           smallModel,
304			Provider:        string(selectedItem.Provider.ID),
305			ReasoningEffort: model.DefaultReasoningEffort,
306			MaxTokens:       model.DefaultMaxTokens,
307		}
308		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
309		if err != nil {
310			return util.ReportError(err)
311		}
312	}
313	cfg.SetupAgents()
314	return nil
315}
316
317func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
318	providers, err := config.Providers()
319	if err != nil {
320		return nil, err
321	}
322	for _, p := range providers {
323		if p.ID == providerID {
324			return &p, nil
325		}
326	}
327	return nil, nil
328}
329
330func (s *splashCmp) isProviderConfigured(providerID string) bool {
331	cfg := config.Get()
332	if _, ok := cfg.Providers[providerID]; ok {
333		return true
334	}
335	return false
336}
337
338func (s *splashCmp) View() string {
339	t := styles.CurrentTheme()
340	var content string
341	if s.needsAPIKey {
342		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
343		apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
344		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
345			lipgloss.JoinVertical(
346				lipgloss.Left,
347				apiKeyView,
348			),
349		)
350		content = lipgloss.JoinVertical(
351			lipgloss.Left,
352			s.logoRendered,
353			apiKeySelector,
354		)
355	} else if s.isOnboarding {
356		modelListView := s.modelList.View()
357		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
358		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
359			lipgloss.JoinVertical(
360				lipgloss.Left,
361				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
362				"",
363				modelListView,
364			),
365		)
366		content = lipgloss.JoinVertical(
367			lipgloss.Left,
368			s.logoRendered,
369			modelSelector,
370		)
371	} else if s.needsProjectInit {
372		titleStyle := t.S().Base.Foreground(t.FgBase)
373		bodyStyle := t.S().Base.Foreground(t.FgMuted)
374		shortcutStyle := t.S().Base.Foreground(t.Success)
375
376		initText := lipgloss.JoinVertical(
377			lipgloss.Left,
378			titleStyle.Render("Would you like to initialize this project?"),
379			"",
380			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
381			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
382			"",
383			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
384			"",
385			bodyStyle.Render("Would you like to initialize now?"),
386		)
387
388		yesButton := core.SelectableButton(core.ButtonOpts{
389			Text:           "Yep!",
390			UnderlineIndex: 0,
391			Selected:       !s.selectedNo,
392		})
393
394		noButton := core.SelectableButton(core.ButtonOpts{
395			Text:           "Nope",
396			UnderlineIndex: 0,
397			Selected:       s.selectedNo,
398		})
399
400		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
401		infoSection := s.infoSection()
402
403		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
404
405		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
406			lipgloss.JoinVertical(
407				lipgloss.Left,
408				initText,
409				"",
410				buttons,
411			),
412		)
413
414		content = lipgloss.JoinVertical(
415			lipgloss.Left,
416			s.logoRendered,
417			infoSection,
418			initContent,
419		)
420	} else {
421		parts := []string{
422			s.logoRendered,
423			s.infoSection(),
424		}
425		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
426	}
427
428	return t.S().Base.
429		Width(s.width).
430		Height(s.height).
431		PaddingTop(SplashScreenPaddingY).
432		PaddingBottom(SplashScreenPaddingY).
433		Render(content)
434}
435
436func (s *splashCmp) Cursor() *tea.Cursor {
437	if s.needsAPIKey {
438		cursor := s.apiKeyInput.Cursor()
439		if cursor != nil {
440			return s.moveCursor(cursor)
441		}
442	} else if s.isOnboarding {
443		cursor := s.modelList.Cursor()
444		if cursor != nil {
445			return s.moveCursor(cursor)
446		}
447	} else {
448		return nil
449	}
450	return nil
451}
452
453func (s *splashCmp) infoSection() string {
454	t := styles.CurrentTheme()
455	return t.S().Base.PaddingLeft(2).Render(
456		lipgloss.JoinVertical(
457			lipgloss.Left,
458			s.cwd(),
459			"",
460			lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
461			"",
462		),
463	)
464}
465
466func (s *splashCmp) logoBlock() string {
467	t := styles.CurrentTheme()
468	logoStyle := t.S().Base.Padding(0, 2).Width(s.width)
469	if s.width < 40 || s.height < 20 {
470		// If the width is too small, render a smaller version of the logo
471		// NOTE: 20 is not correct because [splashCmp.height] is not the
472		// *actual* window height, instead, it is the height of the splash
473		// component and that depends on other variables like compact mode and
474		// the height of the editor.
475		return logoStyle.Render(
476			logo.SmallRender(s.width - logoStyle.GetHorizontalFrameSize()),
477		)
478	}
479	return logoStyle.Render(
480		logo.Render(version.Version, false, logo.Opts{
481			FieldColor:   t.Primary,
482			TitleColorA:  t.Secondary,
483			TitleColorB:  t.Primary,
484			CharmColor:   t.Secondary,
485			VersionColor: t.Primary,
486			Width:        s.width - logoStyle.GetHorizontalFrameSize(),
487		}),
488	)
489}
490
491func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
492	if cursor == nil {
493		return nil
494	}
495	// Calculate the correct Y offset based on current state
496	logoHeight := lipgloss.Height(s.logoRendered)
497	if s.needsAPIKey {
498		infoSectionHeight := lipgloss.Height(s.infoSection())
499		baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
500		remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
501		offset := baseOffset + remainingHeight
502		cursor.Y += offset
503		cursor.X = cursor.X + 1
504	} else if s.isOnboarding {
505		offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
506		cursor.Y += offset
507		cursor.X = cursor.X + 1
508	}
509
510	return cursor
511}
512
513func (s *splashCmp) logoGap() int {
514	if s.height > 35 {
515		return LogoGap
516	}
517	return 0
518}
519
520// Bindings implements SplashPage.
521func (s *splashCmp) Bindings() []key.Binding {
522	if s.needsAPIKey {
523		return []key.Binding{
524			s.keyMap.Select,
525			s.keyMap.Back,
526		}
527	} else if s.isOnboarding {
528		return []key.Binding{
529			s.keyMap.Select,
530			s.keyMap.Next,
531			s.keyMap.Previous,
532		}
533	} else if s.needsProjectInit {
534		return []key.Binding{
535			s.keyMap.Select,
536			s.keyMap.Yes,
537			s.keyMap.No,
538			s.keyMap.Tab,
539			s.keyMap.LeftRight,
540		}
541	}
542	return []key.Binding{}
543}
544
545func (s *splashCmp) getMaxInfoWidth() int {
546	return min(s.width-2, 40) // 2 for left padding
547}
548
549func (s *splashCmp) cwd() string {
550	cwd := config.Get().WorkingDir()
551	t := styles.CurrentTheme()
552	homeDir, err := os.UserHomeDir()
553	if err == nil && cwd != homeDir {
554		cwd = strings.ReplaceAll(cwd, homeDir, "~")
555	}
556	maxWidth := s.getMaxInfoWidth()
557	return t.S().Muted.Width(maxWidth).Render(cwd)
558}
559
560func LSPList(maxWidth int) []string {
561	t := styles.CurrentTheme()
562	lspList := []string{}
563	lsp := config.Get().LSP.Sorted()
564	if len(lsp) == 0 {
565		return []string{t.S().Base.Foreground(t.Border).Render("None")}
566	}
567	for _, l := range lsp {
568		iconColor := t.Success
569		if l.LSP.Disabled {
570			iconColor = t.FgMuted
571		}
572		lspList = append(lspList,
573			core.Status(
574				core.StatusOpts{
575					IconColor:   iconColor,
576					Title:       l.Name,
577					Description: l.LSP.Command,
578				},
579				maxWidth,
580			),
581		)
582	}
583	return lspList
584}
585
586func (s *splashCmp) lspBlock() string {
587	t := styles.CurrentTheme()
588	maxWidth := s.getMaxInfoWidth() / 2
589	section := t.S().Subtle.Render("LSPs")
590	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
591	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
592		lipgloss.JoinVertical(
593			lipgloss.Left,
594			lspList...,
595		),
596	)
597}
598
599func MCPList(maxWidth int) []string {
600	t := styles.CurrentTheme()
601	mcpList := []string{}
602	mcps := config.Get().MCP.Sorted()
603	if len(mcps) == 0 {
604		return []string{t.S().Base.Foreground(t.Border).Render("None")}
605	}
606	for _, l := range mcps {
607		iconColor := t.Success
608		if l.MCP.Disabled {
609			iconColor = t.FgMuted
610		}
611		mcpList = append(mcpList,
612			core.Status(
613				core.StatusOpts{
614					IconColor:   iconColor,
615					Title:       l.Name,
616					Description: l.MCP.Command,
617				},
618				maxWidth,
619			),
620		)
621	}
622	return mcpList
623}
624
625func (s *splashCmp) mcpBlock() string {
626	t := styles.CurrentTheme()
627	maxWidth := s.getMaxInfoWidth() / 2
628	section := t.S().Subtle.Render("MCPs")
629	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
630	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
631		lipgloss.JoinVertical(
632			lipgloss.Left,
633			mcpList...,
634		),
635	)
636}
637
638func (s *splashCmp) IsShowingAPIKey() bool {
639	return s.needsAPIKey
640}