diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 65655a21385baa5d98926ed62ba89ba0aac2c539..830ed73d2a8134c12b3f823c116052c683847c74 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -16,6 +16,7 @@ import ( "fmt" "log/slog" "os" + "regexp" "strconv" "strings" "sync" @@ -39,12 +40,17 @@ import ( "github.com/charmbracelet/crush/internal/stringext" ) +const defaultSessionName = "Untitled Session" + //go:embed templates/title.md var titlePrompt []byte //go:embed templates/summary.md var summaryPrompt []byte +// Used to remove tags from generated titles. +var thinkTagRegex = regexp.MustCompile(`.*?`) + type SessionAgentCall struct { SessionID string Prompt string @@ -725,49 +731,68 @@ func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.S return msgs, nil } -func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) { - if prompt == "" { +// generateTitle generates a session titled based on the initial prompt. +func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) { + if userPrompt == "" { return } - var maxOutput int64 = 40 + var maxOutputTokens int64 = 40 if a.smallModel.CatwalkCfg.CanReason { - maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens + maxOutputTokens = a.smallModel.CatwalkCfg.DefaultMaxTokens } - agent := fantasy.NewAgent(a.smallModel.Model, - fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"), - fantasy.WithMaxOutputTokens(maxOutput), - ) + newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent { + return fantasy.NewAgent(m, + fantasy.WithSystemPrompt(string(p)+"\n /no_think"), + fantasy.WithMaxOutputTokens(tok), + ) + } - resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{ - Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n \n\n", prompt), - PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { - prepared.Messages = options.Messages + streamCall := fantasy.AgentStreamCall{ + Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n \n\n", userPrompt), + PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) { + prepared.Messages = opts.Messages if a.systemPromptPrefix != "" { - prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...) + prepared.Messages = append([]fantasy.Message{ + fantasy.NewSystemMessage(a.systemPromptPrefix), + }, prepared.Messages...) } - return callContext, prepared, nil + return callCtx, prepared, nil }, - }) - if err != nil { - slog.Error("error generating title", "err", err) - return } - title := resp.Response.Content.Text() + // Use the small model to generate the title. + model := &a.smallModel + agent := newAgent(model.Model, titlePrompt, maxOutputTokens) + resp, err := agent.Stream(ctx, streamCall) + if err == nil { + // We successfully generated a title with the small model. + slog.Info("generated title with small model") + } else { + // It didn't work. Let's try with the big model. + slog.Error("error generating title with small model; trying big model", "err", err) + model = &a.largeModel + agent = newAgent(model.Model, titlePrompt, maxOutputTokens) + resp, err = agent.Stream(ctx, streamCall) + if err == nil { + slog.Info("generated title with large model") + } else { + // Welp, the large model didn't work either. + slog.Error("error generating title with large model", "err", err) + } + } - title = strings.ReplaceAll(title, "\n", " ") + title := strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ") + slog.Info("generated title", "title", title) // Remove thinking tags if present. - if idx := strings.Index(title, ""); idx > 0 { - title = title[idx+len(""):] - } + title = thinkTagRegex.ReplaceAllString(title, "") title = strings.TrimSpace(title) if title == "" { - slog.Warn("failed to generate title", "warn", "empty title") - return + slog.Warn("empty title; using fallback") + title = defaultSessionName } // Calculate usage and cost. @@ -783,7 +808,11 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prom } } - modelConfig := a.smallModel.CatwalkCfg + if model == nil { + slog.Error("no model available for cost calculation") + return + } + modelConfig := model.CatwalkCfg cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) + modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) + modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) + @@ -805,7 +834,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prom // concurrent session updates. saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost) if saveErr != nil { - slog.Error("failed to save session title & usage", "error", saveErr) + slog.Error("failed to save session title and usage", "error", saveErr) return } } @@ -947,6 +976,7 @@ func (a *sessionAgent) promptPrefix() string { return a.systemPromptPrefix } +// XXX: this should be generalized to cover other subscription plans, like Copilot. func (a *sessionAgent) isClaudeCode() bool { cfg := config.Get() pc, ok := cfg.Providers.Get(a.largeModel.ModelCfg.Provider) diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 8c54b028f90326ac8cee1cacb0df2377528e4a2b..d86e60c8cdcb0f6d87b7c97a6e40e83bddffeace 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -961,7 +961,10 @@ func (p *chatPage) sendMessage(text string, attachments []message.Attachment) te session := p.session var cmds []tea.Cmd if p.session.ID == "" { - newSession, err := p.app.Sessions.Create(context.Background(), "New Session") + // XXX: The second argument here is the session name, which we leave + // blank as it will be auto-generated. Ideally, we remove the need for + // that argument entirely. + newSession, err := p.app.Sessions.Create(context.Background(), "") if err != nil { return util.ReportError(err) }