agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"maps"
   9	"slices"
  10	"strings"
  11	"time"
  12
  13	"github.com/charmbracelet/catwalk/pkg/catwalk"
  14	"github.com/charmbracelet/crush/internal/config"
  15	"github.com/charmbracelet/crush/internal/csync"
  16	"github.com/charmbracelet/crush/internal/history"
  17	"github.com/charmbracelet/crush/internal/llm/prompt"
  18	"github.com/charmbracelet/crush/internal/llm/provider"
  19	"github.com/charmbracelet/crush/internal/llm/tools"
  20	"github.com/charmbracelet/crush/internal/log"
  21	"github.com/charmbracelet/crush/internal/lsp"
  22	"github.com/charmbracelet/crush/internal/message"
  23	"github.com/charmbracelet/crush/internal/permission"
  24	"github.com/charmbracelet/crush/internal/pubsub"
  25	"github.com/charmbracelet/crush/internal/session"
  26	"github.com/charmbracelet/crush/internal/shell"
  27)
  28
  29// Common errors
  30var (
  31	ErrRequestCancelled = errors.New("request canceled by user")
  32	ErrSessionBusy      = errors.New("session is currently processing another request")
  33)
  34
  35type AgentEventType string
  36
  37const (
  38	AgentEventTypeError     AgentEventType = "error"
  39	AgentEventTypeResponse  AgentEventType = "response"
  40	AgentEventTypeSummarize AgentEventType = "summarize"
  41)
  42
  43type AgentEvent struct {
  44	Type    AgentEventType
  45	Message message.Message
  46	Error   error
  47
  48	// When summarizing
  49	SessionID string
  50	Progress  string
  51	Done      bool
  52}
  53
  54type Service interface {
  55	pubsub.Suscriber[AgentEvent]
  56	Model() catwalk.Model
  57	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  58	Cancel(sessionID string)
  59	CancelAll()
  60	IsSessionBusy(sessionID string) bool
  61	IsBusy() bool
  62	Summarize(ctx context.Context, sessionID string) error
  63	UpdateModel() error
  64	QueuedPrompts(sessionID string) int
  65	ClearQueue(sessionID string)
  66}
  67
  68type agent struct {
  69	cleanupFuncs []func()
  70
  71	*pubsub.Broker[AgentEvent]
  72	agentCfg config.Agent
  73	sessions session.Service
  74	messages message.Service
  75
  76	baseTools *csync.Map[string, tools.BaseTool]
  77	mcpTools  *csync.Map[string, tools.BaseTool]
  78
  79	provider   provider.Provider
  80	providerID string
  81
  82	titleProvider       provider.Provider
  83	summarizeProvider   provider.Provider
  84	summarizeProviderID string
  85
  86	activeRequests *csync.Map[string, context.CancelFunc]
  87
  88	promptQueue *csync.Map[string, []string]
  89}
  90
  91var agentPromptMap = map[string]prompt.PromptID{
  92	"coder": prompt.PromptCoder,
  93	"task":  prompt.PromptTask,
  94}
  95
  96func NewAgent(
  97	ctx context.Context,
  98	agentCfg config.Agent,
  99	// These services are needed in the tools
 100	permissions permission.Service,
 101	sessions session.Service,
 102	messages message.Service,
 103	history history.Service,
 104	lspClients map[string]*lsp.Client,
 105) (Service, error) {
 106	cfg := config.Get()
 107
 108	var agentTool tools.BaseTool
 109	if agentCfg.ID == "coder" {
 110		taskAgentCfg := config.Get().Agents["task"]
 111		if taskAgentCfg.ID == "" {
 112			return nil, fmt.Errorf("task agent not found in config")
 113		}
 114		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 115		if err != nil {
 116			return nil, fmt.Errorf("failed to create task agent: %w", err)
 117		}
 118
 119		agentTool = NewAgentTool(taskAgent, sessions, messages)
 120	}
 121
 122	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 123	if providerCfg == nil {
 124		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 125	}
 126	model := config.Get().GetModelByType(agentCfg.Model)
 127
 128	if model == nil {
 129		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 130	}
 131
 132	promptID := agentPromptMap[agentCfg.ID]
 133	if promptID == "" {
 134		promptID = prompt.PromptDefault
 135	}
 136	opts := []provider.ProviderClientOption{
 137		provider.WithModel(agentCfg.Model),
 138		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 139	}
 140	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 141	if err != nil {
 142		return nil, err
 143	}
 144
 145	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 146	var smallModelProviderCfg *config.ProviderConfig
 147	if smallModelCfg.Provider == providerCfg.ID {
 148		smallModelProviderCfg = providerCfg
 149	} else {
 150		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 151
 152		if smallModelProviderCfg.ID == "" {
 153			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 154		}
 155	}
 156	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 157	if smallModel.ID == "" {
 158		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 159	}
 160
 161	titleOpts := []provider.ProviderClientOption{
 162		provider.WithModel(config.SelectedModelTypeSmall),
 163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 164	}
 165	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 166	if err != nil {
 167		return nil, err
 168	}
 169
 170	summarizeOpts := []provider.ProviderClientOption{
 171		provider.WithModel(config.SelectedModelTypeLarge),
 172		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 173	}
 174	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 175	if err != nil {
 176		return nil, err
 177	}
 178
 179	baseToolsFn := func() map[string]tools.BaseTool {
 180		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 181		defer func() {
 182			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 183		}()
 184
 185		// Base tools available to all agents
 186		cwd := cfg.WorkingDir()
 187		result := make(map[string]tools.BaseTool)
 188		for _, tool := range []tools.BaseTool{
 189			tools.NewBashTool(permissions, cwd),
 190			tools.NewDownloadTool(permissions, cwd),
 191			tools.NewEditTool(lspClients, permissions, history, cwd),
 192			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 193			tools.NewFetchTool(permissions, cwd),
 194			tools.NewGlobTool(cwd),
 195			tools.NewGrepTool(cwd),
 196			tools.NewLsTool(permissions, cwd),
 197			tools.NewSourcegraphTool(),
 198			tools.NewViewTool(lspClients, permissions, cwd),
 199			tools.NewWriteTool(lspClients, permissions, history, cwd),
 200		} {
 201			result[tool.Name()] = tool
 202		}
 203
 204		if len(lspClients) > 0 {
 205			diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
 206			result[diagnosticsTool.Name()] = diagnosticsTool
 207		}
 208
 209		if agentTool != nil {
 210			result[agentTool.Name()] = agentTool
 211		}
 212
 213		return result
 214	}
 215	mcpToolsFn := func() map[string]tools.BaseTool {
 216		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
 217		defer func() {
 218			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
 219		}()
 220
 221		mcpToolsOnce.Do(func() {
 222			doGetMCPTools(ctx, permissions, cfg)
 223		})
 224
 225		result := make(map[string]tools.BaseTool)
 226		for _, mcpTool := range mcpTools.Seq2() {
 227			result[mcpTool.Name()] = mcpTool
 228		}
 229		return result
 230	}
 231
 232	a := &agent{
 233		Broker:              pubsub.NewBroker[AgentEvent](),
 234		agentCfg:            agentCfg,
 235		provider:            agentProvider,
 236		providerID:          string(providerCfg.ID),
 237		messages:            messages,
 238		sessions:            sessions,
 239		titleProvider:       titleProvider,
 240		summarizeProvider:   summarizeProvider,
 241		summarizeProviderID: string(providerCfg.ID),
 242		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 243		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 244		baseTools:           csync.NewLazyMap(baseToolsFn),
 245		promptQueue:         csync.NewMap[string, []string](),
 246	}
 247	a.setupEvents(ctx)
 248	return a, nil
 249}
 250
 251func (a *agent) Model() catwalk.Model {
 252	return *config.Get().GetModelByType(a.agentCfg.Model)
 253}
 254
 255func (a *agent) Cancel(sessionID string) {
 256	// Cancel regular requests
 257	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 258		slog.Info("Request cancellation initiated", "session_id", sessionID)
 259		cancel()
 260	}
 261
 262	// Also check for summarize requests
 263	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 264		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 265		cancel()
 266	}
 267
 268	if a.QueuedPrompts(sessionID) > 0 {
 269		slog.Info("Clearing queued prompts", "session_id", sessionID)
 270		a.promptQueue.Del(sessionID)
 271	}
 272}
 273
 274func (a *agent) IsBusy() bool {
 275	var busy bool
 276	for cancelFunc := range a.activeRequests.Seq() {
 277		if cancelFunc != nil {
 278			busy = true
 279			break
 280		}
 281	}
 282	return busy
 283}
 284
 285func (a *agent) IsSessionBusy(sessionID string) bool {
 286	_, busy := a.activeRequests.Get(sessionID)
 287	return busy
 288}
 289
 290func (a *agent) QueuedPrompts(sessionID string) int {
 291	l, ok := a.promptQueue.Get(sessionID)
 292	if !ok {
 293		return 0
 294	}
 295	return len(l)
 296}
 297
 298func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 299	if content == "" {
 300		return nil
 301	}
 302	if a.titleProvider == nil {
 303		return nil
 304	}
 305	session, err := a.sessions.Get(ctx, sessionID)
 306	if err != nil {
 307		return err
 308	}
 309	parts := []message.ContentPart{message.TextContent{
 310		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 311	}}
 312
 313	// Use streaming approach like summarization
 314	response := a.titleProvider.StreamResponse(
 315		ctx,
 316		[]message.Message{
 317			{
 318				Role:  message.User,
 319				Parts: parts,
 320			},
 321		},
 322		nil,
 323	)
 324
 325	var finalResponse *provider.ProviderResponse
 326	for r := range response {
 327		if r.Error != nil {
 328			return r.Error
 329		}
 330		finalResponse = r.Response
 331	}
 332
 333	if finalResponse == nil {
 334		return fmt.Errorf("no response received from title provider")
 335	}
 336
 337	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 338	if title == "" {
 339		return nil
 340	}
 341
 342	session.Title = title
 343	_, err = a.sessions.Save(ctx, session)
 344	return err
 345}
 346
 347func (a *agent) err(err error) AgentEvent {
 348	return AgentEvent{
 349		Type:  AgentEventTypeError,
 350		Error: err,
 351	}
 352}
 353
 354func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 355	if !a.Model().SupportsImages && attachments != nil {
 356		attachments = nil
 357	}
 358	events := make(chan AgentEvent)
 359	if a.IsSessionBusy(sessionID) {
 360		existing, ok := a.promptQueue.Get(sessionID)
 361		if !ok {
 362			existing = []string{}
 363		}
 364		existing = append(existing, content)
 365		a.promptQueue.Set(sessionID, existing)
 366		return nil, nil
 367	}
 368
 369	genCtx, cancel := context.WithCancel(ctx)
 370
 371	a.activeRequests.Set(sessionID, cancel)
 372	go func() {
 373		slog.Debug("Request started", "sessionID", sessionID)
 374		defer log.RecoverPanic("agent.Run", func() {
 375			events <- a.err(fmt.Errorf("panic while running the agent"))
 376		})
 377		var attachmentParts []message.ContentPart
 378		for _, attachment := range attachments {
 379			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 380		}
 381		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 382		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 383			slog.Error(result.Error.Error())
 384		}
 385		slog.Debug("Request completed", "sessionID", sessionID)
 386		a.activeRequests.Del(sessionID)
 387		cancel()
 388		a.Publish(pubsub.CreatedEvent, result)
 389		select {
 390		case events <- result:
 391		case <-genCtx.Done():
 392		}
 393		close(events)
 394	}()
 395	return events, nil
 396}
 397
 398func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 399	cfg := config.Get()
 400	// List existing messages; if none, start title generation asynchronously.
 401	msgs, err := a.messages.List(ctx, sessionID)
 402	if err != nil {
 403		return a.err(fmt.Errorf("failed to list messages: %w", err))
 404	}
 405	if len(msgs) == 0 {
 406		go func() {
 407			defer log.RecoverPanic("agent.Run", func() {
 408				slog.Error("panic while generating title")
 409			})
 410			titleErr := a.generateTitle(context.Background(), sessionID, content)
 411			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 412				slog.Error("failed to generate title", "error", titleErr)
 413			}
 414		}()
 415	}
 416	session, err := a.sessions.Get(ctx, sessionID)
 417	if err != nil {
 418		return a.err(fmt.Errorf("failed to get session: %w", err))
 419	}
 420	if session.SummaryMessageID != "" {
 421		summaryMsgInex := -1
 422		for i, msg := range msgs {
 423			if msg.ID == session.SummaryMessageID {
 424				summaryMsgInex = i
 425				break
 426			}
 427		}
 428		if summaryMsgInex != -1 {
 429			msgs = msgs[summaryMsgInex:]
 430			msgs[0].Role = message.User
 431		}
 432	}
 433
 434	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 435	if err != nil {
 436		return a.err(fmt.Errorf("failed to create user message: %w", err))
 437	}
 438	// Append the new user message to the conversation history.
 439	msgHistory := append(msgs, userMsg)
 440
 441	for {
 442		// Check for cancellation before each iteration
 443		select {
 444		case <-ctx.Done():
 445			return a.err(ctx.Err())
 446		default:
 447			// Continue processing
 448		}
 449		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 450		if err != nil {
 451			if errors.Is(err, context.Canceled) {
 452				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 453				a.messages.Update(context.Background(), agentMessage)
 454				return a.err(ErrRequestCancelled)
 455			}
 456			return a.err(fmt.Errorf("failed to process events: %w", err))
 457		}
 458		if cfg.Options.Debug {
 459			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 460		}
 461		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 462			// We are not done, we need to respond with the tool response
 463			msgHistory = append(msgHistory, agentMessage, *toolResults)
 464			// If there are queued prompts, process the next one
 465			nextPrompt, ok := a.promptQueue.Take(sessionID)
 466			if ok {
 467				for _, prompt := range nextPrompt {
 468					// Create a new user message for the queued prompt
 469					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 470					if err != nil {
 471						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 472					}
 473					// Append the new user message to the conversation history
 474					msgHistory = append(msgHistory, userMsg)
 475				}
 476			}
 477
 478			continue
 479		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 480			queuePrompts, ok := a.promptQueue.Take(sessionID)
 481			if ok {
 482				for _, prompt := range queuePrompts {
 483					if prompt == "" {
 484						continue
 485					}
 486					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 487					if err != nil {
 488						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 489					}
 490					msgHistory = append(msgHistory, userMsg)
 491				}
 492				continue
 493			}
 494		}
 495		if agentMessage.FinishReason() == "" {
 496			// Kujtim: could not track down where this is happening but this means its cancelled
 497			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 498			_ = a.messages.Update(context.Background(), agentMessage)
 499			return a.err(ErrRequestCancelled)
 500		}
 501		return AgentEvent{
 502			Type:    AgentEventTypeResponse,
 503			Message: agentMessage,
 504			Done:    true,
 505		}
 506	}
 507}
 508
 509func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 510	parts := []message.ContentPart{message.TextContent{Text: content}}
 511	parts = append(parts, attachmentParts...)
 512	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 513		Role:  message.User,
 514		Parts: parts,
 515	})
 516}
 517
 518func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 519	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 520
 521	// Create the assistant message first so the spinner shows immediately
 522	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 523		Role:     message.Assistant,
 524		Parts:    []message.ContentPart{},
 525		Model:    a.Model().ID,
 526		Provider: a.providerID,
 527	})
 528	if err != nil {
 529		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 530	}
 531
 532	// Now collect tools (which may block on MCP initialization)
 533	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools())
 534
 535	// Add the session and message ID into the context if needed by tools.
 536	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 537
 538	// Process each event in the stream.
 539	for event := range eventChan {
 540		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 541			if errors.Is(processErr, context.Canceled) {
 542				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 543			} else {
 544				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 545			}
 546			return assistantMsg, nil, processErr
 547		}
 548		if ctx.Err() != nil {
 549			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 550			return assistantMsg, nil, ctx.Err()
 551		}
 552	}
 553
 554	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 555	toolCalls := assistantMsg.ToolCalls()
 556	for i, toolCall := range toolCalls {
 557		select {
 558		case <-ctx.Done():
 559			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 560			// Make all future tool calls cancelled
 561			for j := i; j < len(toolCalls); j++ {
 562				toolResults[j] = message.ToolResult{
 563					ToolCallID: toolCalls[j].ID,
 564					Content:    "Tool execution canceled by user",
 565					IsError:    true,
 566				}
 567			}
 568			goto out
 569		default:
 570			// Continue processing
 571			tool, ok := a.getToolByName(toolCall.Name)
 572
 573			// Tool not found
 574			if !ok {
 575				toolResults[i] = message.ToolResult{
 576					ToolCallID: toolCall.ID,
 577					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 578					IsError:    true,
 579				}
 580				continue
 581			}
 582
 583			// Run tool in goroutine to allow cancellation
 584			type toolExecResult struct {
 585				response tools.ToolResponse
 586				err      error
 587			}
 588			resultChan := make(chan toolExecResult, 1)
 589
 590			go func() {
 591				response, err := tool.Run(ctx, tools.ToolCall{
 592					ID:    toolCall.ID,
 593					Name:  toolCall.Name,
 594					Input: toolCall.Input,
 595				})
 596				resultChan <- toolExecResult{response: response, err: err}
 597			}()
 598
 599			var toolResponse tools.ToolResponse
 600			var toolErr error
 601
 602			select {
 603			case <-ctx.Done():
 604				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 605				// Mark remaining tool calls as cancelled
 606				for j := i; j < len(toolCalls); j++ {
 607					toolResults[j] = message.ToolResult{
 608						ToolCallID: toolCalls[j].ID,
 609						Content:    "Tool execution canceled by user",
 610						IsError:    true,
 611					}
 612				}
 613				goto out
 614			case result := <-resultChan:
 615				toolResponse = result.response
 616				toolErr = result.err
 617			}
 618
 619			if toolErr != nil {
 620				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 621				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 622					toolResults[i] = message.ToolResult{
 623						ToolCallID: toolCall.ID,
 624						Content:    "Permission denied",
 625						IsError:    true,
 626					}
 627					for j := i + 1; j < len(toolCalls); j++ {
 628						toolResults[j] = message.ToolResult{
 629							ToolCallID: toolCalls[j].ID,
 630							Content:    "Tool execution canceled by user",
 631							IsError:    true,
 632						}
 633					}
 634					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 635					break
 636				}
 637			}
 638			toolResults[i] = message.ToolResult{
 639				ToolCallID: toolCall.ID,
 640				Content:    toolResponse.Content,
 641				Metadata:   toolResponse.Metadata,
 642				IsError:    toolResponse.IsError,
 643			}
 644		}
 645	}
 646out:
 647	if len(toolResults) == 0 {
 648		return assistantMsg, nil, nil
 649	}
 650	parts := make([]message.ContentPart, 0)
 651	for _, tr := range toolResults {
 652		parts = append(parts, tr)
 653	}
 654	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 655		Role:     message.Tool,
 656		Parts:    parts,
 657		Provider: a.providerID,
 658	})
 659	if err != nil {
 660		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 661	}
 662
 663	return assistantMsg, &msg, err
 664}
 665
 666func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 667	msg.AddFinish(finishReason, message, details)
 668	_ = a.messages.Update(ctx, *msg)
 669}
 670
 671func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 672	select {
 673	case <-ctx.Done():
 674		return ctx.Err()
 675	default:
 676		// Continue processing.
 677	}
 678
 679	switch event.Type {
 680	case provider.EventThinkingDelta:
 681		assistantMsg.AppendReasoningContent(event.Thinking)
 682		return a.messages.Update(ctx, *assistantMsg)
 683	case provider.EventSignatureDelta:
 684		assistantMsg.AppendReasoningSignature(event.Signature)
 685		return a.messages.Update(ctx, *assistantMsg)
 686	case provider.EventContentDelta:
 687		assistantMsg.FinishThinking()
 688		assistantMsg.AppendContent(event.Content)
 689		return a.messages.Update(ctx, *assistantMsg)
 690	case provider.EventToolUseStart:
 691		assistantMsg.FinishThinking()
 692		slog.Info("Tool call started", "toolCall", event.ToolCall)
 693		assistantMsg.AddToolCall(*event.ToolCall)
 694		return a.messages.Update(ctx, *assistantMsg)
 695	case provider.EventToolUseDelta:
 696		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 697		return a.messages.Update(ctx, *assistantMsg)
 698	case provider.EventToolUseStop:
 699		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 700		assistantMsg.FinishToolCall(event.ToolCall.ID)
 701		return a.messages.Update(ctx, *assistantMsg)
 702	case provider.EventError:
 703		return event.Error
 704	case provider.EventComplete:
 705		assistantMsg.FinishThinking()
 706		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 707		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 708		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 709			return fmt.Errorf("failed to update message: %w", err)
 710		}
 711		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 712	}
 713
 714	return nil
 715}
 716
 717func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 718	sess, err := a.sessions.Get(ctx, sessionID)
 719	if err != nil {
 720		return fmt.Errorf("failed to get session: %w", err)
 721	}
 722
 723	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 724		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 725		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 726		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 727
 728	sess.Cost += cost
 729	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 730	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 731
 732	_, err = a.sessions.Save(ctx, sess)
 733	if err != nil {
 734		return fmt.Errorf("failed to save session: %w", err)
 735	}
 736	return nil
 737}
 738
 739func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 740	if a.summarizeProvider == nil {
 741		return fmt.Errorf("summarize provider not available")
 742	}
 743
 744	// Check if session is busy
 745	if a.IsSessionBusy(sessionID) {
 746		return ErrSessionBusy
 747	}
 748
 749	// Create a new context with cancellation
 750	summarizeCtx, cancel := context.WithCancel(ctx)
 751
 752	// Store the cancel function in activeRequests to allow cancellation
 753	a.activeRequests.Set(sessionID+"-summarize", cancel)
 754
 755	go func() {
 756		defer a.activeRequests.Del(sessionID + "-summarize")
 757		defer cancel()
 758		event := AgentEvent{
 759			Type:     AgentEventTypeSummarize,
 760			Progress: "Starting summarization...",
 761		}
 762
 763		a.Publish(pubsub.CreatedEvent, event)
 764		// Get all messages from the session
 765		msgs, err := a.messages.List(summarizeCtx, sessionID)
 766		if err != nil {
 767			event = AgentEvent{
 768				Type:  AgentEventTypeError,
 769				Error: fmt.Errorf("failed to list messages: %w", err),
 770				Done:  true,
 771			}
 772			a.Publish(pubsub.CreatedEvent, event)
 773			return
 774		}
 775		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 776
 777		if len(msgs) == 0 {
 778			event = AgentEvent{
 779				Type:  AgentEventTypeError,
 780				Error: fmt.Errorf("no messages to summarize"),
 781				Done:  true,
 782			}
 783			a.Publish(pubsub.CreatedEvent, event)
 784			return
 785		}
 786
 787		event = AgentEvent{
 788			Type:     AgentEventTypeSummarize,
 789			Progress: "Analyzing conversation...",
 790		}
 791		a.Publish(pubsub.CreatedEvent, event)
 792
 793		// Add a system message to guide the summarization
 794		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 795
 796		// Create a new message with the summarize prompt
 797		promptMsg := message.Message{
 798			Role:  message.User,
 799			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 800		}
 801
 802		// Append the prompt to the messages
 803		msgsWithPrompt := append(msgs, promptMsg)
 804
 805		event = AgentEvent{
 806			Type:     AgentEventTypeSummarize,
 807			Progress: "Generating summary...",
 808		}
 809
 810		a.Publish(pubsub.CreatedEvent, event)
 811
 812		// Send the messages to the summarize provider
 813		response := a.summarizeProvider.StreamResponse(
 814			summarizeCtx,
 815			msgsWithPrompt,
 816			nil,
 817		)
 818		var finalResponse *provider.ProviderResponse
 819		for r := range response {
 820			if r.Error != nil {
 821				event = AgentEvent{
 822					Type:  AgentEventTypeError,
 823					Error: fmt.Errorf("failed to summarize: %w", err),
 824					Done:  true,
 825				}
 826				a.Publish(pubsub.CreatedEvent, event)
 827				return
 828			}
 829			finalResponse = r.Response
 830		}
 831
 832		summary := strings.TrimSpace(finalResponse.Content)
 833		if summary == "" {
 834			event = AgentEvent{
 835				Type:  AgentEventTypeError,
 836				Error: fmt.Errorf("empty summary returned"),
 837				Done:  true,
 838			}
 839			a.Publish(pubsub.CreatedEvent, event)
 840			return
 841		}
 842		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 843		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 844		event = AgentEvent{
 845			Type:     AgentEventTypeSummarize,
 846			Progress: "Creating new session...",
 847		}
 848
 849		a.Publish(pubsub.CreatedEvent, event)
 850		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 851		if err != nil {
 852			event = AgentEvent{
 853				Type:  AgentEventTypeError,
 854				Error: fmt.Errorf("failed to get session: %w", err),
 855				Done:  true,
 856			}
 857
 858			a.Publish(pubsub.CreatedEvent, event)
 859			return
 860		}
 861		// Create a message in the new session with the summary
 862		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 863			Role: message.Assistant,
 864			Parts: []message.ContentPart{
 865				message.TextContent{Text: summary},
 866				message.Finish{
 867					Reason: message.FinishReasonEndTurn,
 868					Time:   time.Now().Unix(),
 869				},
 870			},
 871			Model:    a.summarizeProvider.Model().ID,
 872			Provider: a.summarizeProviderID,
 873		})
 874		if err != nil {
 875			event = AgentEvent{
 876				Type:  AgentEventTypeError,
 877				Error: fmt.Errorf("failed to create summary message: %w", err),
 878				Done:  true,
 879			}
 880
 881			a.Publish(pubsub.CreatedEvent, event)
 882			return
 883		}
 884		oldSession.SummaryMessageID = msg.ID
 885		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 886		oldSession.PromptTokens = 0
 887		model := a.summarizeProvider.Model()
 888		usage := finalResponse.Usage
 889		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 890			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 891			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 892			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 893		oldSession.Cost += cost
 894		_, err = a.sessions.Save(summarizeCtx, oldSession)
 895		if err != nil {
 896			event = AgentEvent{
 897				Type:  AgentEventTypeError,
 898				Error: fmt.Errorf("failed to save session: %w", err),
 899				Done:  true,
 900			}
 901			a.Publish(pubsub.CreatedEvent, event)
 902		}
 903
 904		event = AgentEvent{
 905			Type:      AgentEventTypeSummarize,
 906			SessionID: oldSession.ID,
 907			Progress:  "Summary complete",
 908			Done:      true,
 909		}
 910		a.Publish(pubsub.CreatedEvent, event)
 911		// Send final success event with the new session ID
 912	}()
 913
 914	return nil
 915}
 916
 917func (a *agent) ClearQueue(sessionID string) {
 918	if a.QueuedPrompts(sessionID) > 0 {
 919		slog.Info("Clearing queued prompts", "session_id", sessionID)
 920		a.promptQueue.Del(sessionID)
 921	}
 922}
 923
 924func (a *agent) CancelAll() {
 925	if !a.IsBusy() {
 926		return
 927	}
 928	for key := range a.activeRequests.Seq2() {
 929		a.Cancel(key) // key is sessionID
 930	}
 931
 932	for _, cleanup := range a.cleanupFuncs {
 933		if cleanup != nil {
 934			cleanup()
 935		}
 936	}
 937
 938	timeout := time.After(5 * time.Second)
 939	for a.IsBusy() {
 940		select {
 941		case <-timeout:
 942			return
 943		default:
 944			time.Sleep(200 * time.Millisecond)
 945		}
 946	}
 947}
 948
 949func (a *agent) UpdateModel() error {
 950	cfg := config.Get()
 951
 952	// Get current provider configuration
 953	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 954	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 955		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 956	}
 957
 958	// Check if provider has changed
 959	if string(currentProviderCfg.ID) != a.providerID {
 960		// Provider changed, need to recreate the main provider
 961		model := cfg.GetModelByType(a.agentCfg.Model)
 962		if model.ID == "" {
 963			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 964		}
 965
 966		promptID := agentPromptMap[a.agentCfg.ID]
 967		if promptID == "" {
 968			promptID = prompt.PromptDefault
 969		}
 970
 971		opts := []provider.ProviderClientOption{
 972			provider.WithModel(a.agentCfg.Model),
 973			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 974		}
 975
 976		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 977		if err != nil {
 978			return fmt.Errorf("failed to create new provider: %w", err)
 979		}
 980
 981		// Update the provider and provider ID
 982		a.provider = newProvider
 983		a.providerID = string(currentProviderCfg.ID)
 984	}
 985
 986	// Check if providers have changed for title (small) and summarize (large)
 987	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 988	var smallModelProviderCfg config.ProviderConfig
 989	for p := range cfg.Providers.Seq() {
 990		if p.ID == smallModelCfg.Provider {
 991			smallModelProviderCfg = p
 992			break
 993		}
 994	}
 995	if smallModelProviderCfg.ID == "" {
 996		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 997	}
 998
 999	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1000	var largeModelProviderCfg config.ProviderConfig
