agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/proto"
  24	"github.com/charmbracelet/crush/internal/pubsub"
  25	"github.com/charmbracelet/crush/internal/session"
  26	"github.com/charmbracelet/crush/internal/shell"
  27)
  28
  29// Common errors
  30var (
  31	ErrRequestCancelled = errors.New("request canceled by user")
  32	ErrSessionBusy      = errors.New("session is currently processing another request")
  33)
  34
  35type (
  36	AgentEventType = proto.AgentEventType
  37	AgentEvent     = proto.AgentEvent
  38)
  39
  40const (
  41	AgentEventTypeError     = proto.AgentEventTypeError
  42	AgentEventTypeResponse  = proto.AgentEventTypeResponse
  43	AgentEventTypeSummarize = proto.AgentEventTypeSummarize
  44)
  45
  46type Service interface {
  47	pubsub.Suscriber[AgentEvent]
  48	Model() catwalk.Model
  49	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  50	Cancel(sessionID string)
  51	CancelAll()
  52	IsSessionBusy(sessionID string) bool
  53	IsBusy() bool
  54	Summarize(ctx context.Context, sessionID string) error
  55	UpdateModel() error
  56	QueuedPrompts(sessionID string) int
  57	ClearQueue(sessionID string)
  58}
  59
  60type agent struct {
  61	*pubsub.Broker[AgentEvent]
  62	agentCfg config.Agent
  63	sessions session.Service
  64	messages message.Service
  65	mcpTools []McpTool
  66	cfg      *config.Config
  67
  68	tools *csync.LazySlice[tools.BaseTool]
  69	// We need this to be able to update it when model changes
  70	agentToolFn func() (tools.BaseTool, error)
  71
  72	provider   provider.Provider
  73	providerID string
  74
  75	titleProvider       provider.Provider
  76	summarizeProvider   provider.Provider
  77	summarizeProviderID string
  78
  79	activeRequests *csync.Map[string, context.CancelFunc]
  80	promptQueue    *csync.Map[string, []string]
  81}
  82
  83var agentPromptMap = map[string]prompt.PromptID{
  84	"coder": prompt.PromptCoder,
  85	"task":  prompt.PromptTask,
  86}
  87
  88func NewAgent(
  89	ctx context.Context,
  90	cfg *config.Config,
  91	agentCfg config.Agent,
  92	// These services are needed in the tools
  93	permissions permission.Service,
  94	sessions session.Service,
  95	messages message.Service,
  96	history history.Service,
  97	lspClients *csync.Map[string, *lsp.Client],
  98) (Service, error) {
  99	var agentToolFn func() (tools.BaseTool, error)
 100	if agentCfg.ID == "coder" {
 101		agentToolFn = func() (tools.BaseTool, error) {
 102			taskAgentCfg := cfg.Agents["task"]
 103			if taskAgentCfg.ID == "" {
 104				return nil, fmt.Errorf("task agent not found in config")
 105			}
 106			taskAgent, err := NewAgent(ctx, cfg, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 107			if err != nil {
 108				return nil, fmt.Errorf("failed to create task agent: %w", err)
 109			}
 110			return NewAgentTool(taskAgent, sessions, messages), nil
 111		}
 112	}
 113
 114	providerCfg := cfg.GetProviderForModel(agentCfg.Model)
 115	if providerCfg == nil {
 116		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 117	}
 118	model := cfg.GetModelByType(agentCfg.Model)
 119
 120	if model == nil {
 121		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 122	}
 123
 124	promptID := agentPromptMap[agentCfg.ID]
 125	if promptID == "" {
 126		promptID = prompt.PromptDefault
 127	}
 128	opts := []provider.ProviderClientOption{
 129		provider.WithModel(agentCfg.Model),
 130		provider.WithSystemMessage(prompt.GetPrompt(cfg, promptID, providerCfg.ID, cfg.Options.ContextPaths...)),
 131	}
 132	agentProvider, err := provider.NewProvider(cfg, *providerCfg, opts...)
 133	if err != nil {
 134		return nil, err
 135	}
 136
 137	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 138	var smallModelProviderCfg *config.ProviderConfig
 139	if smallModelCfg.Provider == providerCfg.ID {
 140		smallModelProviderCfg = providerCfg
 141	} else {
 142		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 143
 144		if smallModelProviderCfg.ID == "" {
 145			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 146		}
 147	}
 148	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 149	if smallModel.ID == "" {
 150		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 151	}
 152
 153	titleOpts := []provider.ProviderClientOption{
 154		provider.WithModel(config.SelectedModelTypeSmall),
 155		provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
 156	}
 157	titleProvider, err := provider.NewProvider(cfg, *smallModelProviderCfg, titleOpts...)
 158	if err != nil {
 159		return nil, err
 160	}
 161
 162	summarizeOpts := []provider.ProviderClientOption{
 163		provider.WithModel(config.SelectedModelTypeLarge),
 164		provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptSummarizer, providerCfg.ID)),
 165	}
 166	summarizeProvider, err := provider.NewProvider(cfg, *providerCfg, summarizeOpts...)
 167	if err != nil {
 168		return nil, err
 169	}
 170
 171	toolFn := func() []tools.BaseTool {
 172		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 173		defer func() {
 174			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 175		}()
 176
 177		cwd := cfg.WorkingDir()
 178		allTools := []tools.BaseTool{
 179			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 180			tools.NewDownloadTool(permissions, cwd),
 181			tools.NewEditTool(lspClients, permissions, history, cwd),
 182			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 183			tools.NewFetchTool(permissions, cwd),
 184			tools.NewGlobTool(cwd),
 185			tools.NewGrepTool(cwd),
 186			tools.NewLsTool(permissions, cwd),
 187			tools.NewSourcegraphTool(),
 188			tools.NewViewTool(lspClients, permissions, cwd),
 189			tools.NewWriteTool(lspClients, permissions, history, cwd),
 190		}
 191
 192		mcpToolsOnce.Do(func() {
 193			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 194		})
 195
 196		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 197			if agentCfg.ID == "coder" {
 198				t = append(t, mcpTools...)
 199				if lspClients.Len() > 0 {
 200					t = append(t, tools.NewDiagnosticsTool(lspClients))
 201				}
 202			}
 203			return t
 204		}
 205
 206		if agentCfg.AllowedTools == nil {
 207			return withCoderTools(allTools)
 208		}
 209
 210		var filteredTools []tools.BaseTool
 211		for _, tool := range allTools {
 212			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 213				filteredTools = append(filteredTools, tool)
 214			}
 215		}
 216		return withCoderTools(filteredTools)
 217	}
 218
 219	return &agent{
 220		Broker:              pubsub.NewBroker[AgentEvent](),
 221		agentCfg:            agentCfg,
 222		provider:            agentProvider,
 223		providerID:          string(providerCfg.ID),
 224		messages:            messages,
 225		sessions:            sessions,
 226		titleProvider:       titleProvider,
 227		summarizeProvider:   summarizeProvider,
 228		summarizeProviderID: string(providerCfg.ID),
 229		agentToolFn:         agentToolFn,
 230		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 231		tools:               csync.NewLazySlice(toolFn),
 232		promptQueue:         csync.NewMap[string, []string](),
 233		cfg:                 cfg,
 234	}, nil
 235}
 236
 237func (a *agent) Model() catwalk.Model {
 238	return *a.cfg.GetModelByType(a.agentCfg.Model)
 239}
 240
 241func (a *agent) Cancel(sessionID string) {
 242	// Cancel regular requests
 243	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 244		slog.Info("Request cancellation initiated", "session_id", sessionID)
 245		cancel()
 246	}
 247
 248	// Also check for summarize requests
 249	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 250		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 251		cancel()
 252	}
 253
 254	if a.QueuedPrompts(sessionID) > 0 {
 255		slog.Info("Clearing queued prompts", "session_id", sessionID)
 256		a.promptQueue.Del(sessionID)
 257	}
 258}
 259
 260func (a *agent) IsBusy() bool {
 261	var busy bool
 262	for cancelFunc := range a.activeRequests.Seq() {
 263		if cancelFunc != nil {
 264			busy = true
 265			break
 266		}
 267	}
 268	return busy
 269}
 270
 271func (a *agent) IsSessionBusy(sessionID string) bool {
 272	_, busy := a.activeRequests.Get(sessionID)
 273	return busy
 274}
 275
 276func (a *agent) QueuedPrompts(sessionID string) int {
 277	l, ok := a.promptQueue.Get(sessionID)
 278	if !ok {
 279		return 0
 280	}
 281	return len(l)
 282}
 283
 284func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 285	if content == "" {
 286		return nil
 287	}
 288	if a.titleProvider == nil {
 289		return nil
 290	}
 291	session, err := a.sessions.Get(ctx, sessionID)
 292	if err != nil {
 293		return err
 294	}
 295	parts := []message.ContentPart{message.TextContent{
 296		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 297	}}
 298
 299	// Use streaming approach like summarization
 300	response := a.titleProvider.StreamResponse(
 301		ctx,
 302		[]message.Message{
 303			{
 304				Role:  message.User,
 305				Parts: parts,
 306			},
 307		},
 308		nil,
 309	)
 310
 311	var finalResponse *provider.ProviderResponse
 312	for r := range response {
 313		if r.Error != nil {
 314			return r.Error
 315		}
 316		finalResponse = r.Response
 317	}
 318
 319	if finalResponse == nil {
 320		return fmt.Errorf("no response received from title provider")
 321	}
 322
 323	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 324
 325	if idx := strings.Index(title, "</think>"); idx > 0 {
 326		title = title[idx+len("</think>"):]
 327	}
 328
 329	title = strings.TrimSpace(title)
 330	if title == "" {
 331		return nil
 332	}
 333
 334	session.Title = title
 335	_, err = a.sessions.Save(ctx, session)
 336	return err
 337}
 338
 339func (a *agent) err(err error) AgentEvent {
 340	return AgentEvent{
 341		Type:  AgentEventTypeError,
 342		Error: err,
 343	}
 344}
 345
 346func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 347	if !a.Model().SupportsImages && attachments != nil {
 348		attachments = nil
 349	}
 350	events := make(chan AgentEvent, 1)
 351	if a.IsSessionBusy(sessionID) {
 352		existing, ok := a.promptQueue.Get(sessionID)
 353		if !ok {
 354			existing = []string{}
 355		}
 356		existing = append(existing, content)
 357		a.promptQueue.Set(sessionID, existing)
 358		return nil, nil
 359	}
 360
 361	genCtx, cancel := context.WithCancel(ctx)
 362
 363	a.activeRequests.Set(sessionID, cancel)
 364	go func() {
 365		slog.Debug("Request started", "sessionID", sessionID)
 366		defer log.RecoverPanic("agent.Run", func() {
 367			events <- a.err(fmt.Errorf("panic while running the agent"))
 368		})
 369		var attachmentParts []message.ContentPart
 370		for _, attachment := range attachments {
 371			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 372		}
 373		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 374		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 375			slog.Error(result.Error.Error())
 376		}
 377		slog.Debug("Request completed", "sessionID", sessionID)
 378		a.activeRequests.Del(sessionID)
 379		cancel()
 380		a.Publish(pubsub.CreatedEvent, result)
 381		events <- result
 382		close(events)
 383	}()
 384	return events, nil
 385}
 386
 387func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 388	// List existing messages; if none, start title generation asynchronously.
 389	msgs, err := a.messages.List(ctx, sessionID)
 390	if err != nil {
 391		return a.err(fmt.Errorf("failed to list messages: %w", err))
 392	}
 393	if len(msgs) == 0 {
 394		go func() {
 395			defer log.RecoverPanic("agent.Run", func() {
 396				slog.Error("panic while generating title")
 397			})
 398			titleErr := a.generateTitle(ctx, sessionID, content)
 399			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 400				slog.Error("failed to generate title", "error", titleErr)
 401			}
 402		}()
 403	}
 404	session, err := a.sessions.Get(ctx, sessionID)
 405	if err != nil {
 406		return a.err(fmt.Errorf("failed to get session: %w", err))
 407	}
 408	if session.SummaryMessageID != "" {
 409		summaryMsgInex := -1
 410		for i, msg := range msgs {
 411			if msg.ID == session.SummaryMessageID {
 412				summaryMsgInex = i
 413				break
 414			}
 415		}
 416		if summaryMsgInex != -1 {
 417			msgs = msgs[summaryMsgInex:]
 418			msgs[0].Role = message.User
 419		}
 420	}
 421
 422	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 423	if err != nil {
 424		return a.err(fmt.Errorf("failed to create user message: %w", err))
 425	}
 426	// Append the new user message to the conversation history.
 427	msgHistory := append(msgs, userMsg)
 428
 429	for {
 430		// Check for cancellation before each iteration
 431		select {
 432		case <-ctx.Done():
 433			return a.err(ctx.Err())
 434		default:
 435			// Continue processing
 436		}
 437		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 438		if err != nil {
 439			if errors.Is(err, context.Canceled) {
 440				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 441				a.messages.Update(context.Background(), agentMessage)
 442				return a.err(ErrRequestCancelled)
 443			}
 444			return a.err(fmt.Errorf("failed to process events: %w", err))
 445		}
 446		if a.cfg.Options.Debug {
 447			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 448		}
 449		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 450			// We are not done, we need to respond with the tool response
 451			msgHistory = append(msgHistory, agentMessage, *toolResults)
 452			// If there are queued prompts, process the next one
 453			nextPrompt, ok := a.promptQueue.Take(sessionID)
 454			if ok {
 455				for _, prompt := range nextPrompt {
 456					// Create a new user message for the queued prompt
 457					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 458					if err != nil {
 459						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 460					}
 461					// Append the new user message to the conversation history
 462					msgHistory = append(msgHistory, userMsg)
 463				}
 464			}
 465
 466			continue
 467		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 468			queuePrompts, ok := a.promptQueue.Take(sessionID)
 469			if ok {
 470				for _, prompt := range queuePrompts {
 471					if prompt == "" {
 472						continue
 473					}
 474					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 475					if err != nil {
 476						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 477					}
 478					msgHistory = append(msgHistory, userMsg)
 479				}
 480				continue
 481			}
 482		}
 483		if agentMessage.FinishReason() == "" {
 484			// Kujtim: could not track down where this is happening but this means its cancelled
 485			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 486			_ = a.messages.Update(context.Background(), agentMessage)
 487			return a.err(ErrRequestCancelled)
 488		}
 489		return AgentEvent{
 490			Type:    AgentEventTypeResponse,
 491			Message: agentMessage,
 492			Done:    true,
 493		}
 494	}
 495}
 496
 497func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 498	parts := []message.ContentPart{message.TextContent{Text: content}}
 499	parts = append(parts, attachmentParts...)
 500	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 501		Role:  message.User,
 502		Parts: parts,
 503	})
 504}
 505
 506func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 507	allTools := slices.Collect(a.tools.Seq())
 508	if a.agentToolFn != nil {
 509		agentTool, agentToolErr := a.agentToolFn()
 510		if agentToolErr != nil {
 511			return nil, agentToolErr
 512		}
 513		allTools = append(allTools, agentTool)
 514	}
 515	return allTools, nil
 516}
 517
 518func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 519	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 520
 521	// Create the assistant message first so the spinner shows immediately
 522	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 523		Role:     message.Assistant,
 524		Parts:    []message.ContentPart{},
 525		Model:    a.Model().ID,
 526		Provider: a.providerID,
 527	})
 528	if err != nil {
 529		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 530	}
 531
 532	allTools, toolsErr := a.getAllTools()
 533	if toolsErr != nil {
 534		return assistantMsg, nil, toolsErr
 535	}
 536	// Now collect tools (which may block on MCP initialization)
 537	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 538
 539	// Add the session and message ID into the context if needed by tools.
 540	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 541
 542	// Process each event in the stream.
 543	for event := range eventChan {
 544		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 545			if errors.Is(processErr, context.Canceled) {
 546				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 547			} else {
 548				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 549			}
 550			return assistantMsg, nil, processErr
 551		}
 552		if ctx.Err() != nil {
 553			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 554			return assistantMsg, nil, ctx.Err()
 555		}
 556	}
 557
 558	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 559	toolCalls := assistantMsg.ToolCalls()
 560	for i, toolCall := range toolCalls {
 561		select {
 562		case <-ctx.Done():
 563			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 564			// Make all future tool calls cancelled
 565			for j := i; j < len(toolCalls); j++ {
 566				toolResults[j] = message.ToolResult{
 567					ToolCallID: toolCalls[j].ID,
 568					Content:    "Tool execution canceled by user",
 569					IsError:    true,
 570				}
 571			}
 572			goto out
 573		default:
 574			// Continue processing
 575			var tool tools.BaseTool
 576			allTools, _ := a.getAllTools()
 577			for _, availableTool := range allTools {
 578				if availableTool.Info().Name == toolCall.Name {
 579					tool = availableTool
 580					break
 581				}
 582			}
 583
 584			// Tool not found
 585			if tool == nil {
 586				toolResults[i] = message.ToolResult{
 587					ToolCallID: toolCall.ID,
 588					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 589					IsError:    true,
 590				}
 591				continue
 592			}
 593
 594			// Run tool in goroutine to allow cancellation
 595			type toolExecResult struct {
 596				response tools.ToolResponse
 597				err      error
 598			}
 599			resultChan := make(chan toolExecResult, 1)
 600
 601			go func() {
 602				response, err := tool.Run(ctx, tools.ToolCall{
 603					ID:    toolCall.ID,
 604					Name:  toolCall.Name,
 605					Input: toolCall.Input,
 606				})
 607				resultChan <- toolExecResult{response: response, err: err}
 608			}()
 609
 610			var toolResponse tools.ToolResponse
 611			var toolErr error
 612
 613			select {
 614			case <-ctx.Done():
 615				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 616				// Mark remaining tool calls as cancelled
 617				for j := i; j < len(toolCalls); j++ {
 618					toolResults[j] = message.ToolResult{
 619						ToolCallID: toolCalls[j].ID,
 620						Content:    "Tool execution canceled by user",
 621						IsError:    true,
 622					}
 623				}
 624				goto out
 625			case result := <-resultChan:
 626				toolResponse = result.response
 627				toolErr = result.err
 628			}
 629
 630			if toolErr != nil {
 631				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 632				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 633					toolResults[i] = message.ToolResult{
 634						ToolCallID: toolCall.ID,
 635						Content:    "Permission denied",
 636						IsError:    true,
 637					}
 638					for j := i + 1; j < len(toolCalls); j++ {
 639						toolResults[j] = message.ToolResult{
 640							ToolCallID: toolCalls[j].ID,
 641							Content:    "Tool execution canceled by user",
 642							IsError:    true,
 643						}
 644					}
 645					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 646					break
 647				}
 648			}
 649			toolResults[i] = message.ToolResult{
 650				ToolCallID: toolCall.ID,
 651				Content:    toolResponse.Content,
 652				Metadata:   toolResponse.Metadata,
 653				IsError:    toolResponse.IsError,
 654			}
 655		}
 656	}
 657out:
 658	if len(toolResults) == 0 {
 659		return assistantMsg, nil, nil
 660	}
 661	parts := make([]message.ContentPart, 0)
 662	for _, tr := range toolResults {
 663		parts = append(parts, tr)
 664	}
 665	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 666		Role:     message.Tool,
 667		Parts:    parts,
 668		Provider: a.providerID,
 669	})
 670	if err != nil {
 671		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 672	}
 673
 674	return assistantMsg, &msg, err
 675}
 676
 677func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 678	msg.AddFinish(finishReason, message, details)
 679	_ = a.messages.Update(ctx, *msg)
 680}
 681
 682func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 683	select {
 684	case <-ctx.Done():
 685		return ctx.Err()
 686	default:
 687		// Continue processing.
 688	}
 689
 690	switch event.Type {
 691	case provider.EventThinkingDelta:
 692		assistantMsg.AppendReasoningContent(event.Thinking)
 693		return a.messages.Update(ctx, *assistantMsg)
 694	case provider.EventSignatureDelta:
 695		assistantMsg.AppendReasoningSignature(event.Signature)
 696		return a.messages.Update(ctx, *assistantMsg)
 697	case provider.EventContentDelta:
 698		assistantMsg.FinishThinking()
 699		assistantMsg.AppendContent(event.Content)
 700		return a.messages.Update(ctx, *assistantMsg)
 701	case provider.EventToolUseStart:
 702		assistantMsg.FinishThinking()
 703		slog.Info("Tool call started", "toolCall", event.ToolCall)
 704		assistantMsg.AddToolCall(*event.ToolCall)
 705		return a.messages.Update(ctx, *assistantMsg)
 706	case provider.EventToolUseDelta:
 707		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 708		return a.messages.Update(ctx, *assistantMsg)
 709	case provider.EventToolUseStop:
 710		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 711		assistantMsg.FinishToolCall(event.ToolCall.ID)
 712		return a.messages.Update(ctx, *assistantMsg)
 713	case provider.EventError:
 714		return event.Error
 715	case provider.EventComplete:
 716		assistantMsg.FinishThinking()
 717		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 718		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 719		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 720			return fmt.Errorf("failed to update message: %w", err)
 721		}
 722		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 723	}
 724
 725	return nil
 726}
 727
 728func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 729	sess, err := a.sessions.Get(ctx, sessionID)
 730	if err != nil {
 731		return fmt.Errorf("failed to get session: %w", err)
 732	}
 733
 734	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 735		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 736		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 737		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 738
 739	sess.Cost += cost
 740	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 741	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 742
 743	_, err = a.sessions.Save(ctx, sess)
 744	if err != nil {
 745		return fmt.Errorf("failed to save session: %w", err)
 746	}
 747	return nil
 748}
 749
 750func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 751	if a.summarizeProvider == nil {
 752		return fmt.Errorf("summarize provider not available")
 753	}
 754
 755	// Check if session is busy
 756	if a.IsSessionBusy(sessionID) {
 757		return ErrSessionBusy
 758	}
 759
 760	// Create a new context with cancellation
 761	summarizeCtx, cancel := context.WithCancel(ctx)
 762
 763	// Store the cancel function in activeRequests to allow cancellation
 764	a.activeRequests.Set(sessionID+"-summarize", cancel)
 765
 766	go func() {
 767		defer a.activeRequests.Del(sessionID + "-summarize")
 768		defer cancel()
 769		event := AgentEvent{
 770			Type:     AgentEventTypeSummarize,
 771			Progress: "Starting summarization...",
 772		}
 773
 774		a.Publish(pubsub.CreatedEvent, event)
 775		// Get all messages from the session
 776		msgs, err := a.messages.List(summarizeCtx, sessionID)
 777		if err != nil {
 778			event = AgentEvent{
 779				Type:  AgentEventTypeError,
 780				Error: fmt.Errorf("failed to list messages: %w", err),
 781				Done:  true,
 782			}
 783			a.Publish(pubsub.CreatedEvent, event)
 784			return
 785		}
 786		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 787
 788		if len(msgs) == 0 {
 789			event = AgentEvent{
 790				Type:  AgentEventTypeError,
 791				Error: fmt.Errorf("no messages to summarize"),
 792				Done:  true,
 793			}
 794			a.Publish(pubsub.CreatedEvent, event)
 795			return
 796		}
 797
 798		event = AgentEvent{
 799			Type:     AgentEventTypeSummarize,
 800			Progress: "Analyzing conversation...",
 801		}
 802		a.Publish(pubsub.CreatedEvent, event)
 803
 804		// Add a system message to guide the summarization
 805		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 806
 807		// Create a new message with the summarize prompt
 808		promptMsg := message.Message{
 809			Role:  message.User,
 810			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 811		}
 812
 813		// Append the prompt to the messages
 814		msgsWithPrompt := append(msgs, promptMsg)
 815
 816		event = AgentEvent{
 817			Type:     AgentEventTypeSummarize,
 818			Progress: "Generating summary...",
 819		}
 820
 821		a.Publish(pubsub.CreatedEvent, event)
 822
 823		// Send the messages to the summarize provider
 824		response := a.summarizeProvider.StreamResponse(
 825			summarizeCtx,
 826			msgsWithPrompt,
 827			nil,
 828		)
 829		var finalResponse *provider.ProviderResponse
 830		for r := range response {
 831			if r.Error != nil {
 832				event = AgentEvent{
 833					Type:  AgentEventTypeError,
 834					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 835					Done:  true,
 836				}
 837				a.Publish(pubsub.CreatedEvent, event)
 838				return
 839			}
 840			finalResponse = r.Response
 841		}
 842
 843		summary := strings.TrimSpace(finalResponse.Content)
 844		if summary == "" {
 845			event = AgentEvent{
 846				Type:  AgentEventTypeError,
 847				Error: fmt.Errorf("empty summary returned"),
 848				Done:  true,
 849			}
 850			a.Publish(pubsub.CreatedEvent, event)
 851			return
 852		}
 853		shell := shell.GetPersistentShell(a.cfg.WorkingDir())
 854		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 855		event = AgentEvent{
 856			Type:     AgentEventTypeSummarize,
 857			Progress: "Creating new session...",
 858		}
 859
 860		a.Publish(pubsub.CreatedEvent, event)
 861		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 862		if err != nil {
 863			event = AgentEvent{
 864				Type:  AgentEventTypeError,
 865				Error: fmt.Errorf("failed to get session: %w", err),
 866				Done:  true,
 867			}
 868
 869			a.Publish(pubsub.CreatedEvent, event)
 870			return
 871		}
 872		// Create a message in the new session with the summary
 873		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 874			Role: message.Assistant,
 875			Parts: []message.ContentPart{
 876				message.TextContent{Text: summary},
 877				message.Finish{
 878					Reason: message.FinishReasonEndTurn,
 879					Time:   time.Now().Unix(),
 880				},
 881			},
 882			Model:    a.summarizeProvider.Model().ID,
 883			Provider: a.summarizeProviderID,
 884		})
 885		if err != nil {
 886			event = AgentEvent{
 887				Type:  AgentEventTypeError,
 888				Error: fmt.Errorf("failed to create summary message: %w", err),
 889				Done:  true,
 890			}
 891
 892			a.Publish(pubsub.CreatedEvent, event)
 893			return
 894		}
 895		oldSession.SummaryMessageID = msg.ID
 896		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 897		oldSession.PromptTokens = 0
 898		model := a.summarizeProvider.Model()
 899		usage := finalResponse.Usage
 900		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 901			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 902			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 903			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 904		oldSession.Cost += cost
 905		_, err = a.sessions.Save(summarizeCtx, oldSession)
 906		if err != nil {
 907			event = AgentEvent{
 908				Type:  AgentEventTypeError,
 909				Error: fmt.Errorf("failed to save session: %w", err),
 910				Done:  true,
 911			}
 912			a.Publish(pubsub.CreatedEvent, event)
 913		}
 914
 915		event = AgentEvent{
 916			Type:      AgentEventTypeSummarize,
 917			SessionID: oldSession.ID,
 918			Progress:  "Summary complete",
 919			Done:      true,
 920		}
 921		a.Publish(pubsub.CreatedEvent, event)
 922		// Send final success event with the new session ID
 923	}()
 924
 925	return nil
 926}
 927
 928func (a *agent) ClearQueue(sessionID string) {
 929	if a.QueuedPrompts(sessionID) > 0 {
 930		slog.Info("Clearing queued prompts", "session_id", sessionID)
 931		a.promptQueue.Del(sessionID)
 932	}
 933}
 934
 935func (a *agent) CancelAll() {
 936	if !a.IsBusy() {
 937		return
 938	}
 939	for key := range a.activeRequests.Seq2() {
 940		a.Cancel(key) // key is sessionID
 941	}
 942
 943	timeout := time.After(5 * time.Second)
 944	for a.IsBusy() {
 945		select {
 946		case <-timeout:
 947			return
 948		default:
 949			time.Sleep(200 * time.Millisecond)
 950		}
 951	}
 952}
 953
 954func (a *agent) UpdateModel() error {
 955	// Get current provider configuration
 956	currentProviderCfg := a.cfg.GetProviderForModel(a.agentCfg.Model)
 957	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 958		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 959	}
 960
 961	// Check if provider has changed
 962	if string(currentProviderCfg.ID) != a.providerID {
 963		// Provider changed, need to recreate the main provider
 964		model := a.cfg.GetModelByType(a.agentCfg.Model)
 965		if model.ID == "" {
 966			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 967		}
 968
 969		promptID := agentPromptMap[a.agentCfg.ID]
 970		if promptID == "" {
 971			promptID = prompt.PromptDefault
 972		}
 973
 974		opts := []provider.ProviderClientOption{
 975			provider.WithModel(a.agentCfg.Model),
 976			provider.WithSystemMessage(prompt.GetPrompt(a.cfg, promptID, currentProviderCfg.ID, a.cfg.Options.ContextPaths...)),
 977		}
 978
 979		newProvider, err := provider.NewProvider(a.cfg, *currentProviderCfg, opts...)
 980		if err != nil {
 981			return fmt.Errorf("failed to create new provider: %w", err)
 982		}
 983
 984		// Update the provider and provider ID
 985		a.provider = newProvider
 986		a.providerID = string(currentProviderCfg.ID)
 987	}
 988
 989	// Check if providers have changed for title (small) and summarize (large)
 990	smallModelCfg := a.cfg.Models[config.SelectedModelTypeSmall]
 991	var smallModelProviderCfg config.ProviderConfig
 992	for p := range a.cfg.Providers.Seq() {
 993		if p.ID == smallModelCfg.Provider {
 994			smallModelProviderCfg = p
 995			break
 996		}
 997	}
 998	if smallModelProviderCfg.ID == "" {
 999		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1000	}
