agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.Map[string, tools.BaseTool]
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84
  85	promptQueue *csync.Map[string, []string]
  86}
  87
  88var agentPromptMap = map[string]prompt.PromptID{
  89	"coder": prompt.PromptCoder,
  90	"task":  prompt.PromptTask,
  91}
  92
  93func NewAgent(
  94	ctx context.Context,
  95	agentCfg config.Agent,
  96	// These services are needed in the tools
  97	permissions permission.Service,
  98	sessions session.Service,
  99	messages message.Service,
 100	history history.Service,
 101	lspClients map[string]*lsp.Client,
 102) (Service, error) {
 103	cfg := config.Get()
 104
 105	var agentTool tools.BaseTool
 106	if agentCfg.ID == "coder" {
 107		taskAgentCfg := config.Get().Agents["task"]
 108		if taskAgentCfg.ID == "" {
 109			return nil, fmt.Errorf("task agent not found in config")
 110		}
 111		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112		if err != nil {
 113			return nil, fmt.Errorf("failed to create task agent: %w", err)
 114		}
 115
 116		agentTool = NewAgentTool(taskAgent, sessions, messages)
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() *csync.Map[string, tools.BaseTool] {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		toolMap := csync.NewMap[string, tools.BaseTool]()
 184
 185		// Base tools available to all agents
 186		baseTools := []tools.BaseTool{
 187			tools.NewBashTool(permissions, cwd),
 188			tools.NewDownloadTool(permissions, cwd),
 189			tools.NewEditTool(lspClients, permissions, history, cwd),
 190			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 191			tools.NewFetchTool(permissions, cwd),
 192			tools.NewGlobTool(cwd),
 193			tools.NewGrepTool(cwd),
 194			tools.NewLsTool(permissions, cwd),
 195			tools.NewSourcegraphTool(),
 196			tools.NewViewTool(lspClients, permissions, cwd),
 197			tools.NewWriteTool(lspClients, permissions, history, cwd),
 198		}
 199		for _, tool := range baseTools {
 200			toolMap.Set(tool.Name(), tool)
 201		}
 202
 203		mcpToolsOnce.Do(func() {
 204			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 205		})
 206		for _, mcpTool := range mcpTools {
 207			toolMap.Set(mcpTool.Name(), mcpTool)
 208		}
 209
 210		if len(lspClients) > 0 {
 211			diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
 212			toolMap.Set(diagnosticsTool.Name(), diagnosticsTool)
 213		}
 214
 215		if agentTool != nil {
 216			toolMap.Set(agentTool.Name(), agentTool)
 217		}
 218
 219		if agentCfg.AllowedTools != nil {
 220			// Filter tools based on allowed tools list
 221			for toolName := range toolMap.Seq2() {
 222				if !slices.Contains(agentCfg.AllowedTools, toolName) {
 223					toolMap.Del(toolName)
 224				}
 225			}
 226		}
 227		return toolMap
 228	}
 229
 230	return &agent{
 231		Broker:              pubsub.NewBroker[AgentEvent](),
 232		agentCfg:            agentCfg,
 233		provider:            agentProvider,
 234		providerID:          string(providerCfg.ID),
 235		messages:            messages,
 236		sessions:            sessions,
 237		titleProvider:       titleProvider,
 238		summarizeProvider:   summarizeProvider,
 239		summarizeProviderID: string(providerCfg.ID),
 240		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 241		tools:               toolFn(),
 242		promptQueue:         csync.NewMap[string, []string](),
 243	}, nil
 244}
 245
 246func (a *agent) Model() catwalk.Model {
 247	return *config.Get().GetModelByType(a.agentCfg.Model)
 248}
 249
 250func (a *agent) Cancel(sessionID string) {
 251	// Cancel regular requests
 252	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 253		slog.Info("Request cancellation initiated", "session_id", sessionID)
 254		cancel()
 255	}
 256
 257	// Also check for summarize requests
 258	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 259		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 260		cancel()
 261	}
 262
 263	if a.QueuedPrompts(sessionID) > 0 {
 264		slog.Info("Clearing queued prompts", "session_id", sessionID)
 265		a.promptQueue.Del(sessionID)
 266	}
 267}
 268
 269func (a *agent) IsBusy() bool {
 270	var busy bool
 271	for cancelFunc := range a.activeRequests.Seq() {
 272		if cancelFunc != nil {
 273			busy = true
 274			break
 275		}
 276	}
 277	return busy
 278}
 279
 280func (a *agent) IsSessionBusy(sessionID string) bool {
 281	_, busy := a.activeRequests.Get(sessionID)
 282	return busy
 283}
 284
 285func (a *agent) QueuedPrompts(sessionID string) int {
 286	l, ok := a.promptQueue.Get(sessionID)
 287	if !ok {
 288		return 0
 289	}
 290	return len(l)
 291}
 292
 293func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 294	if content == "" {
 295		return nil
 296	}
 297	if a.titleProvider == nil {
 298		return nil
 299	}
 300	session, err := a.sessions.Get(ctx, sessionID)
 301	if err != nil {
 302		return err
 303	}
 304	parts := []message.ContentPart{message.TextContent{
 305		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 306	}}
 307
 308	// Use streaming approach like summarization
 309	response := a.titleProvider.StreamResponse(
 310		ctx,
 311		[]message.Message{
 312			{
 313				Role:  message.User,
 314				Parts: parts,
 315			},
 316		},
 317		nil,
 318	)
 319
 320	var finalResponse *provider.ProviderResponse
 321	for r := range response {
 322		if r.Error != nil {
 323			return r.Error
 324		}
 325		finalResponse = r.Response
 326	}
 327
 328	if finalResponse == nil {
 329		return fmt.Errorf("no response received from title provider")
 330	}
 331
 332	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 333	if title == "" {
 334		return nil
 335	}
 336
 337	session.Title = title
 338	_, err = a.sessions.Save(ctx, session)
 339	return err
 340}
 341
 342func (a *agent) err(err error) AgentEvent {
 343	return AgentEvent{
 344		Type:  AgentEventTypeError,
 345		Error: err,
 346	}
 347}
 348
 349func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 350	if !a.Model().SupportsImages && attachments != nil {
 351		attachments = nil
 352	}
 353	events := make(chan AgentEvent)
 354	if a.IsSessionBusy(sessionID) {
 355		existing, ok := a.promptQueue.Get(sessionID)
 356		if !ok {
 357			existing = []string{}
 358		}
 359		existing = append(existing, content)
 360		a.promptQueue.Set(sessionID, existing)
 361		return nil, nil
 362	}
 363
 364	genCtx, cancel := context.WithCancel(ctx)
 365
 366	a.activeRequests.Set(sessionID, cancel)
 367	go func() {
 368		slog.Debug("Request started", "sessionID", sessionID)
 369		defer log.RecoverPanic("agent.Run", func() {
 370			events <- a.err(fmt.Errorf("panic while running the agent"))
 371		})
 372		var attachmentParts []message.ContentPart
 373		for _, attachment := range attachments {
 374			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 375		}
 376		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 377		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 378			slog.Error(result.Error.Error())
 379		}
 380		slog.Debug("Request completed", "sessionID", sessionID)
 381		a.activeRequests.Del(sessionID)
 382		cancel()
 383		a.Publish(pubsub.CreatedEvent, result)
 384		select {
 385		case events <- result:
 386		case <-genCtx.Done():
 387		}
 388		close(events)
 389	}()
 390	return events, nil
 391}
 392
 393func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 394	cfg := config.Get()
 395	// List existing messages; if none, start title generation asynchronously.
 396	msgs, err := a.messages.List(ctx, sessionID)
 397	if err != nil {
 398		return a.err(fmt.Errorf("failed to list messages: %w", err))
 399	}
 400	if len(msgs) == 0 {
 401		go func() {
 402			defer log.RecoverPanic("agent.Run", func() {
 403				slog.Error("panic while generating title")
 404			})
 405			titleErr := a.generateTitle(context.Background(), sessionID, content)
 406			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 407				slog.Error("failed to generate title", "error", titleErr)
 408			}
 409		}()
 410	}
 411	session, err := a.sessions.Get(ctx, sessionID)
 412	if err != nil {
 413		return a.err(fmt.Errorf("failed to get session: %w", err))
 414	}
 415	if session.SummaryMessageID != "" {
 416		summaryMsgInex := -1
 417		for i, msg := range msgs {
 418			if msg.ID == session.SummaryMessageID {
 419				summaryMsgInex = i
 420				break
 421			}
 422		}
 423		if summaryMsgInex != -1 {
 424			msgs = msgs[summaryMsgInex:]
 425			msgs[0].Role = message.User
 426		}
 427	}
 428
 429	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 430	if err != nil {
 431		return a.err(fmt.Errorf("failed to create user message: %w", err))
 432	}
 433	// Append the new user message to the conversation history.
 434	msgHistory := append(msgs, userMsg)
 435
 436	for {
 437		// Check for cancellation before each iteration
 438		select {
 439		case <-ctx.Done():
 440			return a.err(ctx.Err())
 441		default:
 442			// Continue processing
 443		}
 444		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 445		if err != nil {
 446			if errors.Is(err, context.Canceled) {
 447				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 448				a.messages.Update(context.Background(), agentMessage)
 449				return a.err(ErrRequestCancelled)
 450			}
 451			return a.err(fmt.Errorf("failed to process events: %w", err))
 452		}
 453		if cfg.Options.Debug {
 454			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 455		}
 456		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 457			// We are not done, we need to respond with the tool response
 458			msgHistory = append(msgHistory, agentMessage, *toolResults)
 459			// If there are queued prompts, process the next one
 460			nextPrompt, ok := a.promptQueue.Take(sessionID)
 461			if ok {
 462				for _, prompt := range nextPrompt {
 463					// Create a new user message for the queued prompt
 464					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 465					if err != nil {
 466						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 467					}
 468					// Append the new user message to the conversation history
 469					msgHistory = append(msgHistory, userMsg)
 470				}
 471			}
 472
 473			continue
 474		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 475			queuePrompts, ok := a.promptQueue.Take(sessionID)
 476			if ok {
 477				for _, prompt := range queuePrompts {
 478					if prompt == "" {
 479						continue
 480					}
 481					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 482					if err != nil {
 483						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 484					}
 485					msgHistory = append(msgHistory, userMsg)
 486				}
 487				continue
 488			}
 489		}
 490		if agentMessage.FinishReason() == "" {
 491			// Kujtim: could not track down where this is happening but this means its cancelled
 492			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 493			_ = a.messages.Update(context.Background(), agentMessage)
 494			return a.err(ErrRequestCancelled)
 495		}
 496		return AgentEvent{
 497			Type:    AgentEventTypeResponse,
 498			Message: agentMessage,
 499			Done:    true,
 500		}
 501	}
 502}
 503
 504func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 505	parts := []message.ContentPart{message.TextContent{Text: content}}
 506	parts = append(parts, attachmentParts...)
 507	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 508		Role:  message.User,
 509		Parts: parts,
 510	})
 511}
 512
 513func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 514	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 515
 516	// Create the assistant message first so the spinner shows immediately
 517	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 518		Role:     message.Assistant,
 519		Parts:    []message.ContentPart{},
 520		Model:    a.Model().ID,
 521		Provider: a.providerID,
 522	})
 523	if err != nil {
 524		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 525	}
 526
 527	// Now collect tools (which may block on MCP initialization)
 528	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
 529
 530	// Add the session and message ID into the context if needed by tools.
 531	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 532
 533	// Process each event in the stream.
 534	for event := range eventChan {
 535		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 536			if errors.Is(processErr, context.Canceled) {
 537				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 538			} else {
 539				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 540			}
 541			return assistantMsg, nil, processErr
 542		}
 543		if ctx.Err() != nil {
 544			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 545			return assistantMsg, nil, ctx.Err()
 546		}
 547	}
 548
 549	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 550	toolCalls := assistantMsg.ToolCalls()
 551	for i, toolCall := range toolCalls {
 552		select {
 553		case <-ctx.Done():
 554			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 555			// Make all future tool calls cancelled
 556			for j := i; j < len(toolCalls); j++ {
 557				toolResults[j] = message.ToolResult{
 558					ToolCallID: toolCalls[j].ID,
 559					Content:    "Tool execution canceled by user",
 560					IsError:    true,
 561				}
 562			}
 563			goto out
 564		default:
 565			// Continue processing
 566			tool, ok := a.tools.Get(toolCall.Name)
 567
 568			// Tool not found
 569			if !ok {
 570				toolResults[i] = message.ToolResult{
 571					ToolCallID: toolCall.ID,
 572					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 573					IsError:    true,
 574				}
 575				continue
 576			}
 577
 578			// Run tool in goroutine to allow cancellation
 579			type toolExecResult struct {
 580				response tools.ToolResponse
 581				err      error
 582			}
 583			resultChan := make(chan toolExecResult, 1)
 584
 585			go func() {
 586				response, err := tool.Run(ctx, tools.ToolCall{
 587					ID:    toolCall.ID,
 588					Name:  toolCall.Name,
 589					Input: toolCall.Input,
 590				})
 591				resultChan <- toolExecResult{response: response, err: err}
 592			}()
 593
 594			var toolResponse tools.ToolResponse
 595			var toolErr error
 596
 597			select {
 598			case <-ctx.Done():
 599				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 600				// Mark remaining tool calls as cancelled
 601				for j := i; j < len(toolCalls); j++ {
 602					toolResults[j] = message.ToolResult{
 603						ToolCallID: toolCalls[j].ID,
 604						Content:    "Tool execution canceled by user",
 605						IsError:    true,
 606					}
 607				}
 608				goto out
 609			case result := <-resultChan:
 610				toolResponse = result.response
 611				toolErr = result.err
 612			}
 613
 614			if toolErr != nil {
 615				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 616				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 617					toolResults[i] = message.ToolResult{
 618						ToolCallID: toolCall.ID,
 619						Content:    "Permission denied",
 620						IsError:    true,
 621					}
 622					for j := i + 1; j < len(toolCalls); j++ {
 623						toolResults[j] = message.ToolResult{
 624							ToolCallID: toolCalls[j].ID,
 625							Content:    "Tool execution canceled by user",
 626							IsError:    true,
 627						}
 628					}
 629					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 630					break
 631				}
 632			}
 633			toolResults[i] = message.ToolResult{
 634				ToolCallID: toolCall.ID,
 635				Content:    toolResponse.Content,
 636				Metadata:   toolResponse.Metadata,
 637				IsError:    toolResponse.IsError,
 638			}
 639		}
 640	}
 641out:
 642	if len(toolResults) == 0 {
 643		return assistantMsg, nil, nil
 644	}
 645	parts := make([]message.ContentPart, 0)
 646	for _, tr := range toolResults {
 647		parts = append(parts, tr)
 648	}
 649	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 650		Role:     message.Tool,
 651		Parts:    parts,
 652		Provider: a.providerID,
 653	})
 654	if err != nil {
 655		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 656	}
 657
 658	return assistantMsg, &msg, err
 659}
 660
 661func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 662	msg.AddFinish(finishReason, message, details)
 663	_ = a.messages.Update(ctx, *msg)
 664}
 665
 666func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 667	select {
 668	case <-ctx.Done():
 669		return ctx.Err()
 670	default:
 671		// Continue processing.
 672	}
 673
 674	switch event.Type {
 675	case provider.EventThinkingDelta:
 676		assistantMsg.AppendReasoningContent(event.Thinking)
 677		return a.messages.Update(ctx, *assistantMsg)
 678	case provider.EventSignatureDelta:
 679		assistantMsg.AppendReasoningSignature(event.Signature)
 680		return a.messages.Update(ctx, *assistantMsg)
 681	case provider.EventContentDelta:
 682		assistantMsg.FinishThinking()
 683		assistantMsg.AppendContent(event.Content)
 684		return a.messages.Update(ctx, *assistantMsg)
 685	case provider.EventToolUseStart:
 686		assistantMsg.FinishThinking()
 687		slog.Info("Tool call started", "toolCall", event.ToolCall)
 688		assistantMsg.AddToolCall(*event.ToolCall)
 689		return a.messages.Update(ctx, *assistantMsg)
 690	case provider.EventToolUseDelta:
 691		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 692		return a.messages.Update(ctx, *assistantMsg)
 693	case provider.EventToolUseStop:
 694		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 695		assistantMsg.FinishToolCall(event.ToolCall.ID)
 696		return a.messages.Update(ctx, *assistantMsg)
 697	case provider.EventError:
 698		return event.Error
 699	case provider.EventComplete:
 700		assistantMsg.FinishThinking()
 701		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 702		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 703		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 704			return fmt.Errorf("failed to update message: %w", err)
 705		}
 706		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 707	}
 708
 709	return nil
 710}
 711
 712func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 713	sess, err := a.sessions.Get(ctx, sessionID)
 714	if err != nil {
 715		return fmt.Errorf("failed to get session: %w", err)
 716	}
 717
 718	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 719		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 720		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 721		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 722
 723	sess.Cost += cost
 724	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 725	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 726
 727	_, err = a.sessions.Save(ctx, sess)
 728	if err != nil {
 729		return fmt.Errorf("failed to save session: %w", err)
 730	}
 731	return nil
 732}
 733
 734func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 735	if a.summarizeProvider == nil {
 736		return fmt.Errorf("summarize provider not available")
 737	}
 738
 739	// Check if session is busy
 740	if a.IsSessionBusy(sessionID) {
 741		return ErrSessionBusy
 742	}
 743
 744	// Create a new context with cancellation
 745	summarizeCtx, cancel := context.WithCancel(ctx)
 746
 747	// Store the cancel function in activeRequests to allow cancellation
 748	a.activeRequests.Set(sessionID+"-summarize", cancel)
 749
 750	go func() {
 751		defer a.activeRequests.Del(sessionID + "-summarize")
 752		defer cancel()
 753		event := AgentEvent{
 754			Type:     AgentEventTypeSummarize,
 755			Progress: "Starting summarization...",
 756		}
 757
 758		a.Publish(pubsub.CreatedEvent, event)
 759		// Get all messages from the session
 760		msgs, err := a.messages.List(summarizeCtx, sessionID)
 761		if err != nil {
 762			event = AgentEvent{
 763				Type:  AgentEventTypeError,
 764				Error: fmt.Errorf("failed to list messages: %w", err),
 765				Done:  true,
 766			}
 767			a.Publish(pubsub.CreatedEvent, event)
 768			return
 769		}
 770		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 771
 772		if len(msgs) == 0 {
 773			event = AgentEvent{
 774				Type:  AgentEventTypeError,
 775				Error: fmt.Errorf("no messages to summarize"),
 776				Done:  true,
 777			}
 778			a.Publish(pubsub.CreatedEvent, event)
 779			return
 780		}
 781
 782		event = AgentEvent{
 783			Type:     AgentEventTypeSummarize,
 784			Progress: "Analyzing conversation...",
 785		}
 786		a.Publish(pubsub.CreatedEvent, event)
 787
 788		// Add a system message to guide the summarization
 789		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 790
 791		// Create a new message with the summarize prompt
 792		promptMsg := message.Message{
 793			Role:  message.User,
 794			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 795		}
 796
 797		// Append the prompt to the messages
 798		msgsWithPrompt := append(msgs, promptMsg)
 799
 800		event = AgentEvent{
 801			Type:     AgentEventTypeSummarize,
 802			Progress: "Generating summary...",
 803		}
 804
 805		a.Publish(pubsub.CreatedEvent, event)
 806
 807		// Send the messages to the summarize provider
 808		response := a.summarizeProvider.StreamResponse(
 809			summarizeCtx,
 810			msgsWithPrompt,
 811			nil,
 812		)
 813		var finalResponse *provider.ProviderResponse
 814		for r := range response {
 815			if r.Error != nil {
 816				event = AgentEvent{
 817					Type:  AgentEventTypeError,
 818					Error: fmt.Errorf("failed to summarize: %w", err),
 819					Done:  true,
 820				}
 821				a.Publish(pubsub.CreatedEvent, event)
 822				return
 823			}
 824			finalResponse = r.Response
 825		}
 826
 827		summary := strings.TrimSpace(finalResponse.Content)
 828		if summary == "" {
 829			event = AgentEvent{
 830				Type:  AgentEventTypeError,
 831				Error: fmt.Errorf("empty summary returned"),
 832				Done:  true,
 833			}
 834			a.Publish(pubsub.CreatedEvent, event)
 835			return
 836		}
 837		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 838		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 839		event = AgentEvent{
 840			Type:     AgentEventTypeSummarize,
 841			Progress: "Creating new session...",
 842		}
 843
 844		a.Publish(pubsub.CreatedEvent, event)
 845		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 846		if err != nil {
 847			event = AgentEvent{
 848				Type:  AgentEventTypeError,
 849				Error: fmt.Errorf("failed to get session: %w", err),
 850				Done:  true,
 851			}
 852
 853			a.Publish(pubsub.CreatedEvent, event)
 854			return
 855		}
 856		// Create a message in the new session with the summary
 857		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 858			Role: message.Assistant,
 859			Parts: []message.ContentPart{
 860				message.TextContent{Text: summary},
 861				message.Finish{
 862					Reason: message.FinishReasonEndTurn,
 863					Time:   time.Now().Unix(),
 864				},
 865			},
 866			Model:    a.summarizeProvider.Model().ID,
 867			Provider: a.summarizeProviderID,
 868		})
 869		if err != nil {
 870			event = AgentEvent{
 871				Type:  AgentEventTypeError,
 872				Error: fmt.Errorf("failed to create summary message: %w", err),
 873				Done:  true,
 874			}
 875
 876			a.Publish(pubsub.CreatedEvent, event)
 877			return
 878		}
 879		oldSession.SummaryMessageID = msg.ID
 880		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 881		oldSession.PromptTokens = 0
 882		model := a.summarizeProvider.Model()
 883		usage := finalResponse.Usage
 884		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 885			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 886			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 887			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 888		oldSession.Cost += cost
 889		_, err = a.sessions.Save(summarizeCtx, oldSession)
 890		if err != nil {
 891			event = AgentEvent{
 892				Type:  AgentEventTypeError,
 893				Error: fmt.Errorf("failed to save session: %w", err),
 894				Done:  true,
 895			}
 896			a.Publish(pubsub.CreatedEvent, event)
 897		}
 898
 899		event = AgentEvent{
 900			Type:      AgentEventTypeSummarize,
 901			SessionID: oldSession.ID,
 902			Progress:  "Summary complete",
 903			Done:      true,
 904		}
 905		a.Publish(pubsub.CreatedEvent, event)
 906		// Send final success event with the new session ID
 907	}()
 908
 909	return nil
 910}
 911
 912func (a *agent) ClearQueue(sessionID string) {
 913	if a.QueuedPrompts(sessionID) > 0 {
 914		slog.Info("Clearing queued prompts", "session_id", sessionID)
 915		a.promptQueue.Del(sessionID)
 916	}
 917}
 918
 919func (a *agent) CancelAll() {
 920	if !a.IsBusy() {
 921		return
 922	}
 923	for key := range a.activeRequests.Seq2() {
 924		a.Cancel(key) // key is sessionID
 925	}
 926
 927	timeout := time.After(5 * time.Second)
 928	for a.IsBusy() {
 929		select {
 930		case <-timeout:
 931			return
 932		default:
 933			time.Sleep(200 * time.Millisecond)
 934		}
 935	}
 936}
 937
 938func (a *agent) UpdateModel() error {
 939	cfg := config.Get()
 940
 941	// Get current provider configuration
 942	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 943	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 944		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 945	}
 946
 947	// Check if provider has changed
 948	if string(currentProviderCfg.ID) != a.providerID {
 949		// Provider changed, need to recreate the main provider
 950		model := cfg.GetModelByType(a.agentCfg.Model)
 951		if model.ID == "" {
 952			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 953		}
 954
 955		promptID := agentPromptMap[a.agentCfg.ID]
 956		if promptID == "" {
 957			promptID = prompt.PromptDefault
 958		}
 959
 960		opts := []provider.ProviderClientOption{
 961			provider.WithModel(a.agentCfg.Model),
 962			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 963		}
 964
 965		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 966		if err != nil {
 967			return fmt.Errorf("failed to create new provider: %w", err)
 968		}
 969
 970		// Update the provider and provider ID
 971		a.provider = newProvider
 972		a.providerID = string(currentProviderCfg.ID)
 973	}
 974
 975	// Check if providers have changed for title (small) and summarize (large)
 976	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 977	var smallModelProviderCfg config.ProviderConfig
 978	for p := range cfg.Providers.Seq() {
 979		if p.ID == smallModelCfg.Provider {
 980			smallModelProviderCfg = p
 981			break
 982		}
 983	}
 984	if smallModelProviderCfg.ID == "" {
 985		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 986	}
 987
 988	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
 989	var largeModelProviderCfg config.ProviderConfig
 990	for p := range cfg.Providers.Seq() {
 991		if p.ID == largeModelCfg.Provider {
 992			largeModelProviderCfg = p
 993			break
 994		}
 995	}
 996	if largeModelProviderCfg.ID == "" {
 997		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
 998	}
 999
1000	// Recreate title provider
1001	titleOpts := []provider.ProviderClientOption{
1002		provider.WithModel(config.SelectedModelTypeSmall),
1003		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1004		provider.WithMaxTokens(40),
1005	}
1006	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1007	if err != nil {
1008		return fmt.Errorf("failed to create new title provider: %w", err)
1009	}
1010	a.titleProvider = newTitleProvider
1011
1012	// Recreate summarize provider if provider changed (now large model)
1013	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1014		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1015		if largeModel == nil {
1016			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1017		}
1018		summarizeOpts := []provider.ProviderClientOption{
1019			provider.WithModel(config.SelectedModelTypeLarge),
1020			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1021		}
1022		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1023		if err != nil {
1024			return fmt.Errorf("failed to create new summarize provider: %w", err)
1025		}
1026		a.summarizeProvider = newSummarizeProvider
1027		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1028	}
1029
1030	return nil
1031}