splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/tui/components/chat"
 15	"github.com/charmbracelet/crush/internal/tui/components/completions"
 16	"github.com/charmbracelet/crush/internal/tui/components/core"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 18	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 20	"github.com/charmbracelet/crush/internal/tui/components/logo"
 21	"github.com/charmbracelet/crush/internal/tui/styles"
 22	"github.com/charmbracelet/crush/internal/tui/util"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/charmbracelet/lipgloss/v2"
 25)
 26
 27type Splash interface {
 28	util.Model
 29	layout.Sizeable
 30	layout.Help
 31	Cursor() *tea.Cursor
 32	// SetOnboarding controls whether the splash shows model selection UI
 33	SetOnboarding(bool)
 34	// SetProjectInit controls whether the splash shows project initialization prompt
 35	SetProjectInit(bool)
 36
 37	// Showing API key input
 38	IsShowingAPIKey() bool
 39}
 40
 41const (
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43
 44	LogoGap = 6
 45)
 46
 47// OnboardingCompleteMsg is sent when onboarding is complete
 48type OnboardingCompleteMsg struct{}
 49
 50type splashCmp struct {
 51	width, height int
 52	keyMap        KeyMap
 53	logoRendered  string
 54
 55	// State
 56	isOnboarding     bool
 57	needsProjectInit bool
 58	needsAPIKey      bool
 59	selectedNo       bool
 60
 61	listHeight    int
 62	modelList     *models.ModelListComponent
 63	apiKeyInput   *models.APIKeyInput
 64	selectedModel *models.ModelOption
 65}
 66
 67func New() Splash {
 68	keyMap := DefaultKeyMap()
 69	listKeyMap := list.DefaultKeyMap()
 70	listKeyMap.Down.SetEnabled(false)
 71	listKeyMap.Up.SetEnabled(false)
 72	listKeyMap.HalfPageDown.SetEnabled(false)
 73	listKeyMap.HalfPageUp.SetEnabled(false)
 74	listKeyMap.Home.SetEnabled(false)
 75	listKeyMap.End.SetEnabled(false)
 76	listKeyMap.DownOneItem = keyMap.Next
 77	listKeyMap.UpOneItem = keyMap.Previous
 78
 79	t := styles.CurrentTheme()
 80	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 81	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 82	apiKeyInput := models.NewAPIKeyInput()
 83
 84	return &splashCmp{
 85		width:        0,
 86		height:       0,
 87		keyMap:       keyMap,
 88		logoRendered: "",
 89		modelList:    modelList,
 90		apiKeyInput:  apiKeyInput,
 91		selectedNo:   false,
 92	}
 93}
 94
 95func (s *splashCmp) SetOnboarding(onboarding bool) {
 96	s.isOnboarding = onboarding
 97	if onboarding {
 98		providers, err := config.Providers()
 99		if err != nil {
100			return
101		}
102		filteredProviders := []provider.Provider{}
103		simpleProviders := []string{
104			"anthropic",
105			"openai",
106			"gemini",
107			"xai",
108			"openrouter",
109		}
110		for _, p := range providers {
111			if slices.Contains(simpleProviders, string(p.ID)) {
112				filteredProviders = append(filteredProviders, p)
113			}
114		}
115		s.modelList.SetProviders(filteredProviders)
116	}
117}
118
119func (s *splashCmp) SetProjectInit(needsInit bool) {
120	s.needsProjectInit = needsInit
121}
122
123// GetSize implements SplashPage.
124func (s *splashCmp) GetSize() (int, int) {
125	return s.width, s.height
126}
127
128// Init implements SplashPage.
129func (s *splashCmp) Init() tea.Cmd {
130	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
131}
132
133// SetSize implements SplashPage.
134func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
135	s.height = height
136	if width != s.width {
137		s.width = width
138		s.logoRendered = s.logoBlock()
139	}
140	// remove padding, logo height, gap, title space
141	s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
142	listWidth := min(60, width)
143	return s.modelList.SetSize(listWidth, s.listHeight)
144}
145
146// Update implements SplashPage.
147func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
148	switch msg := msg.(type) {
149	case tea.WindowSizeMsg:
150		return s, s.SetSize(msg.Width, msg.Height)
151	case tea.KeyPressMsg:
152		switch {
153		case key.Matches(msg, s.keyMap.Back):
154			if s.needsAPIKey {
155				// Go back to model selection
156				s.needsAPIKey = false
157				s.selectedModel = nil
158				return s, nil
159			}
160		case key.Matches(msg, s.keyMap.Select):
161			if s.isOnboarding && !s.needsAPIKey {
162				modelInx := s.modelList.SelectedIndex()
163				items := s.modelList.Items()
164				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
165				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
166					cmd := s.setPreferredModel(selectedItem)
167					s.isOnboarding = false
168					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
169				} else {
170					// Provider not configured, show API key input
171					s.needsAPIKey = true
172					s.selectedModel = &selectedItem
173					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
174					return s, nil
175				}
176			} else if s.needsAPIKey {
177				// Handle API key submission
178				apiKey := s.apiKeyInput.Value()
179				if apiKey != "" {
180					return s, s.saveAPIKeyAndContinue(apiKey)
181				}
182			} else if s.needsProjectInit {
183				return s, s.initializeProject()
184			}
185		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
186			if s.needsProjectInit {
187				s.selectedNo = !s.selectedNo
188				return s, nil
189			}
190		case key.Matches(msg, s.keyMap.Yes):
191			if s.needsProjectInit {
192				return s, s.initializeProject()
193			}
194		case key.Matches(msg, s.keyMap.No):
195			s.selectedNo = true
196			return s, s.initializeProject()
197		default:
198			if s.needsAPIKey {
199				u, cmd := s.apiKeyInput.Update(msg)
200				s.apiKeyInput = u.(*models.APIKeyInput)
201				return s, cmd
202			} else if s.isOnboarding {
203				u, cmd := s.modelList.Update(msg)
204				s.modelList = u
205				return s, cmd
206			}
207		}
208	case tea.PasteMsg:
209		if s.needsAPIKey {
210			u, cmd := s.apiKeyInput.Update(msg)
211			s.apiKeyInput = u.(*models.APIKeyInput)
212			return s, cmd
213		} else if s.isOnboarding {
214			var cmd tea.Cmd
215			s.modelList, cmd = s.modelList.Update(msg)
216			return s, cmd
217		}
218	}
219	return s, nil
220}
221
222func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
223	if s.selectedModel == nil {
224		return util.ReportError(fmt.Errorf("no model selected"))
225	}
226
227	cfg := config.Get()
228	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
229	if err != nil {
230		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
231	}
232
233	// Reset API key state and continue with model selection
234	s.needsAPIKey = false
235	cmd := s.setPreferredModel(*s.selectedModel)
236	s.isOnboarding = false
237	s.selectedModel = nil
238
239	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
240}
241
242func (s *splashCmp) initializeProject() tea.Cmd {
243	s.needsProjectInit = false
244
245	if err := config.MarkProjectInitialized(); err != nil {
246		return util.ReportError(err)
247	}
248	var cmds []tea.Cmd
249
250	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
251	if !s.selectedNo {
252		cmds = append(cmds,
253			util.CmdHandler(chat.SessionClearedMsg{}),
254			util.CmdHandler(chat.SendMsg{
255				Text: prompt.Initialize(),
256			}),
257		)
258	}
259	return tea.Sequence(cmds...)
260}
261
262func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
263	cfg := config.Get()
264	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
265	if model == nil {
266		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
267	}
268
269	selectedModel := config.SelectedModel{
270		Model:           selectedItem.Model.ID,
271		Provider:        string(selectedItem.Provider.ID),
272		ReasoningEffort: model.DefaultReasoningEffort,
273		MaxTokens:       model.DefaultMaxTokens,
274	}
275
276	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
277	if err != nil {
278		return util.ReportError(err)
279	}
280
281	// Now lets automatically setup the small model
282	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
283	if err != nil {
284		return util.ReportError(err)
285	}
286	if knownProvider == nil {
287		// for local provider we just use the same model
288		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
289		if err != nil {
290			return util.ReportError(err)
291		}
292	} else {
293		smallModel := knownProvider.DefaultSmallModelID
294		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
295		// should never happen
296		if model == nil {
297			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
298			if err != nil {
299				return util.ReportError(err)
300			}
301			return nil
302		}
303		smallSelectedModel := config.SelectedModel{
304			Model:           smallModel,
305			Provider:        string(selectedItem.Provider.ID),
306			ReasoningEffort: model.DefaultReasoningEffort,
307			MaxTokens:       model.DefaultMaxTokens,
308		}
309		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
310		if err != nil {
311			return util.ReportError(err)
312		}
313	}
314	return nil
315}
316
317func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
318	providers, err := config.Providers()
319	if err != nil {
320		return nil, err
321	}
322	for _, p := range providers {
323		if p.ID == providerID {
324			return &p, nil
325		}
326	}
327	return nil, nil
328}
329
330func (s *splashCmp) isProviderConfigured(providerID string) bool {
331	cfg := config.Get()
332	if _, ok := cfg.Providers[providerID]; ok {
333		return true
334	}
335	return false
336}
337
338func (s *splashCmp) View() string {
339	t := styles.CurrentTheme()
340	var content string
341	if s.needsAPIKey {
342		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
343		apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
344		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
345			lipgloss.JoinVertical(
346				lipgloss.Left,
347				apiKeyView,
348			),
349		)
350		content = lipgloss.JoinVertical(
351			lipgloss.Left,
352			s.logoRendered,
353			apiKeySelector,
354		)
355	} else if s.isOnboarding {
356		modelListView := s.modelList.View()
357		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
358		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
359			lipgloss.JoinVertical(
360				lipgloss.Left,
361				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
362				"",
363				modelListView,
364			),
365		)
366		content = lipgloss.JoinVertical(
367			lipgloss.Left,
368			s.logoRendered,
369			modelSelector,
370		)
371	} else if s.needsProjectInit {
372		titleStyle := t.S().Base.Foreground(t.FgBase)
373		bodyStyle := t.S().Base.Foreground(t.FgMuted)
374		shortcutStyle := t.S().Base.Foreground(t.Success)
375
376		initText := lipgloss.JoinVertical(
377			lipgloss.Left,
378			titleStyle.Render("Would you like to initialize this project?"),
379			"",
380			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
381			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
382			"",
383			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
384			"",
385			bodyStyle.Render("Would you like to initialize now?"),
386		)
387
388		yesButton := core.SelectableButton(core.ButtonOpts{
389			Text:           "Yep!",
390			UnderlineIndex: 0,
391			Selected:       !s.selectedNo,
392		})
393
394		noButton := core.SelectableButton(core.ButtonOpts{
395			Text:           "Nope",
396			UnderlineIndex: 0,
397			Selected:       s.selectedNo,
398		})
399
400		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
401		infoSection := s.infoSection()
402
403		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
404
405		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
406			lipgloss.JoinVertical(
407				lipgloss.Left,
408				initText,
409				"",
410				buttons,
411			),
412		)
413
414		content = lipgloss.JoinVertical(
415			lipgloss.Left,
416			s.logoRendered,
417			infoSection,
418			initContent,
419		)
420	} else {
421		parts := []string{
422			s.logoRendered,
423			s.infoSection(),
424		}
425		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
426	}
427
428	return t.S().Base.
429		Width(s.width).
430		Height(s.height).
431		PaddingTop(SplashScreenPaddingY).
432		PaddingBottom(SplashScreenPaddingY).
433		Render(content)
434}
435
436func (s *splashCmp) Cursor() *tea.Cursor {
437	if s.needsAPIKey {
438		cursor := s.apiKeyInput.Cursor()
439		if cursor != nil {
440			return s.moveCursor(cursor)
441		}
442	} else if s.isOnboarding {
443		cursor := s.modelList.Cursor()
444		if cursor != nil {
445			return s.moveCursor(cursor)
446		}
447	} else {
448		return nil
449	}
450	return nil
451}
452
453func (s *splashCmp) infoSection() string {
454	t := styles.CurrentTheme()
455	return t.S().Base.PaddingLeft(2).Render(
456		lipgloss.JoinVertical(
457			lipgloss.Left,
458			s.cwd(),
459			"",
460			lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
461			"",
462		),
463	)
464}
465
466func (s *splashCmp) logoBlock() string {
467	t := styles.CurrentTheme()
468	return t.S().Base.Padding(0, 2).Width(s.width).Render(
469		logo.Render(version.Version, false, logo.Opts{
470			FieldColor:   t.Primary,
471			TitleColorA:  t.Secondary,
472			TitleColorB:  t.Primary,
473			CharmColor:   t.Secondary,
474			VersionColor: t.Primary,
475			Width:        s.width - 4,
476		}),
477	)
478}
479
480func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
481	if cursor == nil {
482		return nil
483	}
484	// Calculate the correct Y offset based on current state
485	logoHeight := lipgloss.Height(s.logoRendered)
486	if s.needsAPIKey {
487		infoSectionHeight := lipgloss.Height(s.infoSection())
488		baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
489		remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
490		offset := baseOffset + remainingHeight
491		cursor.Y += offset
492		cursor.X = cursor.X + 1
493	} else if s.isOnboarding {
494		offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
495		cursor.Y += offset
496		cursor.X = cursor.X + 1
497	}
498
499	return cursor
500}
501
502func (s *splashCmp) logoGap() int {
503	if s.height > 35 {
504		return LogoGap
505	}
506	return 0
507}
508
509// Bindings implements SplashPage.
510func (s *splashCmp) Bindings() []key.Binding {
511	if s.needsAPIKey {
512		return []key.Binding{
513			s.keyMap.Select,
514			s.keyMap.Back,
515		}
516	} else if s.isOnboarding {
517		return []key.Binding{
518			s.keyMap.Select,
519			s.keyMap.Next,
520			s.keyMap.Previous,
521		}
522	} else if s.needsProjectInit {
523		return []key.Binding{
524			s.keyMap.Select,
525			s.keyMap.Yes,
526			s.keyMap.No,
527			s.keyMap.Tab,
528			s.keyMap.LeftRight,
529		}
530	}
531	return []key.Binding{}
532}
533
534func (s *splashCmp) getMaxInfoWidth() int {
535	return min(s.width-2, 40) // 2 for left padding
536}
537
538func (s *splashCmp) cwd() string {
539	cwd := config.Get().WorkingDir()
540	t := styles.CurrentTheme()
541	homeDir, err := os.UserHomeDir()
542	if err == nil && cwd != homeDir {
543		cwd = strings.ReplaceAll(cwd, homeDir, "~")
544	}
545	maxWidth := s.getMaxInfoWidth()
546	return t.S().Muted.Width(maxWidth).Render(cwd)
547}
548
549func LSPList(maxWidth int) []string {
550	t := styles.CurrentTheme()
551	lspList := []string{}
552	lsp := config.Get().LSP.Sorted()
553	if len(lsp) == 0 {
554		return []string{t.S().Base.Foreground(t.Border).Render("None")}
555	}
556	for _, l := range lsp {
557		iconColor := t.Success
558		if l.LSP.Disabled {
559			iconColor = t.FgMuted
560		}
561		lspList = append(lspList,
562			core.Status(
563				core.StatusOpts{
564					IconColor:   iconColor,
565					Title:       l.Name,
566					Description: l.LSP.Command,
567				},
568				maxWidth,
569			),
570		)
571	}
572	return lspList
573}
574
575func (s *splashCmp) lspBlock() string {
576	t := styles.CurrentTheme()
577	maxWidth := s.getMaxInfoWidth() / 2
578	section := t.S().Subtle.Render("LSPs")
579	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
580	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
581		lipgloss.JoinVertical(
582			lipgloss.Left,
583			lspList...,
584		),
585	)
586}
587
588func MCPList(maxWidth int) []string {
589	t := styles.CurrentTheme()
590	mcpList := []string{}
591	mcps := config.Get().MCP.Sorted()
592	if len(mcps) == 0 {
593		return []string{t.S().Base.Foreground(t.Border).Render("None")}
594	}
595	for _, l := range mcps {
596		iconColor := t.Success
597		if l.MCP.Disabled {
598			iconColor = t.FgMuted
599		}
600		mcpList = append(mcpList,
601			core.Status(
602				core.StatusOpts{
603					IconColor:   iconColor,
604					Title:       l.Name,
605					Description: l.MCP.Command,
606				},
607				maxWidth,
608			),
609		)
610	}
611	return mcpList
612}
613
614func (s *splashCmp) mcpBlock() string {
615	t := styles.CurrentTheme()
616	maxWidth := s.getMaxInfoWidth() / 2
617	section := t.S().Subtle.Render("MCPs")
618	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
619	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
620		lipgloss.JoinVertical(
621			lipgloss.Left,
622			mcpList...,
623		),
624	)
625}
626
627func (s *splashCmp) IsShowingAPIKey() bool {
628	return s.needsAPIKey
629}