1001
1002	largeModelCfg := a.cfg.Models[config.SelectedModelTypeLarge]
1003	var largeModelProviderCfg config.ProviderConfig
1004	for p := range a.cfg.Providers.Seq() {
1005		if p.ID == largeModelCfg.Provider {
1006			largeModelProviderCfg = p
1007			break
1008		}
1009	}
1010	if largeModelProviderCfg.ID == "" {
1011		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1012	}
1013
1014	var maxTitleTokens int64 = 40
1015
1016	// if the max output is too low for the gemini provider it won't return anything
1017	if smallModelCfg.Provider == "gemini" {
1018		maxTitleTokens = 1000
1019	}
1020	// Recreate title provider
1021	titleOpts := []provider.ProviderClientOption{
1022		provider.WithModel(config.SelectedModelTypeSmall),
1023		provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
1024		provider.WithMaxTokens(maxTitleTokens),
1025	}
1026	newTitleProvider, err := provider.NewProvider(a.cfg, smallModelProviderCfg, titleOpts...)
1027	if err != nil {
1028		return fmt.Errorf("failed to create new title provider: %w", err)
1029	}
1030	a.titleProvider = newTitleProvider
1031
1032	// Recreate summarize provider if provider changed (now large model)
1033	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1034		largeModel := a.cfg.GetModelByType(config.SelectedModelTypeLarge)
1035		if largeModel == nil {
1036			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1037		}
1038		summarizeOpts := []provider.ProviderClientOption{
1039			provider.WithModel(config.SelectedModelTypeLarge),
1040			provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1041		}
1042		newSummarizeProvider, err := provider.NewProvider(a.cfg, largeModelProviderCfg, summarizeOpts...)
1043		if err != nil {
1044			return fmt.Errorf("failed to create new summarize provider: %w", err)
1045		}
1046		a.summarizeProvider = newSummarizeProvider
1047		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1048	}
1049
1050	return nil
1051}