fix(sessions): generate title with large model if small model fails

Christian Rocha created

Change summary

internal/agent/agent.go        | 76 ++++++++++++++++++++++++-----------
internal/tui/page/chat/chat.go |  5 +
2 files changed, 56 insertions(+), 25 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -40,6 +40,8 @@ import (
 	"github.com/charmbracelet/crush/internal/stringext"
 )
 
+const defaultSessionName = "Untitled Session"
+
 //go:embed templates/title.md
 var titlePrompt []byte
 
@@ -729,47 +731,68 @@ func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.S
 	return msgs, nil
 }
 
-func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
-	if prompt == "" {
+// generateTitle generates a session titled based on the initial prompt.
+func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
+	if userPrompt == "" {
 		return
 	}
 
-	var maxOutput int64 = 40
+	var maxOutputTokens int64 = 40
 	if a.smallModel.CatwalkCfg.CanReason {
-		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
+		maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
 	}
 
-	agent := fantasy.NewAgent(a.smallModel.Model,
-		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
-		fantasy.WithMaxOutputTokens(maxOutput),
-	)
+	newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
+		return fantasy.NewAgent(m,
+			fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
+			fantasy.WithMaxOutputTokens(tok),
+		)
+	}
 
-	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
-		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
-		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
-			prepared.Messages = options.Messages
+	streamCall := fantasy.AgentStreamCall{
+		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
+		PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
+			prepared.Messages = opts.Messages
 			if a.systemPromptPrefix != "" {
-				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
+				prepared.Messages = append([]fantasy.Message{
+					fantasy.NewSystemMessage(a.systemPromptPrefix),
+				}, prepared.Messages...)
 			}
-			return callContext, prepared, nil
+			return callCtx, prepared, nil
 		},
-	})
-	if err != nil {
-		slog.Error("error generating title", "err", err)
-		return
 	}
 
-	title := resp.Response.Content.Text()
+	// Use the small model to generate the title.
+	model := &a.smallModel
+	agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
+	resp, err := agent.Stream(ctx, streamCall)
+	if err == nil {
+		// We successfully generated a title with the small model.
+		slog.Info("generated title with small model")
+	} else {
+		// It didn't work. Let's try with the big model.
+		slog.Error("error generating title with small model; trying big model", "err", err)
+		model = &a.largeModel
+		agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
+		resp, err = agent.Stream(ctx, streamCall)
+		if err == nil {
+			slog.Info("generated title with large model")
+		} else {
+			// Welp, the large model didn't work either.
+			slog.Error("error generating title with large model", "err", err)
+		}
+	}
 
-	title = strings.ReplaceAll(title, "\n", " ")
+	title := strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
+	slog.Info("generated title", "title", title)
 
 	// Remove thinking tags if present.
 	title = thinkTagRegex.ReplaceAllString(title, "")
 
 	title = strings.TrimSpace(title)
 	if title == "" {
-		slog.Warn("failed to generate title", "warn", "empty title")
-		return
+		slog.Warn("empty title; using fallback")
+		title = defaultSessionName
 	}
 
 	// Calculate usage and cost.
@@ -785,7 +808,11 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prom
 		}
 	}
 
-	modelConfig := a.smallModel.CatwalkCfg
+	if model == nil {
+		slog.Error("no model available for cost calculation")
+		return
+	}
+	modelConfig := model.CatwalkCfg
 	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
 		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
 		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
@@ -807,7 +834,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prom
 	// concurrent session updates.
 	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 	if saveErr != nil {
-		slog.Error("failed to save session title & usage", "error", saveErr)
+		slog.Error("failed to save session title and usage", "error", saveErr)
 		return
 	}
 }
@@ -949,6 +976,7 @@ func (a *sessionAgent) promptPrefix() string {
 	return a.systemPromptPrefix
 }
 
+// XXX: this should be generalized to cover other subscription plans, like Copilot.
 func (a *sessionAgent) isClaudeCode() bool {
 	cfg := config.Get()
 	pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)

internal/tui/page/chat/chat.go 🔗

@@ -961,7 +961,10 @@ func (p *chatPage) sendMessage(text string, attachments []message.Attachment) te
 	session := p.session
 	var cmds []tea.Cmd
 	if p.session.ID == "" {
-		newSession, err := p.app.Sessions.Create(context.Background(), "New Session")
+		// XXX: The second argument here is the session name, which we leave
+		// blank as it will be auto-generated. Ideally, we remove the need for
+		// that argument entirely.
+		newSession, err := p.app.Sessions.Create(context.Background(), "")
 		if err != nil {
 			return util.ReportError(err)
 		}