splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/tui/components/chat"
 15	"github.com/charmbracelet/crush/internal/tui/components/completions"
 16	"github.com/charmbracelet/crush/internal/tui/components/core"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 18	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 20	"github.com/charmbracelet/crush/internal/tui/components/logo"
 21	"github.com/charmbracelet/crush/internal/tui/styles"
 22	"github.com/charmbracelet/crush/internal/tui/util"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/charmbracelet/lipgloss/v2"
 25)
 26
 27type Splash interface {
 28	util.Model
 29	layout.Sizeable
 30	layout.Help
 31	Cursor() *tea.Cursor
 32	// SetOnboarding controls whether the splash shows model selection UI
 33	SetOnboarding(bool)
 34	// SetProjectInit controls whether the splash shows project initialization prompt
 35	SetProjectInit(bool)
 36
 37	// Showing API key input
 38	IsShowingAPIKey() bool
 39}
 40
 41const (
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43
 44	LogoGap = 6
 45)
 46
 47// OnboardingCompleteMsg is sent when onboarding is complete
 48type OnboardingCompleteMsg struct{}
 49
 50type splashCmp struct {
 51	width, height int
 52	keyMap        KeyMap
 53	logoRendered  string
 54
 55	// State
 56	isOnboarding     bool
 57	needsProjectInit bool
 58	needsAPIKey      bool
 59	selectedNo       bool
 60
 61	listHeight    int
 62	modelList     *models.ModelListComponent
 63	apiKeyInput   *models.APIKeyInput
 64	selectedModel *models.ModelOption
 65}
 66
 67func New() Splash {
 68	keyMap := DefaultKeyMap()
 69	listKeyMap := list.DefaultKeyMap()
 70	listKeyMap.Down.SetEnabled(false)
 71	listKeyMap.Up.SetEnabled(false)
 72	listKeyMap.HalfPageDown.SetEnabled(false)
 73	listKeyMap.HalfPageUp.SetEnabled(false)
 74	listKeyMap.Home.SetEnabled(false)
 75	listKeyMap.End.SetEnabled(false)
 76	listKeyMap.DownOneItem = keyMap.Next
 77	listKeyMap.UpOneItem = keyMap.Previous
 78
 79	t := styles.CurrentTheme()
 80	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 81	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 82	apiKeyInput := models.NewAPIKeyInput()
 83
 84	return &splashCmp{
 85		width:        0,
 86		height:       0,
 87		keyMap:       keyMap,
 88		logoRendered: "",
 89		modelList:    modelList,
 90		apiKeyInput:  apiKeyInput,
 91		selectedNo:   false,
 92	}
 93}
 94
 95func (s *splashCmp) SetOnboarding(onboarding bool) {
 96	s.isOnboarding = onboarding
 97	if onboarding {
 98		providers, err := config.Providers()
 99		if err != nil {
100			return
101		}
102		filteredProviders := []provider.Provider{}
103		simpleProviders := []string{
104			"anthropic",
105			"openai",
106			"gemini",
107			"xai",
108			"openrouter",
109		}
110		for _, p := range providers {
111			if slices.Contains(simpleProviders, string(p.ID)) {
112				filteredProviders = append(filteredProviders, p)
113			}
114		}
115		s.modelList.SetProviders(filteredProviders)
116	}
117}
118
119func (s *splashCmp) SetProjectInit(needsInit bool) {
120	s.needsProjectInit = needsInit
121}
122
123// GetSize implements SplashPage.
124func (s *splashCmp) GetSize() (int, int) {
125	return s.width, s.height
126}
127
128// Init implements SplashPage.
129func (s *splashCmp) Init() tea.Cmd {
130	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
131}
132
133// SetSize implements SplashPage.
134func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
135	s.height = height
136	if width != s.width {
137		s.width = width
138		s.logoRendered = s.logoBlock()
139	}
140	// remove padding, logo height, gap, title space
141	s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
142	listWidth := min(60, width)
143	return s.modelList.SetSize(listWidth, s.listHeight)
144}
145
146// Update implements SplashPage.
147func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
148	switch msg := msg.(type) {
149	case tea.WindowSizeMsg:
150		return s, s.SetSize(msg.Width, msg.Height)
151	case tea.KeyPressMsg:
152		switch {
153		case key.Matches(msg, s.keyMap.Back):
154			if s.needsAPIKey {
155				// Go back to model selection
156				s.needsAPIKey = false
157				s.selectedModel = nil
158				return s, nil
159			}
160		case key.Matches(msg, s.keyMap.Select):
161			if s.isOnboarding && !s.needsAPIKey {
162				modelInx := s.modelList.SelectedIndex()
163				items := s.modelList.Items()
164				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
165				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
166					cmd := s.setPreferredModel(selectedItem)
167					s.isOnboarding = false
168					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
169				} else {
170					// Provider not configured, show API key input
171					s.needsAPIKey = true
172					s.selectedModel = &selectedItem
173					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
174					return s, nil
175				}
176			} else if s.needsAPIKey {
177				// Handle API key submission
178				apiKey := s.apiKeyInput.Value()
179				if apiKey != "" {
180					return s, s.saveAPIKeyAndContinue(apiKey)
181				}
182			} else if s.needsProjectInit {
183				return s, s.initializeProject()
184			}
185		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
186			if s.needsProjectInit {
187				s.selectedNo = !s.selectedNo
188				return s, nil
189			}
190		case key.Matches(msg, s.keyMap.Yes):
191			if s.needsProjectInit {
192				return s, s.initializeProject()
193			}
194		case key.Matches(msg, s.keyMap.No):
195			s.selectedNo = true
196			return s, s.initializeProject()
197		default:
198			if s.needsAPIKey {
199				u, cmd := s.apiKeyInput.Update(msg)
200				s.apiKeyInput = u.(*models.APIKeyInput)
201				return s, cmd
202			} else if s.isOnboarding {
203				u, cmd := s.modelList.Update(msg)
204				s.modelList = u
205				return s, cmd
206			}
207		}
208	case tea.PasteMsg:
209		if s.needsAPIKey {
210			u, cmd := s.apiKeyInput.Update(msg)
211			s.apiKeyInput = u.(*models.APIKeyInput)
212			return s, cmd
213		} else if s.isOnboarding {
214			var cmd tea.Cmd
215			s.modelList, cmd = s.modelList.Update(msg)
216			return s, cmd
217		}
218	}
219	return s, nil
220}
221
222func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
223	if s.selectedModel == nil {
224		return util.ReportError(fmt.Errorf("no model selected"))
225	}
226
227	cfg := config.Get()
228	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
229	if err != nil {
230		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
231	}
232
233	// Reset API key state and continue with model selection
234	s.needsAPIKey = false
235	cmd := s.setPreferredModel(*s.selectedModel)
236	s.isOnboarding = false
237	s.selectedModel = nil
238
239	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
240}
241
242func (s *splashCmp) initializeProject() tea.Cmd {
243	s.needsProjectInit = false
244
245	if err := config.MarkProjectInitialized(); err != nil {
246		return util.ReportError(err)
247	}
248	var cmds []tea.Cmd
249
250	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
251	if !s.selectedNo {
252		cmds = append(cmds,
253			util.CmdHandler(chat.SessionClearedMsg{}),
254			util.CmdHandler(chat.SendMsg{
255				Text: prompt.Initialize(),
256			}),
257		)
258	}
259	return tea.Sequence(cmds...)
260}
261
262func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
263	cfg := config.Get()
264	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
265	if model == nil {
266		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
267	}
268
269	selectedModel := config.SelectedModel{
270		Model:           selectedItem.Model.ID,
271		Provider:        string(selectedItem.Provider.ID),
272		ReasoningEffort: model.DefaultReasoningEffort,
273		MaxTokens:       model.DefaultMaxTokens,
274	}
275
276	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
277	if err != nil {
278		return util.ReportError(err)
279	}
280
281	// Now lets automatically setup the small model
282	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
283	if err != nil {
284		return util.ReportError(err)
285	}
286	if knownProvider == nil {
287		// for local provider we just use the same model
288		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
289		if err != nil {
290			return util.ReportError(err)
291		}
292	} else {
293		smallModel := knownProvider.DefaultSmallModelID
294		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
295		// should never happen
296		if model == nil {
297			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
298			if err != nil {
299				return util.ReportError(err)
300			}
301			return nil
302		}
303		smallSelectedModel := config.SelectedModel{
304			Model:           smallModel,
305			Provider:        string(selectedItem.Provider.ID),
306			ReasoningEffort: model.DefaultReasoningEffort,
307			MaxTokens:       model.DefaultMaxTokens,
308		}
309		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
310		if err != nil {
311			return util.ReportError(err)
312		}
313	}
314	cfg.SetupAgents()
315	return nil
316}
317
318func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
319	providers, err := config.Providers()
320	if err != nil {
321		return nil, err
322	}
323	for _, p := range providers {
324		if p.ID == providerID {
325			return &p, nil
326		}
327	}
328	return nil, nil
329}
330
331func (s *splashCmp) isProviderConfigured(providerID string) bool {
332	cfg := config.Get()
333	if _, ok := cfg.Providers[providerID]; ok {
334		return true
335	}
336	return false
337}
338
339func (s *splashCmp) View() string {
340	t := styles.CurrentTheme()
341	var content string
342	if s.needsAPIKey {
343		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
344		apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
345		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
346			lipgloss.JoinVertical(
347				lipgloss.Left,
348				apiKeyView,
349			),
350		)
351		content = lipgloss.JoinVertical(
352			lipgloss.Left,
353			s.logoRendered,
354			apiKeySelector,
355		)
356	} else if s.isOnboarding {
357		modelListView := s.modelList.View()
358		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
359		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
360			lipgloss.JoinVertical(
361				lipgloss.Left,
362				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
363				"",
364				modelListView,
365			),
366		)
367		content = lipgloss.JoinVertical(
368			lipgloss.Left,
369			s.logoRendered,
370			modelSelector,
371		)
372	} else if s.needsProjectInit {
373		titleStyle := t.S().Base.Foreground(t.FgBase)
374		bodyStyle := t.S().Base.Foreground(t.FgMuted)
375		shortcutStyle := t.S().Base.Foreground(t.Success)
376
377		initText := lipgloss.JoinVertical(
378			lipgloss.Left,
379			titleStyle.Render("Would you like to initialize this project?"),
380			"",
381			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
382			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
383			"",
384			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
385			"",
386			bodyStyle.Render("Would you like to initialize now?"),
387		)
388
389		yesButton := core.SelectableButton(core.ButtonOpts{
390			Text:           "Yep!",
391			UnderlineIndex: 0,
392			Selected:       !s.selectedNo,
393		})
394
395		noButton := core.SelectableButton(core.ButtonOpts{
396			Text:           "Nope",
397			UnderlineIndex: 0,
398			Selected:       s.selectedNo,
399		})
400
401		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
402		infoSection := s.infoSection()
403
404		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
405
406		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
407			lipgloss.JoinVertical(
408				lipgloss.Left,
409				initText,
410				"",
411				buttons,
412			),
413		)
414
415		content = lipgloss.JoinVertical(
416			lipgloss.Left,
417			s.logoRendered,
418			infoSection,
419			initContent,
420		)
421	} else {
422		parts := []string{
423			s.logoRendered,
424			s.infoSection(),
425		}
426		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
427	}
428
429	return t.S().Base.
430		Width(s.width).
431		Height(s.height).
432		PaddingTop(SplashScreenPaddingY).
433		PaddingBottom(SplashScreenPaddingY).
434		Render(content)
435}
436
437func (s *splashCmp) Cursor() *tea.Cursor {
438	if s.needsAPIKey {
439		cursor := s.apiKeyInput.Cursor()
440		if cursor != nil {
441			return s.moveCursor(cursor)
442		}
443	} else if s.isOnboarding {
444		cursor := s.modelList.Cursor()
445		if cursor != nil {
446			return s.moveCursor(cursor)
447		}
448	} else {
449		return nil
450	}
451	return nil
452}
453
454func (s *splashCmp) infoSection() string {
455	t := styles.CurrentTheme()
456	return t.S().Base.PaddingLeft(2).Render(
457		lipgloss.JoinVertical(
458			lipgloss.Left,
459			s.cwd(),
460			"",
461			lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
462			"",
463		),
464	)
465}
466
467func (s *splashCmp) logoBlock() string {
468	t := styles.CurrentTheme()
469	return t.S().Base.Padding(0, 2).Width(s.width).Render(
470		logo.Render(version.Version, false, logo.Opts{
471			FieldColor:   t.Primary,
472			TitleColorA:  t.Secondary,
473			TitleColorB:  t.Primary,
474			CharmColor:   t.Secondary,
475			VersionColor: t.Primary,
476			Width:        s.width - 4,
477		}),
478	)
479}
480
481func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
482	if cursor == nil {
483		return nil
484	}
485	// Calculate the correct Y offset based on current state
486	logoHeight := lipgloss.Height(s.logoRendered)
487	if s.needsAPIKey {
488		infoSectionHeight := lipgloss.Height(s.infoSection())
489		baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
490		remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
491		offset := baseOffset + remainingHeight
492		cursor.Y += offset
493		cursor.X = cursor.X + 1
494	} else if s.isOnboarding {
495		offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
496		cursor.Y += offset
497		cursor.X = cursor.X + 1
498	}
499
500	return cursor
501}
502
503func (s *splashCmp) logoGap() int {
504	if s.height > 35 {
505		return LogoGap
506	}
507	return 0
508}
509
510// Bindings implements SplashPage.
511func (s *splashCmp) Bindings() []key.Binding {
512	if s.needsAPIKey {
513		return []key.Binding{
514			s.keyMap.Select,
515			s.keyMap.Back,
516		}
517	} else if s.isOnboarding {
518		return []key.Binding{
519			s.keyMap.Select,
520			s.keyMap.Next,
521			s.keyMap.Previous,
522		}
523	} else if s.needsProjectInit {
524		return []key.Binding{
525			s.keyMap.Select,
526			s.keyMap.Yes,
527			s.keyMap.No,
528			s.keyMap.Tab,
529			s.keyMap.LeftRight,
530		}
531	}
532	return []key.Binding{}
533}
534
535func (s *splashCmp) getMaxInfoWidth() int {
536	return min(s.width-2, 40) // 2 for left padding
537}
538
539func (s *splashCmp) cwd() string {
540	cwd := config.Get().WorkingDir()
541	t := styles.CurrentTheme()
542	homeDir, err := os.UserHomeDir()
543	if err == nil && cwd != homeDir {
544		cwd = strings.ReplaceAll(cwd, homeDir, "~")
545	}
546	maxWidth := s.getMaxInfoWidth()
547	return t.S().Muted.Width(maxWidth).Render(cwd)
548}
549
550func LSPList(maxWidth int) []string {
551	t := styles.CurrentTheme()
552	lspList := []string{}
553	lsp := config.Get().LSP.Sorted()
554	if len(lsp) == 0 {
555		return []string{t.S().Base.Foreground(t.Border).Render("None")}
556	}
557	for _, l := range lsp {
558		iconColor := t.Success
559		if l.LSP.Disabled {
560			iconColor = t.FgMuted
561		}
562		lspList = append(lspList,
563			core.Status(
564				core.StatusOpts{
565					IconColor:   iconColor,
566					Title:       l.Name,
567					Description: l.LSP.Command,
568				},
569				maxWidth,
570			),
571		)
572	}
573	return lspList
574}
575
576func (s *splashCmp) lspBlock() string {
577	t := styles.CurrentTheme()
578	maxWidth := s.getMaxInfoWidth() / 2
579	section := t.S().Subtle.Render("LSPs")
580	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
581	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
582		lipgloss.JoinVertical(
583			lipgloss.Left,
584			lspList...,
585		),
586	)
587}
588
589func MCPList(maxWidth int) []string {
590	t := styles.CurrentTheme()
591	mcpList := []string{}
592	mcps := config.Get().MCP.Sorted()
593	if len(mcps) == 0 {
594		return []string{t.S().Base.Foreground(t.Border).Render("None")}
595	}
596	for _, l := range mcps {
597		iconColor := t.Success
598		if l.MCP.Disabled {
599			iconColor = t.FgMuted
600		}
601		mcpList = append(mcpList,
602			core.Status(
603				core.StatusOpts{
604					IconColor:   iconColor,
605					Title:       l.Name,
606					Description: l.MCP.Command,
607				},
608				maxWidth,
609			),
610		)
611	}
612	return mcpList
613}
614
615func (s *splashCmp) mcpBlock() string {
616	t := styles.CurrentTheme()
617	maxWidth := s.getMaxInfoWidth() / 2
618	section := t.S().Subtle.Render("MCPs")
619	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
620	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
621		lipgloss.JoinVertical(
622			lipgloss.Left,
623			mcpList...,
624		),
625	)
626}
627
628func (s *splashCmp) IsShowingAPIKey() bool {
629	return s.needsAPIKey
630}