agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	globalCtx    context.Context
  69	cleanupFuncs []func()
  70
  71	*pubsub.Broker[AgentEvent]
  72	agentCfg config.Agent
  73	sessions session.Service
  74	messages message.Service
  75
  76	baseTools *csync.Map[string, tools.BaseTool]
  77	mcpTools  *csync.Map[string, tools.BaseTool]
  78
  79	provider   provider.Provider
  80	providerID string
  81
  82	titleProvider       provider.Provider
  83	summarizeProvider   provider.Provider
  84	summarizeProviderID string
  85
  86	activeRequests *csync.Map[string, context.CancelFunc]
  87
  88	promptQueue *csync.Map[string, []string]
  89}
  90
  91var agentPromptMap = map[string]prompt.PromptID{
  92	"coder": prompt.PromptCoder,
  93	"task":  prompt.PromptTask,
  94}
  95
  96func NewAgent(
  97	ctx context.Context,
  98	agentCfg config.Agent,
  99	// These services are needed in the tools
 100	permissions permission.Service,
 101	sessions session.Service,
 102	messages message.Service,
 103	history history.Service,
 104	lspClients map[string]*lsp.Client,
 105) (Service, error) {
 106	cfg := config.Get()
 107
 108	var agentTool tools.BaseTool
 109	if agentCfg.ID == "coder" {
 110		taskAgentCfg := config.Get().Agents["task"]
 111		if taskAgentCfg.ID == "" {
 112			return nil, fmt.Errorf("task agent not found in config")
 113		}
 114		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 115		if err != nil {
 116			return nil, fmt.Errorf("failed to create task agent: %w", err)
 117		}
 118
 119		agentTool = NewAgentTool(taskAgent, sessions, messages)
 120	}
 121
 122	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 123	if providerCfg == nil {
 124		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 125	}
 126	model := config.Get().GetModelByType(agentCfg.Model)
 127
 128	if model == nil {
 129		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 130	}
 131
 132	promptID := agentPromptMap[agentCfg.ID]
 133	if promptID == "" {
 134		promptID = prompt.PromptDefault
 135	}
 136	opts := []provider.ProviderClientOption{
 137		provider.WithModel(agentCfg.Model),
 138		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 139	}
 140	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 141	if err != nil {
 142		return nil, err
 143	}
 144
 145	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 146	var smallModelProviderCfg *config.ProviderConfig
 147	if smallModelCfg.Provider == providerCfg.ID {
 148		smallModelProviderCfg = providerCfg
 149	} else {
 150		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 151
 152		if smallModelProviderCfg.ID == "" {
 153			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 154		}
 155	}
 156	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 157	if smallModel.ID == "" {
 158		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 159	}
 160
 161	titleOpts := []provider.ProviderClientOption{
 162		provider.WithModel(config.SelectedModelTypeSmall),
 163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 164	}
 165	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 166	if err != nil {
 167		return nil, err
 168	}
 169
 170	summarizeOpts := []provider.ProviderClientOption{
 171		provider.WithModel(config.SelectedModelTypeLarge),
 172		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 173	}
 174	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 175	if err != nil {
 176		return nil, err
 177	}
 178
 179	baseToolsFn := func() map[string]tools.BaseTool {
 180		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 181		defer func() {
 182			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 183		}()
 184
 185		// Base tools available to all agents
 186		cwd := cfg.WorkingDir()
 187		result := make(map[string]tools.BaseTool)
 188		for _, tool := range []tools.BaseTool{
 189			tools.NewBashTool(permissions, cwd),
 190			tools.NewDownloadTool(permissions, cwd),
 191			tools.NewEditTool(lspClients, permissions, history, cwd),
 192			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 193			tools.NewFetchTool(permissions, cwd),
 194			tools.NewGlobTool(cwd),
 195			tools.NewGrepTool(cwd),
 196			tools.NewLsTool(permissions, cwd),
 197			tools.NewSourcegraphTool(),
 198			tools.NewViewTool(lspClients, permissions, cwd),
 199			tools.NewWriteTool(lspClients, permissions, history, cwd),
 200		} {
 201			result[tool.Name()] = tool
 202		}
 203
 204		if len(lspClients) > 0 {
 205			diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
 206			result[diagnosticsTool.Name()] = diagnosticsTool
 207		}
 208
 209		if agentTool != nil {
 210			result[agentTool.Name()] = agentTool
 211		}
 212
 213		return result
 214	}
 215	mcpToolsFn := func() map[string]tools.BaseTool {
 216		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
 217		defer func() {
 218			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
 219		}()
 220
 221		mcpToolsOnce.Do(func() {
 222			doGetMCPTools(ctx, permissions, cfg)
 223		})
 224
 225		result := make(map[string]tools.BaseTool)
 226		for _, mcpTool := range mcpTools.Seq2() {
 227			result[mcpTool.Name()] = mcpTool
 228		}
 229		return result
 230	}
 231
 232	a := &agent{
 233		globalCtx:           ctx,
 234		Broker:              pubsub.NewBroker[AgentEvent](),
 235		agentCfg:            agentCfg,
 236		provider:            agentProvider,
 237		providerID:          string(providerCfg.ID),
 238		messages:            messages,
 239		sessions:            sessions,
 240		titleProvider:       titleProvider,
 241		summarizeProvider:   summarizeProvider,
 242		summarizeProviderID: string(providerCfg.ID),
 243		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 244		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 245		baseTools:           csync.NewLazyMap(baseToolsFn),
 246		promptQueue:         csync.NewMap[string, []string](),
 247	}
 248	a.setupEvents()
 249	return a, nil
 250}
 251
 252func (a *agent) Model() catwalk.Model {
 253	return *config.Get().GetModelByType(a.agentCfg.Model)
 254}
 255
 256func (a *agent) Cancel(sessionID string) {
 257	// Cancel regular requests
 258	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 259		slog.Info("Request cancellation initiated", "session_id", sessionID)
 260		cancel()
 261	}
 262
 263	// Also check for summarize requests
 264	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 265		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 266		cancel()
 267	}
 268
 269	if a.QueuedPrompts(sessionID) > 0 {
 270		slog.Info("Clearing queued prompts", "session_id", sessionID)
 271		a.promptQueue.Del(sessionID)
 272	}
 273}
 274
 275func (a *agent) IsBusy() bool {
 276	var busy bool
 277	for cancelFunc := range a.activeRequests.Seq() {
 278		if cancelFunc != nil {
 279			busy = true
 280			break
 281		}
 282	}
 283	return busy
 284}
 285
 286func (a *agent) IsSessionBusy(sessionID string) bool {
 287	_, busy := a.activeRequests.Get(sessionID)
 288	return busy
 289}
 290
 291func (a *agent) QueuedPrompts(sessionID string) int {
 292	l, ok := a.promptQueue.Get(sessionID)
 293	if !ok {
 294		return 0
 295	}
 296	return len(l)
 297}
 298
 299func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 300	if content == "" {
 301		return nil
 302	}
 303	if a.titleProvider == nil {
 304		return nil
 305	}
 306	session, err := a.sessions.Get(ctx, sessionID)
 307	if err != nil {
 308		return err
 309	}
 310	parts := []message.ContentPart{message.TextContent{
 311		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 312	}}
 313
 314	// Use streaming approach like summarization
 315	response := a.titleProvider.StreamResponse(
 316		ctx,
 317		[]message.Message{
 318			{
 319				Role:  message.User,
 320				Parts: parts,
 321			},
 322		},
 323		nil,
 324	)
 325
 326	var finalResponse *provider.ProviderResponse
 327	for r := range response {
 328		if r.Error != nil {
 329			return r.Error
 330		}
 331		finalResponse = r.Response
 332	}
 333
 334	if finalResponse == nil {
 335		return fmt.Errorf("no response received from title provider")
 336	}
 337
 338	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 339	if title == "" {
 340		return nil
 341	}
 342
 343	session.Title = title
 344	_, err = a.sessions.Save(ctx, session)
 345	return err
 346}
 347
 348func (a *agent) err(err error) AgentEvent {
 349	return AgentEvent{
 350		Type:  AgentEventTypeError,
 351		Error: err,
 352	}
 353}
 354
 355func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 356	if !a.Model().SupportsImages && attachments != nil {
 357		attachments = nil
 358	}
 359	events := make(chan AgentEvent)
 360	if a.IsSessionBusy(sessionID) {
 361		existing, ok := a.promptQueue.Get(sessionID)
 362		if !ok {
 363			existing = []string{}
 364		}
 365		existing = append(existing, content)
 366		a.promptQueue.Set(sessionID, existing)
 367		return nil, nil
 368	}
 369
 370	genCtx, cancel := context.WithCancel(ctx)
 371
 372	a.activeRequests.Set(sessionID, cancel)
 373	go func() {
 374		slog.Debug("Request started", "sessionID", sessionID)
 375		defer log.RecoverPanic("agent.Run", func() {
 376			events <- a.err(fmt.Errorf("panic while running the agent"))
 377		})
 378		var attachmentParts []message.ContentPart
 379		for _, attachment := range attachments {
 380			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 381		}
 382		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 383		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 384			slog.Error(result.Error.Error())
 385		}
 386		slog.Debug("Request completed", "sessionID", sessionID)
 387		a.activeRequests.Del(sessionID)
 388		cancel()
 389		a.Publish(pubsub.CreatedEvent, result)
 390		select {
 391		case events <- result:
 392		case <-genCtx.Done():
 393		}
 394		close(events)
 395	}()
 396	return events, nil
 397}
 398
 399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 400	cfg := config.Get()
 401	// List existing messages; if none, start title generation asynchronously.
 402	msgs, err := a.messages.List(ctx, sessionID)
 403	if err != nil {
 404		return a.err(fmt.Errorf("failed to list messages: %w", err))
 405	}
 406	if len(msgs) == 0 {
 407		go func() {
 408			defer log.RecoverPanic("agent.Run", func() {
 409				slog.Error("panic while generating title")
 410			})
 411			titleErr := a.generateTitle(context.Background(), sessionID, content)
 412			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 413				slog.Error("failed to generate title", "error", titleErr)
 414			}
 415		}()
 416	}
 417	session, err := a.sessions.Get(ctx, sessionID)
 418	if err != nil {
 419		return a.err(fmt.Errorf("failed to get session: %w", err))
 420	}
 421	if session.SummaryMessageID != "" {
 422		summaryMsgInex := -1
 423		for i, msg := range msgs {
 424			if msg.ID == session.SummaryMessageID {
 425				summaryMsgInex = i
 426				break
 427			}
 428		}
 429		if summaryMsgInex != -1 {
 430			msgs = msgs[summaryMsgInex:]
 431			msgs[0].Role = message.User
 432		}
 433	}
 434
 435	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 436	if err != nil {
 437		return a.err(fmt.Errorf("failed to create user message: %w", err))
 438	}
 439	// Append the new user message to the conversation history.
 440	msgHistory := append(msgs, userMsg)
 441
 442	for {
 443		// Check for cancellation before each iteration
 444		select {
 445		case <-ctx.Done():
 446			return a.err(ctx.Err())
 447		default:
 448			// Continue processing
 449		}
 450		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 451		if err != nil {
 452			if errors.Is(err, context.Canceled) {
 453				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 454				a.messages.Update(context.Background(), agentMessage)
 455				return a.err(ErrRequestCancelled)
 456			}
 457			return a.err(fmt.Errorf("failed to process events: %w", err))
 458		}
 459		if cfg.Options.Debug {
 460			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 461		}
 462		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 463			// We are not done, we need to respond with the tool response
 464			msgHistory = append(msgHistory, agentMessage, *toolResults)
 465			// If there are queued prompts, process the next one
 466			nextPrompt, ok := a.promptQueue.Take(sessionID)
 467			if ok {
 468				for _, prompt := range nextPrompt {
 469					// Create a new user message for the queued prompt
 470					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 471					if err != nil {
 472						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 473					}
 474					// Append the new user message to the conversation history
 475					msgHistory = append(msgHistory, userMsg)
 476				}
 477			}
 478
 479			continue
 480		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 481			queuePrompts, ok := a.promptQueue.Take(sessionID)
 482			if ok {
 483				for _, prompt := range queuePrompts {
 484					if prompt == "" {
 485						continue
 486					}
 487					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 488					if err != nil {
 489						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 490					}
 491					msgHistory = append(msgHistory, userMsg)
 492				}
 493				continue
 494			}
 495		}
 496		if agentMessage.FinishReason() == "" {
 497			// Kujtim: could not track down where this is happening but this means its cancelled
 498			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 499			_ = a.messages.Update(context.Background(), agentMessage)
 500			return a.err(ErrRequestCancelled)
 501		}
 502		return AgentEvent{
 503			Type:    AgentEventTypeResponse,
 504			Message: agentMessage,
 505			Done:    true,
 506		}
 507	}
 508}
 509
 510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 511	parts := []message.ContentPart{message.TextContent{Text: content}}
 512	parts = append(parts, attachmentParts...)
 513	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 514		Role:  message.User,
 515		Parts: parts,
 516	})
 517}
 518
 519func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 520	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 521
 522	// Create the assistant message first so the spinner shows immediately
 523	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 524		Role:     message.Assistant,
 525		Parts:    []message.ContentPart{},
 526		Model:    a.Model().ID,
 527		Provider: a.providerID,
 528	})
 529	if err != nil {
 530		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 531	}
 532
 533	// Now collect tools (which may block on MCP initialization)
 534	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools())
 535
 536	// Add the session and message ID into the context if needed by tools.
 537	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 538
 539	// Process each event in the stream.
 540	for event := range eventChan {
 541		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 542			if errors.Is(processErr, context.Canceled) {
 543				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 544			} else {
 545				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 546			}
 547			return assistantMsg, nil, processErr
 548		}
 549		if ctx.Err() != nil {
 550			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 551			return assistantMsg, nil, ctx.Err()
 552		}
 553	}
 554
 555	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 556	toolCalls := assistantMsg.ToolCalls()
 557	for i, toolCall := range toolCalls {
 558		select {
 559		case <-ctx.Done():
 560			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 561			// Make all future tool calls cancelled
 562			for j := i; j < len(toolCalls); j++ {
 563				toolResults[j] = message.ToolResult{
 564					ToolCallID: toolCalls[j].ID,
 565					Content:    "Tool execution canceled by user",
 566					IsError:    true,
 567				}
 568			}
 569			goto out
 570		default:
 571			// Continue processing
 572			tool, ok := a.getToolByName(toolCall.Name)
 573
 574			// Tool not found
 575			if !ok {
 576				toolResults[i] = message.ToolResult{
 577					ToolCallID: toolCall.ID,
 578					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 579					IsError:    true,
 580				}
 581				continue
 582			}
 583
 584			// Run tool in goroutine to allow cancellation
 585			type toolExecResult struct {
 586				response tools.ToolResponse
 587				err      error
 588			}
 589			resultChan := make(chan toolExecResult, 1)
 590
 591			go func() {
 592				response, err := tool.Run(ctx, tools.ToolCall{
 593					ID:    toolCall.ID,
 594					Name:  toolCall.Name,
 595					Input: toolCall.Input,
 596				})
 597				resultChan <- toolExecResult{response: response, err: err}
 598			}()
 599
 600			var toolResponse tools.ToolResponse
 601			var toolErr error
 602
 603			select {
 604			case <-ctx.Done():
 605				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 606				// Mark remaining tool calls as cancelled
 607				for j := i; j < len(toolCalls); j++ {
 608					toolResults[j] = message.ToolResult{
 609						ToolCallID: toolCalls[j].ID,
 610						Content:    "Tool execution canceled by user",
 611						IsError:    true,
 612					}
 613				}
 614				goto out
 615			case result := <-resultChan:
 616				toolResponse = result.response
 617				toolErr = result.err
 618			}
 619
 620			if toolErr != nil {
 621				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 622				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 623					toolResults[i] = message.ToolResult{
 624						ToolCallID: toolCall.ID,
 625						Content:    "Permission denied",
 626						IsError:    true,
 627					}
 628					for j := i + 1; j < len(toolCalls); j++ {
 629						toolResults[j] = message.ToolResult{
 630							ToolCallID: toolCalls[j].ID,
 631							Content:    "Tool execution canceled by user",
 632							IsError:    true,
 633						}
 634					}
 635					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 636					break
 637				}
 638			}
 639			toolResults[i] = message.ToolResult{
 640				ToolCallID: toolCall.ID,
 641				Content:    toolResponse.Content,
 642				Metadata:   toolResponse.Metadata,
 643				IsError:    toolResponse.IsError,
 644			}
 645		}
 646	}
 647out:
 648	if len(toolResults) == 0 {
 649		return assistantMsg, nil, nil
 650	}
 651	parts := make([]message.ContentPart, 0)
 652	for _, tr := range toolResults {
 653		parts = append(parts, tr)
 654	}
 655	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 656		Role:     message.Tool,
 657		Parts:    parts,
 658		Provider: a.providerID,
 659	})
 660	if err != nil {
 661		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 662	}
 663
 664	return assistantMsg, &msg, err
 665}
 666
 667func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 668	msg.AddFinish(finishReason, message, details)
 669	_ = a.messages.Update(ctx, *msg)
 670}
 671
 672func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 673	select {
 674	case <-ctx.Done():
 675		return ctx.Err()
 676	default:
 677		// Continue processing.
 678	}
 679
 680	switch event.Type {
 681	case provider.EventThinkingDelta:
 682		assistantMsg.AppendReasoningContent(event.Thinking)
 683		return a.messages.Update(ctx, *assistantMsg)
 684	case provider.EventSignatureDelta:
 685		assistantMsg.AppendReasoningSignature(event.Signature)
 686		return a.messages.Update(ctx, *assistantMsg)
 687	case provider.EventContentDelta:
 688		assistantMsg.FinishThinking()
 689		assistantMsg.AppendContent(event.Content)
 690		return a.messages.Update(ctx, *assistantMsg)
 691	case provider.EventToolUseStart:
 692		assistantMsg.FinishThinking()
 693		slog.Info("Tool call started", "toolCall", event.ToolCall)
 694		assistantMsg.AddToolCall(*event.ToolCall)
 695		return a.messages.Update(ctx, *assistantMsg)
 696	case provider.EventToolUseDelta:
 697		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 698		return a.messages.Update(ctx, *assistantMsg)
 699	case provider.EventToolUseStop:
 700		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 701		assistantMsg.FinishToolCall(event.ToolCall.ID)
 702		return a.messages.Update(ctx, *assistantMsg)
 703	case provider.EventError:
 704		return event.Error
 705	case provider.EventComplete:
 706		assistantMsg.FinishThinking()
 707		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 708		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 709		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 710			return fmt.Errorf("failed to update message: %w", err)
 711		}
 712		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 713	}
 714
 715	return nil
 716}
 717
 718func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 719	sess, err := a.sessions.Get(ctx, sessionID)
 720	if err != nil {
 721		return fmt.Errorf("failed to get session: %w", err)
 722	}
 723
 724	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 725		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 726		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 727		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 728
 729	sess.Cost += cost
 730	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 731	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 732
 733	_, err = a.sessions.Save(ctx, sess)
 734	if err != nil {
 735		return fmt.Errorf("failed to save session: %w", err)
 736	}
 737	return nil
 738}
 739
 740func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 741	if a.summarizeProvider == nil {
 742		return fmt.Errorf("summarize provider not available")
 743	}
 744
 745	// Check if session is busy
 746	if a.IsSessionBusy(sessionID) {
 747		return ErrSessionBusy
 748	}
 749
 750	// Create a new context with cancellation
 751	summarizeCtx, cancel := context.WithCancel(ctx)
 752
 753	// Store the cancel function in activeRequests to allow cancellation
 754	a.activeRequests.Set(sessionID+"-summarize", cancel)
 755
 756	go func() {
 757		defer a.activeRequests.Del(sessionID + "-summarize")
 758		defer cancel()
 759		event := AgentEvent{
 760			Type:     AgentEventTypeSummarize,
 761			Progress: "Starting summarization...",
 762		}
 763
 764		a.Publish(pubsub.CreatedEvent, event)
 765		// Get all messages from the session
 766		msgs, err := a.messages.List(summarizeCtx, sessionID)
 767		if err != nil {
 768			event = AgentEvent{
 769				Type:  AgentEventTypeError,
 770				Error: fmt.Errorf("failed to list messages: %w", err),
 771				Done:  true,
 772			}
 773			a.Publish(pubsub.CreatedEvent, event)
 774			return
 775		}
 776		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 777
 778		if len(msgs) == 0 {
 779			event = AgentEvent{
 780				Type:  AgentEventTypeError,
 781				Error: fmt.Errorf("no messages to summarize"),
 782				Done:  true,
 783			}
 784			a.Publish(pubsub.CreatedEvent, event)
 785			return
 786		}
 787
 788		event = AgentEvent{
 789			Type:     AgentEventTypeSummarize,
 790			Progress: "Analyzing conversation...",
 791		}
 792		a.Publish(pubsub.CreatedEvent, event)
 793
 794		// Add a system message to guide the summarization
 795		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 796
 797		// Create a new message with the summarize prompt
 798		promptMsg := message.Message{
 799			Role:  message.User,
 800			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 801		}
 802
 803		// Append the prompt to the messages
 804		msgsWithPrompt := append(msgs, promptMsg)
 805
 806		event = AgentEvent{
 807			Type:     AgentEventTypeSummarize,
 808			Progress: "Generating summary...",
 809		}
 810
 811		a.Publish(pubsub.CreatedEvent, event)
 812
 813		// Send the messages to the summarize provider
 814		response := a.summarizeProvider.StreamResponse(
 815			summarizeCtx,
 816			msgsWithPrompt,
 817			nil,
 818		)
 819		var finalResponse *provider.ProviderResponse
 820		for r := range response {
 821			if r.Error != nil {
 822				event = AgentEvent{
 823					Type:  AgentEventTypeError,
 824					Error: fmt.Errorf("failed to summarize: %w", err),
 825					Done:  true,
 826				}
 827				a.Publish(pubsub.CreatedEvent, event)
 828				return
 829			}
 830			finalResponse = r.Response
 831		}
 832
 833		summary := strings.TrimSpace(finalResponse.Content)
 834		if summary == "" {
 835			event = AgentEvent{
 836				Type:  AgentEventTypeError,
 837				Error: fmt.Errorf("empty summary returned"),
 838				Done:  true,
 839			}
 840			a.Publish(pubsub.CreatedEvent, event)
 841			return
 842		}
 843		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 844		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 845		event = AgentEvent{
 846			Type:     AgentEventTypeSummarize,
 847			Progress: "Creating new session...",
 848		}
 849
 850		a.Publish(pubsub.CreatedEvent, event)
 851		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 852		if err != nil {
 853			event = AgentEvent{
 854				Type:  AgentEventTypeError,
 855				Error: fmt.Errorf("failed to get session: %w", err),
 856				Done:  true,
 857			}
 858
 859			a.Publish(pubsub.CreatedEvent, event)
 860			return
 861		}
 862		// Create a message in the new session with the summary
 863		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 864			Role: message.Assistant,
 865			Parts: []message.ContentPart{
 866				message.TextContent{Text: summary},
 867				message.Finish{
 868					Reason: message.FinishReasonEndTurn,
 869					Time:   time.Now().Unix(),
 870				},
 871			},
 872			Model:    a.summarizeProvider.Model().ID,
 873			Provider: a.summarizeProviderID,
 874		})
 875		if err != nil {
 876			event = AgentEvent{
 877				Type:  AgentEventTypeError,
 878				Error: fmt.Errorf("failed to create summary message: %w", err),
 879				Done:  true,
 880			}
 881
 882			a.Publish(pubsub.CreatedEvent, event)
 883			return
 884		}
 885		oldSession.SummaryMessageID = msg.ID
 886		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 887		oldSession.PromptTokens = 0
 888		model := a.summarizeProvider.Model()
 889		usage := finalResponse.Usage
 890		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 891			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 892			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 893			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 894		oldSession.Cost += cost
 895		_, err = a.sessions.Save(summarizeCtx, oldSession)
 896		if err != nil {
 897			event = AgentEvent{
 898				Type:  AgentEventTypeError,
 899				Error: fmt.Errorf("failed to save session: %w", err),
 900				Done:  true,
 901			}
 902			a.Publish(pubsub.CreatedEvent, event)
 903		}
 904
 905		event = AgentEvent{
 906			Type:      AgentEventTypeSummarize,
 907			SessionID: oldSession.ID,
 908			Progress:  "Summary complete",
 909			Done:      true,
 910		}
 911		a.Publish(pubsub.CreatedEvent, event)
 912		// Send final success event with the new session ID
 913	}()
 914
 915	return nil
 916}
 917
 918func (a *agent) ClearQueue(sessionID string) {
 919	if a.QueuedPrompts(sessionID) > 0 {
 920		slog.Info("Clearing queued prompts", "session_id", sessionID)
 921		a.promptQueue.Del(sessionID)
 922	}
 923}
 924
 925func (a *agent) CancelAll() {
 926	if !a.IsBusy() {
 927		return
 928	}
 929	for key := range a.activeRequests.Seq2() {
 930		a.Cancel(key) // key is sessionID
 931	}
 932
 933	for _, cleanup := range a.cleanupFuncs {
 934		if cleanup != nil {
 935			cleanup()
 936		}
 937	}
 938
 939	timeout := time.After(5 * time.Second)
 940	for a.IsBusy() {
 941		select {
 942		case <-timeout:
 943			return
 944		default:
 945			time.Sleep(200 * time.Millisecond)
 946		}
 947	}
 948}
 949
 950func (a *agent) UpdateModel() error {
 951	cfg := config.Get()
 952
 953	// Get current provider configuration
 954	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 955	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 956		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 957	}
 958
 959	// Check if provider has changed
 960	if string(currentProviderCfg.ID) != a.providerID {
 961		// Provider changed, need to recreate the main provider
 962		model := cfg.GetModelByType(a.agentCfg.Model)
 963		if model.ID == "" {
 964			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 965		}
 966
 967		promptID := agentPromptMap[a.agentCfg.ID]
 968		if promptID == "" {
 969			promptID = prompt.PromptDefault
 970		}
 971
 972		opts := []provider.ProviderClientOption{
 973			provider.WithModel(a.agentCfg.Model),
 974			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 975		}
 976
 977		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 978		if err != nil {
 979			return fmt.Errorf("failed to create new provider: %w", err)
 980		}
 981
 982		// Update the provider and provider ID
 983		a.provider = newProvider
 984		a.providerID = string(currentProviderCfg.ID)
 985	}
 986
 987	// Check if providers have changed for title (small) and summarize (large)
 988	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 989	var smallModelProviderCfg config.ProviderConfig
 990	for p := range cfg.Providers.Seq() {
 991		if p.ID == smallModelCfg.Provider {
 992			smallModelProviderCfg = p
 993			break
 994		}
 995	}
 996	if smallModelProviderCfg.ID == "" {
 997		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 998	}
 999
