1package splash
2
3import (
4 "fmt"
5 "slices"
6
7 "github.com/charmbracelet/bubbles/v2/key"
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/fur/provider"
11 "github.com/charmbracelet/crush/internal/tui/components/chat"
12 "github.com/charmbracelet/crush/internal/tui/components/completions"
13 "github.com/charmbracelet/crush/internal/tui/components/core"
14 "github.com/charmbracelet/crush/internal/tui/components/core/layout"
15 "github.com/charmbracelet/crush/internal/tui/components/core/list"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
17 "github.com/charmbracelet/crush/internal/tui/components/logo"
18 "github.com/charmbracelet/crush/internal/tui/styles"
19 "github.com/charmbracelet/crush/internal/tui/util"
20 "github.com/charmbracelet/crush/internal/version"
21 "github.com/charmbracelet/lipgloss/v2"
22)
23
24type Splash interface {
25 util.Model
26 layout.Sizeable
27 layout.Help
28 Cursor() *tea.Cursor
29 // SetOnboarding controls whether the splash shows model selection UI
30 SetOnboarding(bool)
31 // SetProjectInit controls whether the splash shows project initialization prompt
32 SetProjectInit(bool)
33}
34
35const (
36 SplashScreenPaddingX = 2 // Padding X for the splash screen
37 SplashScreenPaddingY = 1 // Padding Y for the splash screen
38)
39
40// OnboardingCompleteMsg is sent when onboarding is complete
41type OnboardingCompleteMsg struct{}
42
43type splashCmp struct {
44 width, height int
45 keyMap KeyMap
46 logoRendered string
47
48 // State
49 isOnboarding bool
50 needsProjectInit bool
51 selectedNo bool
52
53 modelList *models.ModelListComponent
54 cursorRow, cursorCol int
55}
56
57func New() Splash {
58 keyMap := DefaultKeyMap()
59 listKeyMap := list.DefaultKeyMap()
60 listKeyMap.Down.SetEnabled(false)
61 listKeyMap.Up.SetEnabled(false)
62 listKeyMap.HalfPageDown.SetEnabled(false)
63 listKeyMap.HalfPageUp.SetEnabled(false)
64 listKeyMap.Home.SetEnabled(false)
65 listKeyMap.End.SetEnabled(false)
66 listKeyMap.DownOneItem = keyMap.Next
67 listKeyMap.UpOneItem = keyMap.Previous
68
69 t := styles.CurrentTheme()
70 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
71 modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
72 return &splashCmp{
73 width: 0,
74 height: 0,
75 keyMap: keyMap,
76 logoRendered: "",
77 modelList: modelList,
78 selectedNo: false,
79 }
80}
81
82func (s *splashCmp) SetOnboarding(onboarding bool) {
83 s.isOnboarding = onboarding
84 if onboarding {
85 providers, err := config.Providers()
86 if err != nil {
87 return
88 }
89 filteredProviders := []provider.Provider{}
90 simpleProviders := []string{
91 "anthropic",
92 "openai",
93 "gemini",
94 "xai",
95 "openrouter",
96 }
97 for _, p := range providers {
98 if slices.Contains(simpleProviders, string(p.ID)) {
99 filteredProviders = append(filteredProviders, p)
100 }
101 }
102 s.modelList.SetProviders(filteredProviders)
103 }
104}
105
106func (s *splashCmp) SetProjectInit(needsInit bool) {
107 s.needsProjectInit = needsInit
108}
109
110// GetSize implements SplashPage.
111func (s *splashCmp) GetSize() (int, int) {
112 return s.width, s.height
113}
114
115// Init implements SplashPage.
116func (s *splashCmp) Init() tea.Cmd {
117 return s.modelList.Init()
118}
119
120// SetSize implements SplashPage.
121func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
122 s.width = width
123 s.height = height
124 s.logoRendered = s.logoBlock()
125 listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
126 listWidth := min(60, width-(SplashScreenPaddingX*2))
127
128 // Calculate the cursor position based on the height and logo size
129 s.cursorRow = height - listHeigh
130 return s.modelList.SetSize(listWidth, listHeigh)
131}
132
133// Update implements SplashPage.
134func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
135 switch msg := msg.(type) {
136 case tea.WindowSizeMsg:
137 return s, s.SetSize(msg.Width, msg.Height)
138 case tea.KeyPressMsg:
139 switch {
140 case key.Matches(msg, s.keyMap.Select):
141 if s.isOnboarding {
142 modelInx := s.modelList.SelectedIndex()
143 items := s.modelList.Items()
144 selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
145 if s.isProviderConfigured(string(selectedItem.Provider.ID)) {
146 cmd := s.setPreferredModel(selectedItem)
147 s.isOnboarding = false
148 return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
149 }
150 } else if s.needsProjectInit {
151 return s, s.initializeProject()
152 }
153 case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
154 if s.needsProjectInit {
155 s.selectedNo = !s.selectedNo
156 return s, nil
157 }
158 case key.Matches(msg, s.keyMap.Yes):
159 if s.needsProjectInit {
160 return s, s.initializeProject()
161 }
162 case key.Matches(msg, s.keyMap.No):
163 if s.needsProjectInit {
164 s.needsProjectInit = false
165 return s, util.CmdHandler(OnboardingCompleteMsg{})
166 }
167 default:
168 if s.isOnboarding {
169 u, cmd := s.modelList.Update(msg)
170 s.modelList = u
171 return s, cmd
172 }
173 }
174 }
175 return s, nil
176}
177
178func (s *splashCmp) initializeProject() tea.Cmd {
179 s.needsProjectInit = false
180 prompt := `Please analyze this codebase and create a CRUSH.md file containing:
1811. Build/lint/test commands - especially for running a single test
1822. Code style guidelines including imports, formatting, types, naming conventions, error handling, etc.
183
184The file you create will be given to agentic coding agents (such as yourself) that operate in this repository. Make it about 20 lines long.
185If there's already a CRUSH.md, improve it.
186If there are Cursor rules (in .cursor/rules/ or .cursorrules) or Copilot rules (in .github/copilot-instructions.md), make sure to include them.
187Add the .crush directory to the .gitignore file if it's not already there.`
188
189 if err := config.MarkProjectInitialized(); err != nil {
190 return util.ReportError(err)
191 }
192 var cmds []tea.Cmd
193
194 cmds = append(cmds, util.CmdHandler(OnboardingCompleteMsg{}))
195 if !s.selectedNo {
196 cmds = append(cmds,
197 util.CmdHandler(chat.SessionClearedMsg{}),
198 util.CmdHandler(chat.SendMsg{
199 Text: prompt,
200 }),
201 )
202 }
203 return tea.Sequence(cmds...)
204}
205
206func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
207 cfg := config.Get()
208 model := cfg.GetModel(string(selectedItem.Provider.ID), selectedItem.Model.ID)
209 if model == nil {
210 return util.ReportError(fmt.Errorf("model %s not found for provider %s", selectedItem.Model.ID, selectedItem.Provider.ID))
211 }
212
213 selectedModel := config.SelectedModel{
214 Model: selectedItem.Model.ID,
215 Provider: string(selectedItem.Provider.ID),
216 ReasoningEffort: model.DefaultReasoningEffort,
217 MaxTokens: model.DefaultMaxTokens,
218 }
219
220 err := cfg.UpdatePreferredModel(config.SelectedModelTypeLarge, selectedModel)
221 if err != nil {
222 return util.ReportError(err)
223 }
224
225 // Now lets automatically setup the small model
226 knownProvider, err := s.getProvider(selectedItem.Provider.ID)
227 if err != nil {
228 return util.ReportError(err)
229 }
230 if knownProvider == nil {
231 // for local provider we just use the same model
232 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
233 if err != nil {
234 return util.ReportError(err)
235 }
236 } else {
237 smallModel := knownProvider.DefaultSmallModelID
238 model := cfg.GetModel(string(selectedItem.Provider.ID), smallModel)
239 // should never happen
240 if model == nil {
241 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, selectedModel)
242 if err != nil {
243 return util.ReportError(err)
244 }
245 return nil
246 }
247 smallSelectedModel := config.SelectedModel{
248 Model: smallModel,
249 Provider: string(selectedItem.Provider.ID),
250 ReasoningEffort: model.DefaultReasoningEffort,
251 MaxTokens: model.DefaultMaxTokens,
252 }
253 err = cfg.UpdatePreferredModel(config.SelectedModelTypeSmall, smallSelectedModel)
254 if err != nil {
255 return util.ReportError(err)
256 }
257 }
258 return nil
259}
260
261func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
262 providers, err := config.Providers()
263 if err != nil {
264 return nil, err
265 }
266 for _, p := range providers {
267 if p.ID == providerID {
268 return &p, nil
269 }
270 }
271 return nil, nil
272}
273
274func (s *splashCmp) isProviderConfigured(providerID string) bool {
275 cfg := config.Get()
276 if _, ok := cfg.Providers[providerID]; ok {
277 return true
278 }
279 return false
280}
281
282func (s *splashCmp) View() string {
283 t := styles.CurrentTheme()
284
285 var content string
286 if s.isOnboarding {
287 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
288 modelListView := s.modelList.View()
289 modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
290 lipgloss.JoinVertical(
291 lipgloss.Left,
292 t.S().Base.PaddingLeft(1).Foreground(t.Primary).Render("Choose a Model"),
293 "",
294 modelListView,
295 ),
296 )
297 content = lipgloss.JoinVertical(
298 lipgloss.Left,
299 s.logoRendered,
300 modelSelector,
301 )
302 } else if s.needsProjectInit {
303 t := styles.CurrentTheme()
304
305 titleStyle := t.S().Base.Foreground(t.FgBase)
306 bodyStyle := t.S().Base.Foreground(t.FgMuted)
307 shortcutStyle := t.S().Base.Foreground(t.Success)
308
309 initText := lipgloss.JoinVertical(
310 lipgloss.Left,
311 titleStyle.Render("Would you like to initialize this project?"),
312 "",
313 bodyStyle.Render("When I initialize your codebase I examine the project and put the"),
314 bodyStyle.Render("result into a CRUSH.md file which serves as general context."),
315 "",
316 bodyStyle.Render("You can also initialize anytime via ")+shortcutStyle.Render("ctrl+p")+bodyStyle.Render("."),
317 "",
318 bodyStyle.Render("Would you like to initialize now?"),
319 )
320
321 yesButton := core.SelectableButton(core.ButtonOpts{
322 Text: "Yep!",
323 UnderlineIndex: 0,
324 Selected: !s.selectedNo,
325 })
326
327 noButton := core.SelectableButton(core.ButtonOpts{
328 Text: "Nope",
329 UnderlineIndex: 0,
330 Selected: s.selectedNo,
331 })
332
333 buttons := lipgloss.JoinHorizontal(lipgloss.Left, yesButton, " ", noButton)
334
335 remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
336
337 initContent := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
338 lipgloss.JoinVertical(
339 lipgloss.Left,
340 initText,
341 "",
342 buttons,
343 ),
344 )
345
346 content = lipgloss.JoinVertical(
347 lipgloss.Left,
348 s.logoRendered,
349 initContent,
350 )
351 } else {
352 content = s.logoRendered
353 }
354
355 return t.S().Base.
356 Width(s.width).
357 Height(s.height).
358 PaddingTop(SplashScreenPaddingY).
359 PaddingLeft(SplashScreenPaddingX).
360 PaddingRight(SplashScreenPaddingX).
361 PaddingBottom(SplashScreenPaddingY).
362 Render(content)
363}
364
365func (s *splashCmp) Cursor() *tea.Cursor {
366 if s.isOnboarding {
367 cursor := s.modelList.Cursor()
368 if cursor != nil {
369 return s.moveCursor(cursor)
370 }
371 } else {
372 return nil
373 }
374 return nil
375}
376
377func (s *splashCmp) logoBlock() string {
378 t := styles.CurrentTheme()
379 const padding = 2
380 return logo.Render(version.Version, false, logo.Opts{
381 FieldColor: t.Primary,
382 TitleColorA: t.Secondary,
383 TitleColorB: t.Primary,
384 CharmColor: t.Secondary,
385 VersionColor: t.Primary,
386 Width: s.width - (SplashScreenPaddingX * 2),
387 })
388}
389
390func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
391 if cursor == nil {
392 return nil
393 }
394 offset := m.cursorRow
395 cursor.Y += offset
396 cursor.X = cursor.X + 3 // 3 for padding
397 return cursor
398}
399
400// Bindings implements SplashPage.
401func (s *splashCmp) Bindings() []key.Binding {
402 if s.isOnboarding {
403 return []key.Binding{
404 s.keyMap.Select,
405 s.keyMap.Next,
406 s.keyMap.Previous,
407 }
408 } else if s.needsProjectInit {
409 return []key.Binding{
410 s.keyMap.Select,
411 s.keyMap.Yes,
412 s.keyMap.No,
413 s.keyMap.Tab,
414 s.keyMap.LeftRight,
415 }
416 }
417 return []key.Binding{}
418}