1package splash
  2
  3import (
  4	"fmt"
  5	"os"
  6	"slices"
  7	"strings"
  8
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/llm/prompt"
 14	"github.com/charmbracelet/crush/internal/tui/components/chat"
 15	"github.com/charmbracelet/crush/internal/tui/components/completions"
 16	"github.com/charmbracelet/crush/internal/tui/components/core"
 17	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 18	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 20	"github.com/charmbracelet/crush/internal/tui/components/logo"
 21	"github.com/charmbracelet/crush/internal/tui/styles"
 22	"github.com/charmbracelet/crush/internal/tui/util"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/charmbracelet/lipgloss/v2"
 25)
 26
 27type Splash interface {
 28	util.Model
 29	layout.Sizeable
 30	layout.Help
 31	Cursor() *tea.Cursor
 32	// SetOnboarding controls whether the splash shows model selection UI
 33	SetOnboarding(bool)
 34	// SetProjectInit controls whether the splash shows project initialization prompt
 35	SetProjectInit(bool)
 36
 37	// Showing API key input
 38	IsShowingAPIKey() bool
 39}
 40
 41const (
 42	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 43
 44	LogoGap = 6
 45)
 46
 47// OnboardingCompleteMsg is sent when onboarding is complete
 48type OnboardingCompleteMsg struct{}
 49
 50type splashCmp struct {
 51	width, height int
 52	keyMap        KeyMap
 53	logoRendered  string
 54
 55	// State
 56	isOnboarding     bool
 57	needsProjectInit bool
 58	needsAPIKey      bool
 59	selectedNo       bool
 60
 61	listHeight    int
 62	modelList     *models.ModelListComponent
 63	apiKeyInput   *models.APIKeyInput
 64	selectedModel *models.ModelOption
 65}
 66
 67func New() Splash {
 68	keyMap := DefaultKeyMap()
 69	listKeyMap := list.DefaultKeyMap()
 70	listKeyMap.Down.SetEnabled(false)
 71	listKeyMap.Up.SetEnabled(false)
 72	listKeyMap.HalfPageDown.SetEnabled(false)
 73	listKeyMap.HalfPageUp.SetEnabled(false)
 74	listKeyMap.Home.SetEnabled(false)
 75	listKeyMap.End.SetEnabled(false)
 76	listKeyMap.DownOneItem = keyMap.Next
 77	listKeyMap.UpOneItem = keyMap.Previous
 78
 79	t := styles.CurrentTheme()
 80	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 81	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 82	apiKeyInput := models.NewAPIKeyInput()
 83
 84	return &splashCmp{
 85		width:        0,
 86		height:       0,
 87		keyMap:       keyMap,
 88		logoRendered: "",
 89		modelList:    modelList,
 90		apiKeyInput:  apiKeyInput,
 91		selectedNo:   false,
 92	}
 93}
 94
 95func (s *splashCmp) SetOnboarding(onboarding bool) {
 96	s.isOnboarding = onboarding
 97	if onboarding {
 98		providers, err := config.Providers()
 99		if err != nil {
100			return
101		}
102		filteredProviders := []provider.Provider{}
103		simpleProviders := []string{
104			"anthropic",
105			"openai",
106			"gemini",
107			"xai",
108			"groq",
109			"openrouter",
110		}
111		for _, p := range providers {
112			if slices.Contains(simpleProviders, string(p.ID)) {
113				filteredProviders = append(filteredProviders, p)
114			}
115		}
116		s.modelList.SetProviders(filteredProviders)
117	}
118}
119
120func (s *splashCmp) SetProjectInit(needsInit bool) {
121	s.needsProjectInit = needsInit
122}
123
124// GetSize implements SplashPage.
125func (s *splashCmp) GetSize() (int, int) {
126	return s.width, s.height
127}
128
129// Init implements SplashPage.
130func (s *splashCmp) Init() tea.Cmd {
131	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
132}
133
134// SetSize implements SplashPage.
135func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
136	wasSmallScreen := s.isSmallScreen()
137	rerenderLogo := width != s.width
138	s.height = height
139	s.width = width
140	if rerenderLogo || wasSmallScreen != s.isSmallScreen() {
141		s.logoRendered = s.logoBlock()
142	}
143	// remove padding, logo height, gap, title space
144	s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
145	listWidth := min(60, width)
146	return s.modelList.SetSize(listWidth, s.listHeight)
147}
148
149// Update implements SplashPage.
150func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
151	switch msg := msg.(type) {
152	case tea.WindowSizeMsg:
153		return s, s.SetSize(msg.Width, msg.Height)
154	case tea.KeyPressMsg:
155		switch {
156		case key.Matches(msg, s.keyMap.Back):
157			if s.needsAPIKey {
158				// Go back to model selection
159				s.needsAPIKey = false
160				s.selectedModel = nil
161				return s, nil
162			}
163		case key.Matches(msg, s.keyMap.Select):
164			if s.isOnboarding && !s.needsAPIKey {
165				modelInx := s.modelList.SelectedIndex()
166				items := s.modelList.Items()
167				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
168				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
169					cmd := s.setPreferredModel(selectedItem)
170					s.isOnboarding = false
171					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
172				} else {
173					// Provider not configured, show API key input
174					s.needsAPIKey = true
175					s.selectedModel = &selectedItem
176					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
177					return s, nil
178				}
179			} else if s.needsAPIKey {
180				// Handle API key submission
181				apiKey := s.apiKeyInput.Value()
182				if apiKey != "" {
183					return s, s.saveAPIKeyAndContinue(apiKey)
184				}
185			} else if s.needsProjectInit {
186				return s, s.initializeProject()
187			}
188		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
189			if s.needsProjectInit {
190				s.selectedNo = !s.selectedNo
191				return s, nil
192			}
193		case key.Matches(msg, s.keyMap.Yes):
194			if s.needsProjectInit {
195				return s, s.initializeProject()
196			}
197		case key.Matches(msg, s.keyMap.No):
198			s.selectedNo = true
199			return s, s.initializeProject()
200		default:
201			if s.needsAPIKey {
202				u, cmd := s.apiKeyInput.Update(msg)
203				s.apiKeyInput = u.(*models.APIKeyInput)
204				return s, cmd
205			} else if s.isOnboarding {
206				u, cmd := s.modelList.Update(msg)
207				s.modelList = u
208				return s, cmd
209			}
210		}
211	case tea.PasteMsg:
212		if s.needsAPIKey {
213			u, cmd := s.apiKeyInput.Update(msg)
214			s.apiKeyInput = u.(*models.APIKeyInput)
215			return s, cmd
216		} else if s.isOnboarding {
217			var cmd tea.Cmd
218			s.modelList, cmd = s.modelList.Update(msg)
219			return s, cmd
220		}
221	}
222	return s, nil
223}
224
225func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
226	if s.selectedModel == nil {
227		return util.ReportError(fmt.Errorf("no model selected"))
228	}
229
230	cfg := config.Get()
231	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
232	if err != nil {
233		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
234	}
235
236	// Reset API key state and continue with model selection
237	s.needsAPIKey = false
238	cmd := s.setPreferredModel(*s.selectedModel)
239	s.isOnboarding = false
240	s.selectedModel = nil
241
242	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
243}
244
245func (s *splashCmp) initializeProject() tea.Cmd {
246	s.needsProjectInit = false
247
248	if err := config.MarkProjectInitialized(); err != nil {
249		return util.ReportError(err)
250	}
251	var cmds []tea.Cmd
252
253	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
254	if !s.selectedNo {
255		cmds = append(cmds,
256			util.CmdHandler(chat.SessionClearedMsg{}),
257			util.CmdHandler(chat.SendMsg{
258				Text: prompt.Initialize(),
259			}),
260		)
261	}
262	return tea.Sequence(cmds...)
263}
264
265func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
266	cfg := config.Get()
267	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
268	if model == nil {
269		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
270	}
271
272	selectedModel := config.SelectedModel{
273		Model:           selectedItem.Model.ID,
274		Provider:        string(selectedItem.Provider.ID),
275		ReasoningEffort: model.DefaultReasoningEffort,
276		MaxTokens:       model.DefaultMaxTokens,
277	}
278
279	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
280	if err != nil {
281		return util.ReportError(err)
282	}
283
284	// Now lets automatically setup the small model
285	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
286	if err != nil {
287		return util.ReportError(err)
288	}
289	if knownProvider == nil {
290		// for local provider we just use the same model
291		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
292		if err != nil {
293			return util.ReportError(err)
294		}
295	} else {
296		smallModel := knownProvider.DefaultSmallModelID
297		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
298		// should never happen
299		if model == nil {
300			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
301			if err != nil {
302				return util.ReportError(err)
303			}
304			return nil
305		}
306		smallSelectedModel := config.SelectedModel{
307			Model:           smallModel,
308			Provider:        string(selectedItem.Provider.ID),
309			ReasoningEffort: model.DefaultReasoningEffort,
310			MaxTokens:       model.DefaultMaxTokens,
311		}
312		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
313		if err != nil {
314			return util.ReportError(err)
315		}
316	}
317	cfg.SetupAgents()
318	return nil
319}
320
321func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
322	providers, err := config.Providers()
323	if err != nil {
324		return nil, err
325	}
326	for _, p := range providers {
327		if p.ID == providerID {
328			return &p, nil
329		}
330	}
331	return nil, nil
332}
333
334func (s *splashCmp) isProviderConfigured(providerID string) bool {
335	cfg := config.Get()
336	if _, ok := cfg.Providers[providerID]; ok {
337		return true
338	}
339	return false
340}
341
342func (s *splashCmp) View() string {
343	t := styles.CurrentTheme()
344	var content string
345	if s.needsAPIKey {
346		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
347		apiKeyView := t.S().Base.PaddingLeft(1).Render(s.apiKeyInput.View())
348		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
349			lipgloss.JoinVertical(
350				lipgloss.Left,
351				apiKeyView,
352			),
353		)
354		content = lipgloss.JoinVertical(
355			lipgloss.Left,
356			s.logoRendered,
357			apiKeySelector,
358		)
359	} else if s.isOnboarding {
360		modelListView := s.modelList.View()
361		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
362		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
363			lipgloss.JoinVertical(
364				lipgloss.Left,
365				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
366				"",
367				modelListView,
368			),
369		)
370		content = lipgloss.JoinVertical(
371			lipgloss.Left,
372			s.logoRendered,
373			modelSelector,
374		)
375	} else if s.needsProjectInit {
376		titleStyle := t.S().Base.Foreground(t.FgBase)
377		bodyStyle := t.S().Base.Foreground(t.FgMuted)
378		shortcutStyle := t.S().Base.Foreground(t.Success)
379
380		initText := lipgloss.JoinVertical(
381			lipgloss.Left,
382			titleStyle.Render("Would you like to initialize this project?"),
383			"",
384			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
385			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
386			"",
387			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
388			"",
389			bodyStyle.Render("Would you like to initialize now?"),
390		)
391
392		yesButton := core.SelectableButton(core.ButtonOpts{
393			Text:           "Yep!",
394			UnderlineIndex: 0,
395			Selected:       !s.selectedNo,
396		})
397
398		noButton := core.SelectableButton(core.ButtonOpts{
399			Text:           "Nope",
400			UnderlineIndex: 0,
401			Selected:       s.selectedNo,
402		})
403
404		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
405		infoSection := s.infoSection()
406
407		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - lipgloss.Height(infoSection)
408
409		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).PaddingLeft(1).Height(remainingHeight).Render(
410			lipgloss.JoinVertical(
411				lipgloss.Left,
412				initText,
413				"",
414				buttons,
415			),
416		)
417
418		content = lipgloss.JoinVertical(
419			lipgloss.Left,
420			s.logoRendered,
421			infoSection,
422			initContent,
423		)
424	} else {
425		parts := []string{
426			s.logoRendered,
427			s.infoSection(),
428		}
429		content = lipgloss.JoinVertical(lipgloss.Left, parts...)
430	}
431
432	return t.S().Base.
433		Width(s.width).
434		Height(s.height).
435		PaddingTop(SplashScreenPaddingY).
436		PaddingBottom(SplashScreenPaddingY).
437		Render(content)
438}
439
440func (s *splashCmp) Cursor() *tea.Cursor {
441	if s.needsAPIKey {
442		cursor := s.apiKeyInput.Cursor()
443		if cursor != nil {
444			return s.moveCursor(cursor)
445		}
446	} else if s.isOnboarding {
447		cursor := s.modelList.Cursor()
448		if cursor != nil {
449			return s.moveCursor(cursor)
450		}
451	} else {
452		return nil
453	}
454	return nil
455}
456
457func (s *splashCmp) isSmallScreen() bool {
458	// Consider a screen small if either the width is less than 40 or if the
459	// height is less than 20
460	return s.width < 55 || s.height < 20
461}
462
463func (s *splashCmp) infoSection() string {
464	t := styles.CurrentTheme()
465	infoStyle := t.S().Base.PaddingLeft(2)
466	if s.isSmallScreen() {
467		infoStyle = infoStyle.MarginTop(1)
468	}
469	return infoStyle.Render(
470		lipgloss.JoinVertical(
471			lipgloss.Left,
472			s.cwd(),
473			"",
474			lipgloss.JoinHorizontal(lipgloss.Left, s.lspBlock(), s.mcpBlock()),
475			"",
476		),
477	)
478}
479
480func (s *splashCmp) logoBlock() string {
481	t := styles.CurrentTheme()
482	logoStyle := t.S().Base.Padding(0, 2).Width(s.width)
483	if s.isSmallScreen() {
484		// If the width is too small, render a smaller version of the logo
485		// NOTE: 20 is not correct because [splashCmp.height] is not the
486		// *actual* window height, instead, it is the height of the splash
487		// component and that depends on other variables like compact mode and
488		// the height of the editor.
489		return logoStyle.Render(
490			logo.SmallRender(s.width - logoStyle.GetHorizontalFrameSize()),
491		)
492	}
493	return logoStyle.Render(
494		logo.Render(version.Version, false, logo.Opts{
495			FieldColor:   t.Primary,
496			TitleColorA:  t.Secondary,
497			TitleColorB:  t.Primary,
498			CharmColor:   t.Secondary,
499			VersionColor: t.Primary,
500			Width:        s.width - logoStyle.GetHorizontalFrameSize(),
501		}),
502	)
503}
504
505func (s *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
506	if cursor == nil {
507		return nil
508	}
509	// Calculate the correct Y offset based on current state
510	logoHeight := lipgloss.Height(s.logoRendered)
511	if s.needsAPIKey {
512		infoSectionHeight := lipgloss.Height(s.infoSection())
513		baseOffset := logoHeight + SplashScreenPaddingY + infoSectionHeight
514		remainingHeight := s.height - baseOffset - lipgloss.Height(s.apiKeyInput.View()) - SplashScreenPaddingY
515		offset := baseOffset + remainingHeight
516		cursor.Y += offset
517		cursor.X = cursor.X + 1
518	} else if s.isOnboarding {
519		offset := logoHeight + SplashScreenPaddingY + s.logoGap() + 3
520		cursor.Y += offset
521		cursor.X = cursor.X + 1
522	}
523
524	return cursor
525}
526
527func (s *splashCmp) logoGap() int {
528	if s.height > 35 {
529		return LogoGap
530	}
531	return 0
532}
533
534// Bindings implements SplashPage.
535func (s *splashCmp) Bindings() []key.Binding {
536	if s.needsAPIKey {
537		return []key.Binding{
538			s.keyMap.Select,
539			s.keyMap.Back,
540		}
541	} else if s.isOnboarding {
542		return []key.Binding{
543			s.keyMap.Select,
544			s.keyMap.Next,
545			s.keyMap.Previous,
546		}
547	} else if s.needsProjectInit {
548		return []key.Binding{
549			s.keyMap.Select,
550			s.keyMap.Yes,
551			s.keyMap.No,
552			s.keyMap.Tab,
553			s.keyMap.LeftRight,
554		}
555	}
556	return []key.Binding{}
557}
558
559func (s *splashCmp) getMaxInfoWidth() int {
560	return min(s.width-2, 40) // 2 for left padding
561}
562
563func (s *splashCmp) cwd() string {
564	cwd := config.Get().WorkingDir()
565	t := styles.CurrentTheme()
566	homeDir, err := os.UserHomeDir()
567	if err == nil && cwd != homeDir {
568		cwd = strings.ReplaceAll(cwd, homeDir, "~")
569	}
570	maxWidth := s.getMaxInfoWidth()
571	return t.S().Muted.Width(maxWidth).Render(cwd)
572}
573
574func LSPList(maxWidth int) []string {
575	t := styles.CurrentTheme()
576	lspList := []string{}
577	lsp := config.Get().LSP.Sorted()
578	if len(lsp) == 0 {
579		return []string{t.S().Base.Foreground(t.Border).Render("None")}
580	}
581	for _, l := range lsp {
582		iconColor := t.Success
583		if l.LSP.Disabled {
584			iconColor = t.FgMuted
585		}
586		lspList = append(lspList,
587			core.Status(
588				core.StatusOpts{
589					IconColor:   iconColor,
590					Title:       l.Name,
591					Description: l.LSP.Command,
592				},
593				maxWidth,
594			),
595		)
596	}
597	return lspList
598}
599
600func (s *splashCmp) lspBlock() string {
601	t := styles.CurrentTheme()
602	maxWidth := s.getMaxInfoWidth() / 2
603	section := t.S().Subtle.Render("LSPs")
604	lspList := append([]string{section, ""}, LSPList(maxWidth-1)...)
605	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
606		lipgloss.JoinVertical(
607			lipgloss.Left,
608			lspList...,
609		),
610	)
611}
612
613func MCPList(maxWidth int) []string {
614	t := styles.CurrentTheme()
615	mcpList := []string{}
616	mcps := config.Get().MCP.Sorted()
617	if len(mcps) == 0 {
618		return []string{t.S().Base.Foreground(t.Border).Render("None")}
619	}
620	for _, l := range mcps {
621		iconColor := t.Success
622		if l.MCP.Disabled {
623			iconColor = t.FgMuted
624		}
625		mcpList = append(mcpList,
626			core.Status(
627				core.StatusOpts{
628					IconColor:   iconColor,
629					Title:       l.Name,
630					Description: l.MCP.Command,
631				},
632				maxWidth,
633			),
634		)
635	}
636	return mcpList
637}
638
639func (s *splashCmp) mcpBlock() string {
640	t := styles.CurrentTheme()
641	maxWidth := s.getMaxInfoWidth() / 2
642	section := t.S().Subtle.Render("MCPs")
643	mcpList := append([]string{section, ""}, MCPList(maxWidth-1)...)
644	return t.S().Base.Width(maxWidth).PaddingRight(1).Render(
645		lipgloss.JoinVertical(
646			lipgloss.Left,
647			mcpList...,
648		),
649	)
650}
651
652func (s *splashCmp) IsShowingAPIKey() bool {
653	return s.needsAPIKey
654}