agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/event"
  16	"github.com/charmbracelet/crush/internal/history"
  17	"github.com/charmbracelet/crush/internal/llm/prompt"
  18	"github.com/charmbracelet/crush/internal/llm/provider"
  19	"github.com/charmbracelet/crush/internal/llm/tools"
  20	"github.com/charmbracelet/crush/internal/log"
  21	"github.com/charmbracelet/crush/internal/lsp"
  22	"github.com/charmbracelet/crush/internal/message"
  23	"github.com/charmbracelet/crush/internal/permission"
  24	"github.com/charmbracelet/crush/internal/pubsub"
  25	"github.com/charmbracelet/crush/internal/session"
  26	"github.com/charmbracelet/crush/internal/shell"
  27)
  28
  29const streamChunkTimeout = 2 * time.Minute
  30
  31type AgentEventType string
  32
  33const (
  34	AgentEventTypeError     AgentEventType = "error"
  35	AgentEventTypeResponse  AgentEventType = "response"
  36	AgentEventTypeSummarize AgentEventType = "summarize"
  37)
  38
  39type AgentEvent struct {
  40	Type    AgentEventType
  41	Message message.Message
  42	Error   error
  43
  44	// When summarizing
  45	SessionID string
  46	Progress  string
  47	Done      bool
  48}
  49
  50type Service interface {
  51	pubsub.Suscriber[AgentEvent]
  52	Model() catwalk.Model
  53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  54	Cancel(sessionID string)
  55	CancelAll()
  56	IsSessionBusy(sessionID string) bool
  57	IsBusy() bool
  58	Summarize(ctx context.Context, sessionID string) error
  59	UpdateModel() error
  60	QueuedPrompts(sessionID string) int
  61	ClearQueue(sessionID string)
  62}
  63
  64type agent struct {
  65	*pubsub.Broker[AgentEvent]
  66	agentCfg    config.Agent
  67	sessions    session.Service
  68	messages    message.Service
  69	permissions permission.Service
  70	mcpTools    []McpTool
  71
  72	tools *csync.LazySlice[tools.BaseTool]
  73	// We need this to be able to update it when model changes
  74	agentToolFn func() (tools.BaseTool, error)
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84	promptQueue    *csync.Map[string, []string]
  85}
  86
  87var agentPromptMap = map[string]prompt.PromptID{
  88	"coder": prompt.PromptCoder,
  89	"task":  prompt.PromptTask,
  90}
  91
  92func NewAgent(
  93	ctx context.Context,
  94	agentCfg config.Agent,
  95	// These services are needed in the tools
  96	permissions permission.Service,
  97	sessions session.Service,
  98	messages message.Service,
  99	history history.Service,
 100	lspClients *csync.Map[string, *lsp.Client],
 101) (Service, error) {
 102	cfg := config.Get()
 103
 104	var agentToolFn func() (tools.BaseTool, error)
 105	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 106		agentToolFn = func() (tools.BaseTool, error) {
 107			taskAgentCfg := config.Get().Agents["task"]
 108			if taskAgentCfg.ID == "" {
 109				return nil, fmt.Errorf("task agent not found in config")
 110			}
 111			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112			if err != nil {
 113				return nil, fmt.Errorf("failed to create task agent: %w", err)
 114			}
 115			return NewAgentTool(taskAgent, sessions, messages), nil
 116		}
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() []tools.BaseTool {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		allTools := []tools.BaseTool{
 184			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 185			tools.NewDownloadTool(permissions, cwd),
 186			tools.NewEditTool(lspClients, permissions, history, cwd),
 187			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 188			tools.NewFetchTool(permissions, cwd),
 189			tools.NewGlobTool(cwd),
 190			tools.NewGrepTool(cwd),
 191			tools.NewLsTool(permissions, cwd),
 192			tools.NewSourcegraphTool(),
 193			tools.NewViewTool(lspClients, permissions, cwd),
 194			tools.NewWriteTool(lspClients, permissions, history, cwd),
 195		}
 196
 197		mcpToolsOnce.Do(func() {
 198			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 199		})
 200
 201		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 202			if agentCfg.ID == "coder" {
 203				t = append(t, mcpTools...)
 204				if lspClients.Len() > 0 {
 205					t = append(t, tools.NewDiagnosticsTool(lspClients))
 206				}
 207			}
 208			return t
 209		}
 210
 211		if agentCfg.AllowedTools == nil {
 212			return withCoderTools(allTools)
 213		}
 214
 215		var filteredTools []tools.BaseTool
 216		for _, tool := range allTools {
 217			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 218				filteredTools = append(filteredTools, tool)
 219			}
 220		}
 221		return withCoderTools(filteredTools)
 222	}
 223
 224	return &agent{
 225		Broker:              pubsub.NewBroker[AgentEvent](),
 226		agentCfg:            agentCfg,
 227		provider:            agentProvider,
 228		providerID:          string(providerCfg.ID),
 229		messages:            messages,
 230		sessions:            sessions,
 231		titleProvider:       titleProvider,
 232		summarizeProvider:   summarizeProvider,
 233		summarizeProviderID: string(providerCfg.ID),
 234		agentToolFn:         agentToolFn,
 235		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 236		tools:               csync.NewLazySlice(toolFn),
 237		promptQueue:         csync.NewMap[string, []string](),
 238		permissions:         permissions,
 239	}, nil
 240}
 241
 242func (a *agent) Model() catwalk.Model {
 243	return *config.Get().GetModelByType(a.agentCfg.Model)
 244}
 245
 246func (a *agent) Cancel(sessionID string) {
 247	// Cancel regular requests
 248	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 249		slog.Info("Request cancellation initiated", "session_id", sessionID)
 250		cancel()
 251	}
 252
 253	// Also check for summarize requests
 254	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 255		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 256		cancel()
 257	}
 258
 259	if a.QueuedPrompts(sessionID) > 0 {
 260		slog.Info("Clearing queued prompts", "session_id", sessionID)
 261		a.promptQueue.Del(sessionID)
 262	}
 263}
 264
 265func (a *agent) IsBusy() bool {
 266	var busy bool
 267	for cancelFunc := range a.activeRequests.Seq() {
 268		if cancelFunc != nil {
 269			busy = true
 270			break
 271		}
 272	}
 273	return busy
 274}
 275
 276func (a *agent) IsSessionBusy(sessionID string) bool {
 277	_, busy := a.activeRequests.Get(sessionID)
 278	return busy
 279}
 280
 281func (a *agent) QueuedPrompts(sessionID string) int {
 282	l, ok := a.promptQueue.Get(sessionID)
 283	if !ok {
 284		return 0
 285	}
 286	return len(l)
 287}
 288
 289func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 290	if content == "" {
 291		return nil
 292	}
 293	if a.titleProvider == nil {
 294		return nil
 295	}
 296	session, err := a.sessions.Get(ctx, sessionID)
 297	if err != nil {
 298		return err
 299	}
 300	parts := []message.ContentPart{message.TextContent{
 301		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 302	}}
 303
 304	// Use streaming approach like summarization
 305	response := a.titleProvider.StreamResponse(
 306		ctx,
 307		[]message.Message{
 308			{
 309				Role:  message.User,
 310				Parts: parts,
 311			},
 312		},
 313		nil,
 314	)
 315
 316	var finalResponse *provider.ProviderResponse
 317	for r := range response {
 318		if r.Error != nil {
 319			return r.Error
 320		}
 321		finalResponse = r.Response
 322	}
 323
 324	if finalResponse == nil {
 325		return fmt.Errorf("no response received from title provider")
 326	}
 327
 328	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 329
 330	if idx := strings.Index(title, "</think>"); idx > 0 {
 331		title = title[idx+len("</think>"):]
 332	}
 333
 334	title = strings.TrimSpace(title)
 335	if title == "" {
 336		return nil
 337	}
 338
 339	session.Title = title
 340	_, err = a.sessions.Save(ctx, session)
 341	return err
 342}
 343
 344func (a *agent) err(err error) AgentEvent {
 345	return AgentEvent{
 346		Type:  AgentEventTypeError,
 347		Error: err,
 348	}
 349}
 350
 351func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 352	if !a.Model().SupportsImages && attachments != nil {
 353		attachments = nil
 354	}
 355	events := make(chan AgentEvent, 1)
 356	if a.IsSessionBusy(sessionID) {
 357		existing, ok := a.promptQueue.Get(sessionID)
 358		if !ok {
 359			existing = []string{}
 360		}
 361		existing = append(existing, content)
 362		a.promptQueue.Set(sessionID, existing)
 363		return nil, nil
 364	}
 365
 366	genCtx, cancel := context.WithCancel(ctx)
 367	a.activeRequests.Set(sessionID, cancel)
 368
 369	go func() {
 370		slog.Debug("Request started", "sessionID", sessionID)
 371		defer log.RecoverPanic("agent.Run", func() {
 372			events <- a.err(fmt.Errorf("panic while running the agent"))
 373		})
 374		var attachmentParts []message.ContentPart
 375		for _, attachment := range attachments {
 376			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 377		}
 378		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 379		if result.Error != nil {
 380			if isCancelledErr(result.Error) {
 381				slog.Error("Request canceled", "sessionID", sessionID)
 382			} else {
 383				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 384				event.Error(result.Error)
 385			}
 386		} else {
 387			slog.Debug("Request completed", "sessionID", sessionID)
 388		}
 389		a.activeRequests.Del(sessionID)
 390		cancel()
 391		a.Publish(pubsub.CreatedEvent, result)
 392		events <- result
 393		close(events)
 394	}()
 395	a.eventPromptSent(sessionID)
 396	return events, nil
 397}
 398
 399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 400	cfg := config.Get()
 401	// List existing messages; if none, start title generation asynchronously.
 402	msgs, err := a.messages.List(ctx, sessionID)
 403	if err != nil {
 404		return a.err(fmt.Errorf("failed to list messages: %w", err))
 405	}
 406	if len(msgs) == 0 {
 407		go func() {
 408			defer log.RecoverPanic("agent.Run", func() {
 409				slog.Error("panic while generating title")
 410			})
 411			titleErr := a.generateTitle(ctx, sessionID, content)
 412			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 413				slog.Error("failed to generate title", "error", titleErr)
 414			}
 415		}()
 416	}
 417	session, err := a.sessions.Get(ctx, sessionID)
 418	if err != nil {
 419		return a.err(fmt.Errorf("failed to get session: %w", err))
 420	}
 421	if session.SummaryMessageID != "" {
 422		summaryMsgInex := -1
 423		for i, msg := range msgs {
 424			if msg.ID == session.SummaryMessageID {
 425				summaryMsgInex = i
 426				break
 427			}
 428		}
 429		if summaryMsgInex != -1 {
 430			msgs = msgs[summaryMsgInex:]
 431			msgs[0].Role = message.User
 432		}
 433	}
 434
 435	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 436	if err != nil {
 437		return a.err(fmt.Errorf("failed to create user message: %w", err))
 438	}
 439	// Append the new user message to the conversation history.
 440	msgHistory := append(msgs, userMsg)
 441
 442	for {
 443		// Check for cancellation before each iteration
 444		select {
 445		case <-ctx.Done():
 446			return a.err(ctx.Err())
 447		default:
 448			// Continue processing
 449		}
 450		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 451		if err != nil {
 452			if errors.Is(err, context.Canceled) {
 453				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 454				a.messages.Update(context.Background(), agentMessage)
 455				return a.err(ErrRequestCancelled)
 456			}
 457			return a.err(fmt.Errorf("failed to process events: %w", err))
 458		}
 459		if cfg.Options.Debug {
 460			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 461		}
 462		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 463			// We are not done, we need to respond with the tool response
 464			msgHistory = append(msgHistory, agentMessage, *toolResults)
 465			// If there are queued prompts, process the next one
 466			nextPrompt, ok := a.promptQueue.Take(sessionID)
 467			if ok {
 468				for _, prompt := range nextPrompt {
 469					// Create a new user message for the queued prompt
 470					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 471					if err != nil {
 472						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 473					}
 474					// Append the new user message to the conversation history
 475					msgHistory = append(msgHistory, userMsg)
 476				}
 477			}
 478
 479			continue
 480		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 481			queuePrompts, ok := a.promptQueue.Take(sessionID)
 482			if ok {
 483				for _, prompt := range queuePrompts {
 484					if prompt == "" {
 485						continue
 486					}
 487					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 488					if err != nil {
 489						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 490					}
 491					msgHistory = append(msgHistory, userMsg)
 492				}
 493				continue
 494			}
 495		}
 496		if agentMessage.FinishReason() == "" {
 497			// Kujtim: could not track down where this is happening but this means its cancelled
 498			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 499			_ = a.messages.Update(context.Background(), agentMessage)
 500			return a.err(ErrRequestCancelled)
 501		}
 502		return AgentEvent{
 503			Type:    AgentEventTypeResponse,
 504			Message: agentMessage,
 505			Done:    true,
 506		}
 507	}
 508}
 509
 510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 511	parts := []message.ContentPart{message.TextContent{Text: content}}
 512	parts = append(parts, attachmentParts...)
 513	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 514		Role:  message.User,
 515		Parts: parts,
 516	})
 517}
 518
 519func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 520	allTools := slices.Collect(a.tools.Seq())
 521	if a.agentToolFn != nil {
 522		agentTool, agentToolErr := a.agentToolFn()
 523		if agentToolErr != nil {
 524			return nil, agentToolErr
 525		}
 526		allTools = append(allTools, agentTool)
 527	}
 528	return allTools, nil
 529}
 530
 531func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 532	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 533
 534	// Create the assistant message first so the spinner shows immediately
 535	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 536		Role:     message.Assistant,
 537		Parts:    []message.ContentPart{},
 538		Model:    a.Model().ID,
 539		Provider: a.providerID,
 540	})
 541	if err != nil {
 542		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 543	}
 544
 545	allTools, toolsErr := a.getAllTools()
 546	if toolsErr != nil {
 547		return assistantMsg, nil, toolsErr
 548	}
 549	// Now collect tools (which may block on MCP initialization)
 550	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 551
 552	// Add the session and message ID into the context if needed by tools.
 553	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 554
 555	// Process each event in the stream.
 556	timer := time.NewTimer(streamChunkTimeout)
 557	defer timer.Stop()
 558
 559loop:
 560	for {
 561		select {
 562		case event, ok := <-eventChan:
 563			if !ok {
 564				break loop
 565			}
 566			// Reset the timeout timer since we received a chunk
 567			timer.Reset(streamChunkTimeout)
 568
 569			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 570				if errors.Is(processErr, context.Canceled) {
 571					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 572				} else {
 573					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 574				}
 575				return assistantMsg, nil, processErr
 576			}
 577		case <-timer.C:
 578			a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
 579			return assistantMsg, nil, ErrStreamTimeout
 580		case <-ctx.Done():
 581			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 582			return assistantMsg, nil, ctx.Err()
 583		}
 584	}
 585
 586	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 587	toolCalls := assistantMsg.ToolCalls()
 588	for i, toolCall := range toolCalls {
 589		select {
 590		case <-ctx.Done():
 591			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 592			// Make all future tool calls cancelled
 593			for j := i; j < len(toolCalls); j++ {
 594				toolResults[j] = message.ToolResult{
 595					ToolCallID: toolCalls[j].ID,
 596					Content:    "Tool execution canceled by user",
 597					IsError:    true,
 598				}
 599			}
 600			goto out
 601		default:
 602			// Continue processing
 603			var tool tools.BaseTool
 604			allTools, _ := a.getAllTools()
 605			for _, availableTool := range allTools {
 606				if availableTool.Info().Name == toolCall.Name {
 607					tool = availableTool
 608					break
 609				}
 610			}
 611
 612			// Tool not found
 613			if tool == nil {
 614				toolResults[i] = message.ToolResult{
 615					ToolCallID: toolCall.ID,
 616					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 617					IsError:    true,
 618				}
 619				continue
 620			}
 621
 622			// Run tool in goroutine to allow cancellation
 623			type toolExecResult struct {
 624				response tools.ToolResponse
 625				err      error
 626			}
 627			resultChan := make(chan toolExecResult, 1)
 628
 629			go func() {
 630				response, err := tool.Run(ctx, tools.ToolCall{
 631					ID:    toolCall.ID,
 632					Name:  toolCall.Name,
 633					Input: toolCall.Input,
 634				})
 635				resultChan <- toolExecResult{response: response, err: err}
 636			}()
 637
 638			var toolResponse tools.ToolResponse
 639			var toolErr error
 640
 641			select {
 642			case <-ctx.Done():
 643				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 644				// Mark remaining tool calls as cancelled
 645				for j := i; j < len(toolCalls); j++ {
 646					toolResults[j] = message.ToolResult{
 647						ToolCallID: toolCalls[j].ID,
 648						Content:    "Tool execution canceled by user",
 649						IsError:    true,
 650					}
 651				}
 652				goto out
 653			case result := <-resultChan:
 654				toolResponse = result.response
 655				toolErr = result.err
 656			}
 657
 658			if toolErr != nil {
 659				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 660				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 661					toolResults[i] = message.ToolResult{
 662						ToolCallID: toolCall.ID,
 663						Content:    "Permission denied",
 664						IsError:    true,
 665					}
 666					for j := i + 1; j < len(toolCalls); j++ {
 667						toolResults[j] = message.ToolResult{
 668							ToolCallID: toolCalls[j].ID,
 669							Content:    "Tool execution canceled by user",
 670							IsError:    true,
 671						}
 672					}
 673					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 674					break
 675				}
 676			}
 677			toolResults[i] = message.ToolResult{
 678				ToolCallID: toolCall.ID,
 679				Content:    toolResponse.Content,
 680				Metadata:   toolResponse.Metadata,
 681				IsError:    toolResponse.IsError,
 682			}
 683		}
 684	}
 685out:
 686	if len(toolResults) == 0 {
 687		return assistantMsg, nil, nil
 688	}
 689	parts := make([]message.ContentPart, 0)
 690	for _, tr := range toolResults {
 691		parts = append(parts, tr)
 692	}
 693	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 694		Role:     message.Tool,
 695		Parts:    parts,
 696		Provider: a.providerID,
 697	})
 698	if err != nil {
 699		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 700	}
 701
 702	return assistantMsg, &msg, err
 703}
 704
 705func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 706	msg.AddFinish(finishReason, message, details)
 707	_ = a.messages.Update(ctx, *msg)
 708}
 709
 710func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 711	select {
 712	case <-ctx.Done():
 713		return ctx.Err()
 714	default:
 715		// Continue processing.
 716	}
 717
 718	switch event.Type {
 719	case provider.EventThinkingDelta:
 720		assistantMsg.AppendReasoningContent(event.Thinking)
 721		return a.messages.Update(ctx, *assistantMsg)
 722	case provider.EventSignatureDelta:
 723		assistantMsg.AppendReasoningSignature(event.Signature)
 724		return a.messages.Update(ctx, *assistantMsg)
 725	case provider.EventContentDelta:
 726		assistantMsg.FinishThinking()
 727		assistantMsg.AppendContent(event.Content)
 728		return a.messages.Update(ctx, *assistantMsg)
 729	case provider.EventToolUseStart:
 730		assistantMsg.FinishThinking()
 731		slog.Info("Tool call started", "toolCall", event.ToolCall)
 732		assistantMsg.AddToolCall(*event.ToolCall)
 733		return a.messages.Update(ctx, *assistantMsg)
 734	case provider.EventToolUseDelta:
 735		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 736		return a.messages.Update(ctx, *assistantMsg)
 737	case provider.EventToolUseStop:
 738		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 739		assistantMsg.FinishToolCall(event.ToolCall.ID)
 740		return a.messages.Update(ctx, *assistantMsg)
 741	case provider.EventError:
 742		return event.Error
 743	case provider.EventComplete:
 744		assistantMsg.FinishThinking()
 745		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 746		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 747		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 748			return fmt.Errorf("failed to update message: %w", err)
 749		}
 750		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 751	}
 752
 753	return nil
 754}
 755
 756func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 757	sess, err := a.sessions.Get(ctx, sessionID)
 758	if err != nil {
 759		return fmt.Errorf("failed to get session: %w", err)
 760	}
 761
 762	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 763		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 764		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 765		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 766
 767	a.eventTokensUsed(sessionID, usage, cost)
 768
 769	sess.Cost += cost
 770	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 771	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 772
 773	_, err = a.sessions.Save(ctx, sess)
 774	if err != nil {
 775		return fmt.Errorf("failed to save session: %w", err)
 776	}
 777	return nil
 778}
 779
 780func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 781	if a.summarizeProvider == nil {
 782		return fmt.Errorf("summarize provider not available")
 783	}
 784
 785	// Check if session is busy
 786	if a.IsSessionBusy(sessionID) {
 787		return ErrSessionBusy
 788	}
 789
 790	// Create a new context with cancellation
 791	summarizeCtx, cancel := context.WithCancel(ctx)
 792
 793	// Store the cancel function in activeRequests to allow cancellation
 794	a.activeRequests.Set(sessionID+"-summarize", cancel)
 795
 796	go func() {
 797		defer a.activeRequests.Del(sessionID + "-summarize")
 798		defer cancel()
 799		event := AgentEvent{
 800			Type:     AgentEventTypeSummarize,
 801			Progress: "Starting summarization...",
 802		}
 803
 804		a.Publish(pubsub.CreatedEvent, event)
 805		// Get all messages from the session
 806		msgs, err := a.messages.List(summarizeCtx, sessionID)
 807		if err != nil {
 808			event = AgentEvent{
 809				Type:  AgentEventTypeError,
 810				Error: fmt.Errorf("failed to list messages: %w", err),
 811				Done:  true,
 812			}
 813			a.Publish(pubsub.CreatedEvent, event)
 814			return
 815		}
 816		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 817
 818		if len(msgs) == 0 {
 819			event = AgentEvent{
 820				Type:  AgentEventTypeError,
 821				Error: fmt.Errorf("no messages to summarize"),
 822				Done:  true,
 823			}
 824			a.Publish(pubsub.CreatedEvent, event)
 825			return
 826		}
 827
 828		event = AgentEvent{
 829			Type:     AgentEventTypeSummarize,
 830			Progress: "Analyzing conversation...",
 831		}
 832		a.Publish(pubsub.CreatedEvent, event)
 833
 834		// Add a system message to guide the summarization
 835		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 836
 837		// Create a new message with the summarize prompt
 838		promptMsg := message.Message{
 839			Role:  message.User,
 840			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 841		}
 842
 843		// Append the prompt to the messages
 844		msgsWithPrompt := append(msgs, promptMsg)
 845
 846		event = AgentEvent{
 847			Type:     AgentEventTypeSummarize,
 848			Progress: "Generating summary...",
 849		}
 850
 851		a.Publish(pubsub.CreatedEvent, event)
 852
 853		// Send the messages to the summarize provider
 854		response := a.summarizeProvider.StreamResponse(
 855			summarizeCtx,
 856			msgsWithPrompt,
 857			nil,
 858		)
 859		var finalResponse *provider.ProviderResponse
 860		for r := range response {
 861			if r.Error != nil {
 862				event = AgentEvent{
 863					Type:  AgentEventTypeError,
 864					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 865					Done:  true,
 866				}
 867				a.Publish(pubsub.CreatedEvent, event)
 868				return
 869			}
 870			finalResponse = r.Response
 871		}
 872
 873		summary := strings.TrimSpace(finalResponse.Content)
 874		if summary == "" {
 875			event = AgentEvent{
 876				Type:  AgentEventTypeError,
 877				Error: fmt.Errorf("empty summary returned"),
 878				Done:  true,
 879			}
 880			a.Publish(pubsub.CreatedEvent, event)
 881			return
 882		}
 883		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 884		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 885		event = AgentEvent{
 886			Type:     AgentEventTypeSummarize,
 887			Progress: "Creating new session...",
 888		}
 889
 890		a.Publish(pubsub.CreatedEvent, event)
 891		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 892		if err != nil {
 893			event = AgentEvent{
 894				Type:  AgentEventTypeError,
 895				Error: fmt.Errorf("failed to get session: %w", err),
 896				Done:  true,
 897			}
 898
 899			a.Publish(pubsub.CreatedEvent, event)
 900			return
 901		}
 902		// Create a message in the new session with the summary
 903		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 904			Role: message.Assistant,
 905			Parts: []message.ContentPart{
 906				message.TextContent{Text: summary},
 907				message.Finish{
 908					Reason: message.FinishReasonEndTurn,
 909					Time:   time.Now().Unix(),
 910				},
 911			},
 912			Model:    a.summarizeProvider.Model().ID,
 913			Provider: a.summarizeProviderID,
 914		})
 915		if err != nil {
 916			event = AgentEvent{
 917				Type:  AgentEventTypeError,
 918				Error: fmt.Errorf("failed to create summary message: %w", err),
 919				Done:  true,
 920			}
 921
 922			a.Publish(pubsub.CreatedEvent, event)
 923			return
 924		}
 925		oldSession.SummaryMessageID = msg.ID
 926		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 927		oldSession.PromptTokens = 0
 928		model := a.summarizeProvider.Model()
 929		usage := finalResponse.Usage
 930		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 931			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 932			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 933			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 934		oldSession.Cost += cost
 935		_, err = a.sessions.Save(summarizeCtx, oldSession)
 936		if err != nil {
 937			event = AgentEvent{
 938				Type:  AgentEventTypeError,
 939				Error: fmt.Errorf("failed to save session: %w", err),
 940				Done:  true,
 941			}
 942			a.Publish(pubsub.CreatedEvent, event)
 943		}
 944
 945		event = AgentEvent{
 946			Type:      AgentEventTypeSummarize,
 947			SessionID: oldSession.ID,
 948			Progress:  "Summary complete",
 949			Done:      true,
 950		}
 951		a.Publish(pubsub.CreatedEvent, event)
 952		// Send final success event with the new session ID
 953	}()
 954
 955	return nil
 956}
 957
 958func (a *agent) ClearQueue(sessionID string) {
 959	if a.QueuedPrompts(sessionID) > 0 {
 960		slog.Info("Clearing queued prompts", "session_id", sessionID)
 961		a.promptQueue.Del(sessionID)
 962	}
 963}
 964
 965func (a *agent) CancelAll() {
 966	if !a.IsBusy() {
 967		return
 968	}
 969	for key := range a.activeRequests.Seq2() {
 970		a.Cancel(key) // key is sessionID
 971	}
 972
 973	timeout := time.After(5 * time.Second)
 974	for a.IsBusy() {
 975		select {
 976		case <-timeout:
 977			return
 978		default:
 979			time.Sleep(200 * time.Millisecond)
 980		}
 981	}
 982}
 983
 984func (a *agent) UpdateModel() error {
 985	cfg := config.Get()
 986
 987	// Get current provider configuration
 988	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 989	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 990		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 991	}
 992
 993	// Check if provider has changed
 994	if string(currentProviderCfg.ID) != a.providerID {
 995		// Provider changed, need to recreate the main provider
 996		model := cfg.GetModelByType(a.agentCfg.Model)
 997		if model.ID == "" {
 998			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 999		}
1000
1001		promptID := agentPromptMap[a.agentCfg.ID]
1002		if promptID == "" {
1003			promptID = prompt.PromptDefault
1004		}
1005
1006		opts := []provider.ProviderClientOption{
1007			provider.WithModel(a.agentCfg.Model),
1008			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1009		}
1010
1011		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1012		if err != nil {
1013			return fmt.Errorf("failed to create new provider: %w", err)
1014		}
1015
1016		// Update the provider and provider ID
1017		a.provider = newProvider
1018		a.providerID = string(currentProviderCfg.ID)
1019	}
1020
1021	// Check if providers have changed for title (small) and summarize (large)
1022	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1023	var smallModelProviderCfg config.ProviderConfig
1024	for p := range cfg.Providers.Seq() {
1025		if p.ID == smallModelCfg.Provider {
1026			smallModelProviderCfg = p
1027			break
1028		}
1029	}
1030	if smallModelProviderCfg.ID == "" {
1031		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1032	}
1033
1034	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1035	var largeModelProviderCfg config.ProviderConfig
1036	for p := range cfg.Providers.Seq() {
1037		if p.ID == largeModelCfg.Provider {
1038			largeModelProviderCfg = p
1039			break
1040		}
1041	}
1042	if largeModelProviderCfg.ID == "" {
1043		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1044	}
1045
1046	var maxTitleTokens int64 = 40
1047
1048	// if the max output is too low for the gemini provider it won't return anything
1049	if smallModelCfg.Provider == "gemini" {
1050		maxTitleTokens = 1000
1051	}
1052	// Recreate title provider
1053	titleOpts := []provider.ProviderClientOption{
1054		provider.WithModel(config.SelectedModelTypeSmall),
1055		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1056		provider.WithMaxTokens(maxTitleTokens),
1057	}
1058	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1059	if err != nil {
1060		return fmt.Errorf("failed to create new title provider: %w", err)
1061	}
1062	a.titleProvider = newTitleProvider
1063
1064	// Recreate summarize provider if provider changed (now large model)
1065	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1066		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1067		if largeModel == nil {
1068			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1069		}
1070		summarizeOpts := []provider.ProviderClientOption{
1071			provider.WithModel(config.SelectedModelTypeLarge),
1072			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1073		}
1074		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1075		if err != nil {
1076			return fmt.Errorf("failed to create new summarize provider: %w", err)
1077		}
1078		a.summarizeProvider = newSummarizeProvider
1079		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1080	}
1081
1082	return nil
1083}