1001	for p := range cfg.Providers.Seq() {
1002		if p.ID == largeModelCfg.Provider {
1003			largeModelProviderCfg = p
1004			break
1005		}
1006	}
1007	if largeModelProviderCfg.ID == "" {
1008		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1009	}
1010
1011	// Recreate title provider
1012	titleOpts := []provider.ProviderClientOption{
1013		provider.WithModel(config.SelectedModelTypeSmall),
1014		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1015		provider.WithMaxTokens(40),
1016	}
1017	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1018	if err != nil {
1019		return fmt.Errorf("failed to create new title provider: %w", err)
1020	}
1021	a.titleProvider = newTitleProvider
1022
1023	// Recreate summarize provider if provider changed (now large model)
1024	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1025		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1026		if largeModel == nil {
1027			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1028		}
1029		summarizeOpts := []provider.ProviderClientOption{
1030			provider.WithModel(config.SelectedModelTypeLarge),
1031			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1032		}
1033		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1034		if err != nil {
1035			return fmt.Errorf("failed to create new summarize provider: %w", err)
1036		}
1037		a.summarizeProvider = newSummarizeProvider
1038		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1039	}
1040
1041	return nil
1042}
1043
1044func (a *agent) allTools() []tools.BaseTool {
1045	result := slices.Collect(a.baseTools.Seq())
1046	result = slices.AppendSeq(result, a.mcpTools.Seq())
1047	return result
1048}
1049
1050func (a *agent) getToolByName(name string) (tools.BaseTool, bool) {
1051	tool, ok := a.baseTools.Get(name)
1052	if !ok {
1053		tool, ok = a.mcpTools.Get(name)
1054	}
1055	return tool, ok
1056}
1057
1058func (a *agent) setupEvents(ctx context.Context) {
1059	ctx, cancel := context.WithCancel(ctx)
1060
1061	go func() {
1062		for event := range SubscribeMCPEvents(ctx) {
1063			switch event.Payload.Type {
1064			case MCPEventToolsListChanged:
1065				name := event.Payload.Name
1066				c, ok := mcpClients.Get(name)
1067				if !ok {
1068					slog.Warn("MCP client not found for tools update", "name", name)
1069					continue
1070				}
1071				tools := getTools(ctx, name, c)
1072				updateMcpTools(name, tools)
1073				// Update the lazy map with the new tools
1074				a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
1075				updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1076			default:
1077				continue
1078			}
1079		}
1080	}()
1081
1082	a.cleanupFuncs = append(a.cleanupFuncs, func() {
1083		cancel()
1084	})
1085}