@@ -40,6 +40,8 @@ import (
"github.com/charmbracelet/crush/internal/stringext"
)
+const defaultSessionName = "Untitled Session"
+
//go:embed templates/title.md
var titlePrompt []byte
@@ -729,47 +731,68 @@ func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.S
return msgs, nil
}
-func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
- if prompt == "" {
+// generateTitle generates a session titled based on the initial prompt.
+func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
+ if userPrompt == "" {
return
}
- var maxOutput int64 = 40
+ var maxOutputTokens int64 = 40
if a.smallModel.CatwalkCfg.CanReason {
- maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
+ maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens
}
- agent := fantasy.NewAgent(a.smallModel.Model,
- fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
- fantasy.WithMaxOutputTokens(maxOutput),
- )
+ newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
+ return fantasy.NewAgent(m,
+ fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
+ fantasy.WithMaxOutputTokens(tok),
+ )
+ }
- resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
- Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
- PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
- prepared.Messages = options.Messages
+ streamCall := fantasy.AgentStreamCall{
+ Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
+ PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
+ prepared.Messages = opts.Messages
if a.systemPromptPrefix != "" {
- prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
+ prepared.Messages = append([]fantasy.Message{
+ fantasy.NewSystemMessage(a.systemPromptPrefix),
+ }, prepared.Messages...)
}
- return callContext, prepared, nil
+ return callCtx, prepared, nil
},
- })
- if err != nil {
- slog.Error("error generating title", "err", err)
- return
}
- title := resp.Response.Content.Text()
+ // Use the small model to generate the title.
+ model := &a.smallModel
+ agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
+ resp, err := agent.Stream(ctx, streamCall)
+ if err == nil {
+ // We successfully generated a title with the small model.
+ slog.Info("generated title with small model")
+ } else {
+ // It didn't work. Let's try with the big model.
+ slog.Error("error generating title with small model; trying big model", "err", err)
+ model = &a.largeModel
+ agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
+ resp, err = agent.Stream(ctx, streamCall)
+ if err == nil {
+ slog.Info("generated title with large model")
+ } else {
+ // Welp, the large model didn't work either.
+ slog.Error("error generating title with large model", "err", err)
+ }
+ }
- title = strings.ReplaceAll(title, "\n", " ")
+ title := strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
+ slog.Info("generated title", "title", title)
// Remove thinking tags if present.
title = thinkTagRegex.ReplaceAllString(title, "")
title = strings.TrimSpace(title)
if title == "" {
- slog.Warn("failed to generate title", "warn", "empty title")
- return
+ slog.Warn("empty title; using fallback")
+ title = defaultSessionName
}
// Calculate usage and cost.
@@ -785,7 +808,11 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prom
}
}
- modelConfig := a.smallModel.CatwalkCfg
+ if model == nil {
+ slog.Error("no model available for cost calculation")
+ return
+ }
+ modelConfig := model.CatwalkCfg
cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
@@ -807,7 +834,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prom
// concurrent session updates.
saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
if saveErr != nil {
- slog.Error("failed to save session title & usage", "error", saveErr)
+ slog.Error("failed to save session title and usage", "error", saveErr)
return
}
}
@@ -949,6 +976,7 @@ func (a *sessionAgent) promptPrefix() string {
return a.systemPromptPrefix
}
+// XXX: this should be generalized to cover other subscription plans, like Copilot.
func (a *sessionAgent) isClaudeCode() bool {
cfg := config.Get()
pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider)