splash.go

  1package splash
  2
  3import (
  4	"fmt"
  5	"log/slog"
  6	"slices"
  7
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	tea "github.com/charmbracelet/bubbletea/v2"
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/fur/provider"
 12	"github.com/charmbracelet/crush/internal/tui/components/chat"
 13	"github.com/charmbracelet/crush/internal/tui/components/completions"
 14	"github.com/charmbracelet/crush/internal/tui/components/core"
 15	"github.com/charmbracelet/crush/internal/tui/components/core/layout"
 16	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 17	"github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
 18	"github.com/charmbracelet/crush/internal/tui/components/logo"
 19	"github.com/charmbracelet/crush/internal/tui/styles"
 20	"github.com/charmbracelet/crush/internal/tui/util"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/charmbracelet/lipgloss/v2"
 23)
 24
 25type Splash interface {
 26	util.Model
 27	layout.Sizeable
 28	layout.Help
 29	Cursor() *tea.Cursor
 30	// SetOnboarding controls whether the splash shows model selection UI
 31	SetOnboarding(bool)
 32	// SetProjectInit controls whether the splash shows project initialization prompt
 33	SetProjectInit(bool)
 34}
 35
 36const (
 37	SplashScreenPaddingX = 2 // Padding X for the splash screen
 38	SplashScreenPaddingY = 1 // Padding Y for the splash screen
 39)
 40
 41// OnboardingCompleteMsg is sent when onboarding is complete
 42type OnboardingCompleteMsg struct{}
 43
 44type splashCmp struct {
 45	width, height int
 46	keyMap        KeyMap
 47	logoRendered  string
 48
 49	// State
 50	isOnboarding     bool
 51	needsProjectInit bool
 52	needsAPIKey      bool
 53	selectedNo       bool
 54
 55	modelList     *models.ModelListComponent
 56	apiKeyInput   *models.APIKeyInput
 57	selectedModel *models.ModelOption
 58}
 59
 60func New() Splash {
 61	keyMap := DefaultKeyMap()
 62	listKeyMap := list.DefaultKeyMap()
 63	listKeyMap.Down.SetEnabled(false)
 64	listKeyMap.Up.SetEnabled(false)
 65	listKeyMap.HalfPageDown.SetEnabled(false)
 66	listKeyMap.HalfPageUp.SetEnabled(false)
 67	listKeyMap.Home.SetEnabled(false)
 68	listKeyMap.End.SetEnabled(false)
 69	listKeyMap.DownOneItem = keyMap.Next
 70	listKeyMap.UpOneItem = keyMap.Previous
 71
 72	t := styles.CurrentTheme()
 73	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 74	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
 75	apiKeyInput := models.NewAPIKeyInput()
 76
 77	return &splashCmp{
 78		width:        0,
 79		height:       0,
 80		keyMap:       keyMap,
 81		logoRendered: "",
 82		modelList:    modelList,
 83		apiKeyInput:  apiKeyInput,
 84		selectedNo:   false,
 85	}
 86}
 87
 88func (s *splashCmp) SetOnboarding(onboarding bool) {
 89	s.isOnboarding = onboarding
 90	if onboarding {
 91		providers, err := config.Providers()
 92		if err != nil {
 93			return
 94		}
 95		filteredProviders := []provider.Provider{}
 96		simpleProviders := []string{
 97			"anthropic",
 98			"openai",
 99			"gemini",
100			"xai",
101			"openrouter",
102		}
103		for _, p := range providers {
104			if slices.Contains(simpleProviders, string(p.ID)) {
105				filteredProviders = append(filteredProviders, p)
106			}
107		}
108		s.modelList.SetProviders(filteredProviders)
109	}
110}
111
112func (s *splashCmp) SetProjectInit(needsInit bool) {
113	s.needsProjectInit = needsInit
114}
115
116// GetSize implements SplashPage.
117func (s *splashCmp) GetSize() (int, int) {
118	return s.width, s.height
119}
120
121// Init implements SplashPage.
122func (s *splashCmp) Init() tea.Cmd {
123	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
124}
125
126// SetSize implements SplashPage.
127func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
128	s.width = width
129	s.height = height
130	s.logoRendered = s.logoBlock()
131	listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
132	listWidth := min(60, width-(SplashScreenPaddingX*2))
133
134	return s.modelList.SetSize(listWidth, listHeigh)
135}
136
137// Update implements SplashPage.
138func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
139	switch msg := msg.(type) {
140	case tea.WindowSizeMsg:
141		return s, s.SetSize(msg.Width, msg.Height)
142	case tea.KeyPressMsg:
143		switch {
144		case key.Matches(msg, s.keyMap.Back):
145			slog.Info("Back key pressed in splash screen")
146			if s.needsAPIKey {
147				// Go back to model selection
148				s.needsAPIKey = false
149				s.selectedModel = nil
150				return s, nil
151			}
152		case key.Matches(msg, s.keyMap.Select):
153			if s.isOnboarding && !s.needsAPIKey {
154				modelInx := s.modelList.SelectedIndex()
155				items := s.modelList.Items()
156				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
157				if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
158					cmd := s.setPreferredModel(selectedItem)
159					s.isOnboarding = false
160					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
161				} else {
162					// Provider not configured, show API key input
163					s.needsAPIKey = true
164					s.selectedModel = &selectedItem
165					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
166					return s, nil
167				}
168			} else if s.needsAPIKey {
169				// Handle API key submission
170				apiKey := s.apiKeyInput.Value()
171				if apiKey != "" {
172					return s, s.saveAPIKeyAndContinue(apiKey)
173				}
174			} else if s.needsProjectInit {
175				return s, s.initializeProject()
176			}
177		case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
178			if s.needsProjectInit {
179				s.selectedNo = !s.selectedNo
180				return s, nil
181			}
182		case key.Matches(msg, s.keyMap.Yes):
183			if s.needsProjectInit {
184				return s, s.initializeProject()
185			}
186		case key.Matches(msg, s.keyMap.No):
187			if s.needsProjectInit {
188				s.needsProjectInit = false
189				return s, util.CmdHandler(OnboardingCompleteMsg{})
190			}
191		default:
192			if s.needsAPIKey {
193				u, cmd := s.apiKeyInput.Update(msg)
194				s.apiKeyInput = u.(*models.APIKeyInput)
195				return s, cmd
196			} else if s.isOnboarding {
197				u, cmd := s.modelList.Update(msg)
198				s.modelList = u
199				return s, cmd
200			}
201		}
202	case tea.PasteMsg:
203		if s.needsAPIKey {
204			u, cmd := s.apiKeyInput.Update(msg)
205			s.apiKeyInput = u.(*models.APIKeyInput)
206			return s, cmd
207		} else if s.isOnboarding {
208			var cmd tea.Cmd
209			s.modelList, cmd = s.modelList.Update(msg)
210			return s, cmd
211		}
212	}
213	return s, nil
214}
215
216func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
217	if s.selectedModel == nil {
218		return util.ReportError(fmt.Errorf("no model selected"))
219	}
220
221	cfg := config.Get()
222	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
223	if err != nil {
224		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
225	}
226
227	// Reset API key state and continue with model selection
228	s.needsAPIKey = false
229	cmd := s.setPreferredModel(*s.selectedModel)
230	s.isOnboarding = false
231	s.selectedModel = nil
232
233	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
234}
235
236func (s *splashCmp) initializeProject() tea.Cmd {
237	s.needsProjectInit = false
238	prompt := `Please analyze this codebase and create a CRUSH.md file containing:
2391. Build/lint/test commands - especially for running a single test
2402. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
241
242The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
243If there's already a CRUSH.md, improve it.
244If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
245Add the .crush directory to the .gitignore file if it's not already there.`
246
247	if err := config.MarkProjectInitialized(); err != nil {
248		return util.ReportError(err)
249	}
250	var cmds []tea.Cmd
251
252	cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
253	if !s.selectedNo {
254		cmds = append(cmds,
255			util.CmdHandler(chat.SessionClearedMsg{}),
256			util.CmdHandler(chat.SendMsg{
257				Text: prompt,
258			}),
259		)
260	}
261	return tea.Sequence(cmds...)
262}
263
264func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
265	cfg := config.Get()
266	model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
267	if model == nil {
268		return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
269	}
270
271	selectedModel := config.SelectedModel{
272		Model:           selectedItem.Model.ID,
273		Provider:        string(selectedItem.Provider.ID),
274		ReasoningEffort: model.DefaultReasoningEffort,
275		MaxTokens:       model.DefaultMaxTokens,
276	}
277
278	err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
279	if err != nil {
280		return util.ReportError(err)
281	}
282
283	// Now lets automatically setup the small model
284	knownProvider, err := s.getProvider(selectedItem.Provider.ID)
285	if err != nil {
286		return util.ReportError(err)
287	}
288	if knownProvider == nil {
289		// for local provider we just use the same model
290		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
291		if err != nil {
292			return util.ReportError(err)
293		}
294	} else {
295		smallModel := knownProvider.DefaultSmallModelID
296		model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
297		// should never happen
298		if model == nil {
299			err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
300			if err != nil {
301				return util.ReportError(err)
302			}
303			return nil
304		}
305		smallSelectedModel := config.SelectedModel{
306			Model:           smallModel,
307			Provider:        string(selectedItem.Provider.ID),
308			ReasoningEffort: model.DefaultReasoningEffort,
309			MaxTokens:       model.DefaultMaxTokens,
310		}
311		err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
312		if err != nil {
313			return util.ReportError(err)
314		}
315	}
316	return nil
317}
318
319func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
320	providers, err := config.Providers()
321	if err != nil {
322		return nil, err
323	}
324	for _, p := range providers {
325		if p.ID == providerID {
326			return &p, nil
327		}
328	}
329	return nil, nil
330}
331
332func (s *splashCmp) isProviderConfigured(providerID string) bool {
333	cfg := config.Get()
334	if _, ok := cfg.Providers[providerID]; ok {
335		return true
336	}
337	return false
338}
339
340func (s *splashCmp) View() string {
341	t := styles.CurrentTheme()
342
343	var content string
344	if s.needsAPIKey {
345		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
346		apiKeyView := s.apiKeyInput.View()
347		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
348			lipgloss.JoinVertical(
349				lipgloss.Left,
350				apiKeyView,
351			),
352		)
353		content = lipgloss.JoinVertical(
354			lipgloss.Left,
355			s.logoRendered,
356			apiKeySelector,
357		)
358	} else if s.isOnboarding {
359		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
360		modelListView := s.modelList.View()
361		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
362			lipgloss.JoinVertical(
363				lipgloss.Left,
364				t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
365				"",
366				modelListView,
367			),
368		)
369		content = lipgloss.JoinVertical(
370			lipgloss.Left,
371			s.logoRendered,
372			modelSelector,
373		)
374	} else if s.needsProjectInit {
375		t := styles.CurrentTheme()
376
377		titleStyle := t.S().Base.Foreground(t.FgBase)
378		bodyStyle := t.S().Base.Foreground(t.FgMuted)
379		shortcutStyle := t.S().Base.Foreground(t.Success)
380
381		initText := lipgloss.JoinVertical(
382			lipgloss.Left,
383			titleStyle.Render("Would you like to initialize this project?"),
384			"",
385			bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
386			bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
387			"",
388			bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
389			"",
390			bodyStyle.Render("Would you like to initialize now?"),
391		)
392
393		yesButton := core.SelectableButton(core.ButtonOpts{
394			Text:           "Yep!",
395			UnderlineIndex: 0,
396			Selected:       !s.selectedNo,
397		})
398
399		noButton := core.SelectableButton(core.ButtonOpts{
400			Text:           "Nope",
401			UnderlineIndex: 0,
402			Selected:       s.selectedNo,
403		})
404
405		buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, "  ", noButton)
406
407		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
408
409		initContent := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
410			lipgloss.JoinVertical(
411				lipgloss.Left,
412				initText,
413				"",
414				buttons,
415			),
416		)
417
418		content = lipgloss.JoinVertical(
419			lipgloss.Left,
420			s.logoRendered,
421			initContent,
422		)
423	} else {
424		content = s.logoRendered
425	}
426
427	return t.S().Base.
428		Width(s.width).
429		Height(s.height).
430		PaddingTop(SplashScreenPaddingY).
431		PaddingLeft(SplashScreenPaddingX).
432		PaddingRight(SplashScreenPaddingX).
433		PaddingBottom(SplashScreenPaddingY).
434		Render(content)
435}
436
437func (s *splashCmp) Cursor() *tea.Cursor {
438	if s.needsAPIKey {
439		cursor := s.apiKeyInput.Cursor()
440		if cursor != nil {
441			return s.moveCursor(cursor)
442		}
443	} else if s.isOnboarding {
444		cursor := s.modelList.Cursor()
445		if cursor != nil {
446			return s.moveCursor(cursor)
447		}
448	} else {
449		return nil
450	}
451	return nil
452}
453
454func (s *splashCmp) logoBlock() string {
455	t := styles.CurrentTheme()
456	const padding = 2
457	return logo.Render(version.Version, false, logo.Opts{
458		FieldColor:   t.Primary,
459		TitleColorA:  t.Secondary,
460		TitleColorB:  t.Primary,
461		CharmColor:   t.Secondary,
462		VersionColor: t.Primary,
463		Width:        s.width - (SplashScreenPaddingX * 2),
464	})
465}
466
467func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
468	if cursor == nil {
469		return nil
470	}
471
472	// Calculate the correct Y offset based on current state
473	logoHeight := lipgloss.Height(m.logoRendered)
474	baseOffset := logoHeight + SplashScreenPaddingY
475
476	if m.needsAPIKey {
477		// For API key input, position at the bottom of the remaining space
478		remainingHeight := m.height - logoHeight - (SplashScreenPaddingY * 2)
479		offset := baseOffset + remainingHeight - lipgloss.Height(m.apiKeyInput.View())
480		cursor.Y += offset
481		// API key input already includes prompt in its cursor positioning
482		cursor.X = cursor.X + SplashScreenPaddingX
483	} else if m.isOnboarding {
484		// For model list, use the original calculation
485		listHeight := min(40, m.height-(SplashScreenPaddingY*2)-logoHeight-2)
486		offset := m.height - listHeight
487		cursor.Y += offset
488		// Model list doesn't have a prompt, so add padding + space for list styling
489		cursor.X = cursor.X + SplashScreenPaddingX + 1
490	}
491
492	return cursor
493}
494
495// Bindings implements SplashPage.
496func (s *splashCmp) Bindings() []key.Binding {
497	if s.needsAPIKey {
498		return []key.Binding{
499			s.keyMap.Select,
500			s.keyMap.Back,
501		}
502	} else if s.isOnboarding {
503		return []key.Binding{
504			s.keyMap.Select,
505			s.keyMap.Next,
506			s.keyMap.Previous,
507		}
508	} else if s.needsProjectInit {
509		return []key.Binding{
510			s.keyMap.Select,
511			s.keyMap.Yes,
512			s.keyMap.No,
513			s.keyMap.Tab,
514			s.keyMap.LeftRight,
515		}
516	}
517	return []key.Binding{}
518}