1000	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1001	var largeModelProviderCfg config.ProviderConfig
1002	for p := range cfg.Providers.Seq() {
1003		if p.ID == largeModelCfg.Provider {
1004			largeModelProviderCfg = p
1005			break
1006		}
1007	}
1008	if largeModelProviderCfg.ID == "" {
1009		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1010	}
1011
1012	// Recreate title provider
1013	titleOpts := []provider.ProviderClientOption{
1014		provider.WithModel(config.SelectedModelTypeSmall),
1015		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1016		provider.WithMaxTokens(40),
1017	}
1018	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1019	if err != nil {
1020		return fmt.Errorf("failed to create new title provider: %w", err)
1021	}
1022	a.titleProvider = newTitleProvider
1023
1024	// Recreate summarize provider if provider changed (now large model)
1025	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1026		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1027		if largeModel == nil {
1028			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1029		}
1030		summarizeOpts := []provider.ProviderClientOption{
1031			provider.WithModel(config.SelectedModelTypeLarge),
1032			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1033		}
1034		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1035		if err != nil {
1036			return fmt.Errorf("failed to create new summarize provider: %w", err)
1037		}
1038		a.summarizeProvider = newSummarizeProvider
1039		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1040	}
1041
1042	return nil
1043}
1044
1045func (a *agent) allTools() []tools.BaseTool {
1046	result := slices.Collect(a.baseTools.Seq())
1047	result = slices.AppendSeq(result, a.mcpTools.Seq())
1048	return result
1049}
1050
1051func (a *agent) getToolByName(name string) (tools.BaseTool, bool) {
1052	tool, ok := a.baseTools.Get(name)
1053	if !ok {
1054		tool, ok = a.mcpTools.Get(name)
1055	}
1056	return tool, ok
1057}
1058
1059func (a *agent) setupEvents() {
1060	ctx, cancel := context.WithCancel(a.globalCtx)
1061
1062	go func() {
1063		for event := range SubscribeMCPEvents(ctx) {
1064			switch event.Payload.Type {
1065			case MCPEventToolsListChanged:
1066				name := event.Payload.Name
1067				c, ok := mcpClients.Get(name)
1068				if !ok {
1069					slog.Warn("MCP client not found for tools update", "name", name)
1070					continue
1071				}
1072				tools := getTools(ctx, name, c)
1073				updateMcpTools(name, tools)
1074				// Update the lazy map with the new tools
1075				a.mcpTools = csync.NewMapFromSeq(mcpTools.Seq2())
1076				updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1077			default:
1078				continue
1079			}
1080		}
1081	}()
1082
1083	a.cleanupFuncs = append(a.cleanupFuncs, func() {
1084		cancel()
1085	})
1086}