1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/event"
  16	"github.com/charmbracelet/crush/internal/history"
  17	"github.com/charmbracelet/crush/internal/llm/prompt"
  18	"github.com/charmbracelet/crush/internal/llm/provider"
  19	"github.com/charmbracelet/crush/internal/llm/tools"
  20	"github.com/charmbracelet/crush/internal/log"
  21	"github.com/charmbracelet/crush/internal/lsp"
  22	"github.com/charmbracelet/crush/internal/message"
  23	"github.com/charmbracelet/crush/internal/permission"
  24	"github.com/charmbracelet/crush/internal/pubsub"
  25	"github.com/charmbracelet/crush/internal/session"
  26	"github.com/charmbracelet/crush/internal/shell"
  27)
  28
  29const streamChunkTimeout = 2 * time.Minute
  30
  31type AgentEventType string
  32
  33const (
  34	AgentEventTypeError     AgentEventType = "error"
  35	AgentEventTypeResponse  AgentEventType = "response"
  36	AgentEventTypeSummarize AgentEventType = "summarize"
  37)
  38
  39type AgentEvent struct {
  40	Type    AgentEventType
  41	Message message.Message
  42	Error   error
  43
  44	// When summarizing
  45	SessionID string
  46	Progress  string
  47	Done      bool
  48}
  49
  50type Service interface {
  51	pubsub.Suscriber[AgentEvent]
  52	Model() catwalk.Model
  53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  54	Cancel(sessionID string)
  55	CancelAll()
  56	IsSessionBusy(sessionID string) bool
  57	IsBusy() bool
  58	Summarize(ctx context.Context, sessionID string) error
  59	UpdateModel() error
  60	QueuedPrompts(sessionID string) int
  61	ClearQueue(sessionID string)
  62}
  63
  64type agent struct {
  65	*pubsub.Broker[AgentEvent]
  66	agentCfg    config.Agent
  67	sessions    session.Service
  68	messages    message.Service
  69	permissions permission.Service
  70	mcpTools    []McpTool
  71
  72	tools *csync.LazySlice[tools.BaseTool]
  73	// We need this to be able to update it when model changes
  74	agentToolFn func() (tools.BaseTool, error)
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84	promptQueue    *csync.Map[string, []string]
  85}
  86
  87var agentPromptMap = map[string]prompt.PromptID{
  88	"coder": prompt.PromptCoder,
  89	"task":  prompt.PromptTask,
  90}
  91
  92func NewAgent(
  93	ctx context.Context,
  94	agentCfg config.Agent,
  95	// These services are needed in the tools
  96	permissions permission.Service,
  97	sessions session.Service,
  98	messages message.Service,
  99	history history.Service,
 100	lspClients *csync.Map[string, *lsp.Client],
 101) (Service, error) {
 102	cfg := config.Get()
 103
 104	var agentToolFn func() (tools.BaseTool, error)
 105	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 106		agentToolFn = func() (tools.BaseTool, error) {
 107			taskAgentCfg := config.Get().Agents["task"]
 108			if taskAgentCfg.ID == "" {
 109				return nil, fmt.Errorf("task agent not found in config")
 110			}
 111			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112			if err != nil {
 113				return nil, fmt.Errorf("failed to create task agent: %w", err)
 114			}
 115			return NewAgentTool(taskAgent, sessions, messages), nil
 116		}
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() []tools.BaseTool {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		allTools := []tools.BaseTool{
 184			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 185			tools.NewDownloadTool(permissions, cwd),
 186			tools.NewEditTool(lspClients, permissions, history, cwd),
 187			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 188			tools.NewFetchTool(permissions, cwd),
 189			tools.NewGlobTool(cwd),
 190			tools.NewGrepTool(cwd),
 191			tools.NewLsTool(permissions, cwd),
 192			tools.NewSourcegraphTool(),
 193			tools.NewViewTool(lspClients, permissions, cwd),
 194			tools.NewWriteTool(lspClients, permissions, history, cwd),
 195		}
 196
 197		mcpToolsOnce.Do(func() {
 198			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 199		})
 200
 201		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 202			if agentCfg.ID == "coder" {
 203				t = append(t, mcpTools...)
 204				if lspClients.Len() > 0 {
 205					t = append(t, tools.NewDiagnosticsTool(lspClients))
 206				}
 207			}
 208			return t
 209		}
 210
 211		if agentCfg.AllowedTools == nil {
 212			return withCoderTools(allTools)
 213		}
 214
 215		var filteredTools []tools.BaseTool
 216		for _, tool := range allTools {
 217			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 218				filteredTools = append(filteredTools, tool)
 219			}
 220		}
 221		return withCoderTools(filteredTools)
 222	}
 223
 224	return &agent{
 225		Broker:              pubsub.NewBroker[AgentEvent](),
 226		agentCfg:            agentCfg,
 227		provider:            agentProvider,
 228		providerID:          string(providerCfg.ID),
 229		messages:            messages,
 230		sessions:            sessions,
 231		titleProvider:       titleProvider,
 232		summarizeProvider:   summarizeProvider,
 233		summarizeProviderID: string(providerCfg.ID),
 234		agentToolFn:         agentToolFn,
 235		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 236		tools:               csync.NewLazySlice(toolFn),
 237		promptQueue:         csync.NewMap[string, []string](),
 238		permissions:         permissions,
 239	}, nil
 240}
 241
 242func (a *agent) Model() catwalk.Model {
 243	return *config.Get().GetModelByType(a.agentCfg.Model)
 244}
 245
 246func (a *agent) Cancel(sessionID string) {
 247	// Cancel regular requests
 248	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 249		slog.Info("Request cancellation initiated", "session_id", sessionID)
 250		cancel()
 251	}
 252
 253	// Also check for summarize requests
 254	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 255		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 256		cancel()
 257	}
 258
 259	if a.QueuedPrompts(sessionID) > 0 {
 260		slog.Info("Clearing queued prompts", "session_id", sessionID)
 261		a.promptQueue.Del(sessionID)
 262	}
 263}
 264
 265func (a *agent) IsBusy() bool {
 266	var busy bool
 267	for cancelFunc := range a.activeRequests.Seq() {
 268		if cancelFunc != nil {
 269			busy = true
 270			break
 271		}
 272	}
 273	return busy
 274}
 275
 276func (a *agent) IsSessionBusy(sessionID string) bool {
 277	_, busy := a.activeRequests.Get(sessionID)
 278	return busy
 279}
 280
 281func (a *agent) QueuedPrompts(sessionID string) int {
 282	l, ok := a.promptQueue.Get(sessionID)
 283	if !ok {
 284		return 0
 285	}
 286	return len(l)
 287}
 288
 289func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 290	if content == "" {
 291		return nil
 292	}
 293	if a.titleProvider == nil {
 294		return nil
 295	}
 296	session, err := a.sessions.Get(ctx, sessionID)
 297	if err != nil {
 298		return err
 299	}
 300	parts := []message.ContentPart{message.TextContent{
 301		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 302	}}
 303
 304	// Use streaming approach like summarization
 305	response := a.titleProvider.StreamResponse(
 306		ctx,
 307		[]message.Message{
 308			{
 309				Role:  message.User,
 310				Parts: parts,
 311			},
 312		},
 313		nil,
 314	)
 315
 316	var finalResponse *provider.ProviderResponse
 317	for r := range response {
 318		if r.Error != nil {
 319			return r.Error
 320		}
 321		finalResponse = r.Response
 322	}
 323
 324	if finalResponse == nil {
 325		return fmt.Errorf("no response received from title provider")
 326	}
 327
 328	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 329
 330	if idx := strings.Index(title, "</think>"); idx > 0 {
 331		title = title[idx+len("</think>"):]
 332	}
 333
 334	title = strings.TrimSpace(title)
 335	if title == "" {
 336		return nil
 337	}
 338
 339	session.Title = title
 340	_, err = a.sessions.Save(ctx, session)
 341	return err
 342}
 343
 344func (a *agent) err(err error) AgentEvent {
 345	return AgentEvent{
 346		Type:  AgentEventTypeError,
 347		Error: err,
 348	}
 349}
 350
 351func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 352	if !a.Model().SupportsImages && attachments != nil {
 353		attachments = nil
 354	}
 355	events := make(chan AgentEvent, 1)
 356	if a.IsSessionBusy(sessionID) {
 357		existing, ok := a.promptQueue.Get(sessionID)
 358		if !ok {
 359			existing = []string{}
 360		}
 361		existing = append(existing, content)
 362		a.promptQueue.Set(sessionID, existing)
 363		return nil, nil
 364	}
 365
 366	genCtx, cancel := context.WithCancel(ctx)
 367	a.activeRequests.Set(sessionID, cancel)
 368	startTime := time.Now()
 369
 370	go func() {
 371		slog.Debug("Request started", "sessionID", sessionID)
 372		defer log.RecoverPanic("agent.Run", func() {
 373			events <- a.err(fmt.Errorf("panic while running the agent"))
 374		})
 375		var attachmentParts []message.ContentPart
 376		for _, attachment := range attachments {
 377			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 378		}
 379		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 380		if result.Error != nil {
 381			if isCancelledErr(result.Error) {
 382				slog.Error("Request canceled", "sessionID", sessionID)
 383			} else {
 384				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 385				event.Error(result.Error)
 386			}
 387		} else {
 388			slog.Debug("Request completed", "sessionID", sessionID)
 389		}
 390		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 391		a.activeRequests.Del(sessionID)
 392		cancel()
 393		a.Publish(pubsub.CreatedEvent, result)
 394		events <- result
 395		close(events)
 396	}()
 397	a.eventPromptSent(sessionID)
 398	return events, nil
 399}
 400
 401func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 402	cfg := config.Get()
 403	// List existing messages; if none, start title generation asynchronously.
 404	msgs, err := a.messages.List(ctx, sessionID)
 405	if err != nil {
 406		return a.err(fmt.Errorf("failed to list messages: %w", err))
 407	}
 408	if len(msgs) == 0 {
 409		go func() {
 410			defer log.RecoverPanic("agent.Run", func() {
 411				slog.Error("panic while generating title")
 412			})
 413			titleErr := a.generateTitle(ctx, sessionID, content)
 414			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 415				slog.Error("failed to generate title", "error", titleErr)
 416			}
 417		}()
 418	}
 419	session, err := a.sessions.Get(ctx, sessionID)
 420	if err != nil {
 421		return a.err(fmt.Errorf("failed to get session: %w", err))
 422	}
 423	if session.SummaryMessageID != "" {
 424		summaryMsgInex := -1
 425		for i, msg := range msgs {
 426			if msg.ID == session.SummaryMessageID {
 427				summaryMsgInex = i
 428				break
 429			}
 430		}
 431		if summaryMsgInex != -1 {
 432			msgs = msgs[summaryMsgInex:]
 433			msgs[0].Role = message.User
 434		}
 435	}
 436
 437	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 438	if err != nil {
 439		return a.err(fmt.Errorf("failed to create user message: %w", err))
 440	}
 441	// Append the new user message to the conversation history.
 442	msgHistory := append(msgs, userMsg)
 443
 444	for {
 445		// Check for cancellation before each iteration
 446		select {
 447		case <-ctx.Done():
 448			return a.err(ctx.Err())
 449		default:
 450			// Continue processing
 451		}
 452		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 453		if err != nil {
 454			if errors.Is(err, context.Canceled) {
 455				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 456				a.messages.Update(context.Background(), agentMessage)
 457				return a.err(ErrRequestCancelled)
 458			}
 459			return a.err(fmt.Errorf("failed to process events: %w", err))
 460		}
 461		if cfg.Options.Debug {
 462			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 463		}
 464		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 465			// We are not done, we need to respond with the tool response
 466			msgHistory = append(msgHistory, agentMessage, *toolResults)
 467			// If there are queued prompts, process the next one
 468			nextPrompt, ok := a.promptQueue.Take(sessionID)
 469			if ok {
 470				for _, prompt := range nextPrompt {
 471					// Create a new user message for the queued prompt
 472					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 473					if err != nil {
 474						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 475					}
 476					// Append the new user message to the conversation history
 477					msgHistory = append(msgHistory, userMsg)
 478				}
 479			}
 480
 481			continue
 482		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 483			queuePrompts, ok := a.promptQueue.Take(sessionID)
 484			if ok {
 485				for _, prompt := range queuePrompts {
 486					if prompt == "" {
 487						continue
 488					}
 489					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 490					if err != nil {
 491						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 492					}
 493					msgHistory = append(msgHistory, userMsg)
 494				}
 495				continue
 496			}
 497		}
 498		if agentMessage.FinishReason() == "" {
 499			// Kujtim: could not track down where this is happening but this means its cancelled
 500			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 501			_ = a.messages.Update(context.Background(), agentMessage)
 502			return a.err(ErrRequestCancelled)
 503		}
 504		return AgentEvent{
 505			Type:    AgentEventTypeResponse,
 506			Message: agentMessage,
 507			Done:    true,
 508		}
 509	}
 510}
 511
 512func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 513	parts := []message.ContentPart{message.TextContent{Text: content}}
 514	parts = append(parts, attachmentParts...)
 515	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 516		Role:  message.User,
 517		Parts: parts,
 518	})
 519}
 520
 521func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 522	allTools := slices.Collect(a.tools.Seq())
 523	if a.agentToolFn != nil {
 524		agentTool, agentToolErr := a.agentToolFn()
 525		if agentToolErr != nil {
 526			return nil, agentToolErr
 527		}
 528		allTools = append(allTools, agentTool)
 529	}
 530	return allTools, nil
 531}
 532
 533func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 534	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 535
 536	// Create the assistant message first so the spinner shows immediately
 537	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 538		Role:     message.Assistant,
 539		Parts:    []message.ContentPart{},
 540		Model:    a.Model().ID,
 541		Provider: a.providerID,
 542	})
 543	if err != nil {
 544		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 545	}
 546
 547	allTools, toolsErr := a.getAllTools()
 548	if toolsErr != nil {
 549		return assistantMsg, nil, toolsErr
 550	}
 551	// Now collect tools (which may block on MCP initialization)
 552	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 553
 554	// Add the session and message ID into the context if needed by tools.
 555	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 556
 557	// Process each event in the stream.
 558	timer := time.NewTimer(streamChunkTimeout)
 559	defer timer.Stop()
 560
 561loop:
 562	for {
 563		select {
 564		case event, ok := <-eventChan:
 565			if !ok {
 566				break loop
 567			}
 568			// Reset the timeout timer since we received a chunk
 569			timer.Reset(streamChunkTimeout)
 570
 571			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 572				if errors.Is(processErr, context.Canceled) {
 573					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 574				} else {
 575					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 576				}
 577				return assistantMsg, nil, processErr
 578			}
 579		case <-timer.C:
 580			a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
 581			return assistantMsg, nil, ErrStreamTimeout
 582		case <-ctx.Done():
 583			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 584			return assistantMsg, nil, ctx.Err()
 585		}
 586	}
 587
 588	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 589	toolCalls := assistantMsg.ToolCalls()
 590	for i, toolCall := range toolCalls {
 591		select {
 592		case <-ctx.Done():
 593			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 594			// Make all future tool calls cancelled
 595			for j := i; j < len(toolCalls); j++ {
 596				toolResults[j] = message.ToolResult{
 597					ToolCallID: toolCalls[j].ID,
 598					Content:    "Tool execution canceled by user",
 599					IsError:    true,
 600				}
 601			}
 602			goto out
 603		default:
 604			// Continue processing
 605			var tool tools.BaseTool
 606			allTools, _ := a.getAllTools()
 607			for _, availableTool := range allTools {
 608				if availableTool.Info().Name == toolCall.Name {
 609					tool = availableTool
 610					break
 611				}
 612			}
 613
 614			// Tool not found
 615			if tool == nil {
 616				toolResults[i] = message.ToolResult{
 617					ToolCallID: toolCall.ID,
 618					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 619					IsError:    true,
 620				}
 621				continue
 622			}
 623
 624			// Run tool in goroutine to allow cancellation
 625			type toolExecResult struct {
 626				response tools.ToolResponse
 627				err      error
 628			}
 629			resultChan := make(chan toolExecResult, 1)
 630
 631			go func() {
 632				response, err := tool.Run(ctx, tools.ToolCall{
 633					ID:    toolCall.ID,
 634					Name:  toolCall.Name,
 635					Input: toolCall.Input,
 636				})
 637				resultChan <- toolExecResult{response: response, err: err}
 638			}()
 639
 640			var toolResponse tools.ToolResponse
 641			var toolErr error
 642
 643			select {
 644			case <-ctx.Done():
 645				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 646				// Mark remaining tool calls as cancelled
 647				for j := i; j < len(toolCalls); j++ {
 648					toolResults[j] = message.ToolResult{
 649						ToolCallID: toolCalls[j].ID,
 650						Content:    "Tool execution canceled by user",
 651						IsError:    true,
 652					}
 653				}
 654				goto out
 655			case result := <-resultChan:
 656				toolResponse = result.response
 657				toolErr = result.err
 658			}
 659
 660			if toolErr != nil {
 661				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 662				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 663					toolResults[i] = message.ToolResult{
 664						ToolCallID: toolCall.ID,
 665						Content:    "Permission denied",
 666						IsError:    true,
 667					}
 668					for j := i + 1; j < len(toolCalls); j++ {
 669						toolResults[j] = message.ToolResult{
 670							ToolCallID: toolCalls[j].ID,
 671							Content:    "Tool execution canceled by user",
 672							IsError:    true,
 673						}
 674					}
 675					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 676					break
 677				}
 678			}
 679			toolResults[i] = message.ToolResult{
 680				ToolCallID: toolCall.ID,
 681				Content:    toolResponse.Content,
 682				Metadata:   toolResponse.Metadata,
 683				IsError:    toolResponse.IsError,
 684			}
 685		}
 686	}
 687out:
 688	if len(toolResults) == 0 {
 689		return assistantMsg, nil, nil
 690	}
 691	parts := make([]message.ContentPart, 0)
 692	for _, tr := range toolResults {
 693		parts = append(parts, tr)
 694	}
 695	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 696		Role:     message.Tool,
 697		Parts:    parts,
 698		Provider: a.providerID,
 699	})
 700	if err != nil {
 701		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 702	}
 703
 704	return assistantMsg, &msg, err
 705}
 706
 707func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 708	msg.AddFinish(finishReason, message, details)
 709	_ = a.messages.Update(ctx, *msg)
 710}
 711
 712func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 713	select {
 714	case <-ctx.Done():
 715		return ctx.Err()
 716	default:
 717		// Continue processing.
 718	}
 719
 720	switch event.Type {
 721	case provider.EventThinkingDelta:
 722		assistantMsg.AppendReasoningContent(event.Thinking)
 723		return a.messages.Update(ctx, *assistantMsg)
 724	case provider.EventSignatureDelta:
 725		assistantMsg.AppendReasoningSignature(event.Signature)
 726		return a.messages.Update(ctx, *assistantMsg)
 727	case provider.EventContentDelta:
 728		assistantMsg.FinishThinking()
 729		assistantMsg.AppendContent(event.Content)
 730		return a.messages.Update(ctx, *assistantMsg)
 731	case provider.EventToolUseStart:
 732		assistantMsg.FinishThinking()
 733		slog.Info("Tool call started", "toolCall", event.ToolCall)
 734		assistantMsg.AddToolCall(*event.ToolCall)
 735		return a.messages.Update(ctx, *assistantMsg)
 736	case provider.EventToolUseDelta:
 737		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 738		return a.messages.Update(ctx, *assistantMsg)
 739	case provider.EventToolUseStop:
 740		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 741		assistantMsg.FinishToolCall(event.ToolCall.ID)
 742		return a.messages.Update(ctx, *assistantMsg)
 743	case provider.EventError:
 744		return event.Error
 745	case provider.EventComplete:
 746		assistantMsg.FinishThinking()
 747		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 748		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 749		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 750			return fmt.Errorf("failed to update message: %w", err)
 751		}
 752		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 753	}
 754
 755	return nil
 756}
 757
 758func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 759	sess, err := a.sessions.Get(ctx, sessionID)
 760	if err != nil {
 761		return fmt.Errorf("failed to get session: %w", err)
 762	}
 763
 764	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 765		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 766		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 767		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 768
 769	a.eventTokensUsed(sessionID, usage, cost)
 770
 771	sess.Cost += cost
 772	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 773	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 774
 775	_, err = a.sessions.Save(ctx, sess)
 776	if err != nil {
 777		return fmt.Errorf("failed to save session: %w", err)
 778	}
 779	return nil
 780}
 781
 782func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 783	if a.summarizeProvider == nil {
 784		return fmt.Errorf("summarize provider not available")
 785	}
 786
 787	// Check if session is busy
 788	if a.IsSessionBusy(sessionID) {
 789		return ErrSessionBusy
 790	}
 791
 792	// Create a new context with cancellation
 793	summarizeCtx, cancel := context.WithCancel(ctx)
 794
 795	// Store the cancel function in activeRequests to allow cancellation
 796	a.activeRequests.Set(sessionID+"-summarize", cancel)
 797
 798	go func() {
 799		defer a.activeRequests.Del(sessionID + "-summarize")
 800		defer cancel()
 801		event := AgentEvent{
 802			Type:     AgentEventTypeSummarize,
 803			Progress: "Starting summarization...",
 804		}
 805
 806		a.Publish(pubsub.CreatedEvent, event)
 807		// Get all messages from the session
 808		msgs, err := a.messages.List(summarizeCtx, sessionID)
 809		if err != nil {
 810			event = AgentEvent{
 811				Type:  AgentEventTypeError,
 812				Error: fmt.Errorf("failed to list messages: %w", err),
 813				Done:  true,
 814			}
 815			a.Publish(pubsub.CreatedEvent, event)
 816			return
 817		}
 818		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 819
 820		if len(msgs) == 0 {
 821			event = AgentEvent{
 822				Type:  AgentEventTypeError,
 823				Error: fmt.Errorf("no messages to summarize"),
 824				Done:  true,
 825			}
 826			a.Publish(pubsub.CreatedEvent, event)
 827			return
 828		}
 829
 830		event = AgentEvent{
 831			Type:     AgentEventTypeSummarize,
 832			Progress: "Analyzing conversation...",
 833		}
 834		a.Publish(pubsub.CreatedEvent, event)
 835
 836		// Add a system message to guide the summarization
 837		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 838
 839		// Create a new message with the summarize prompt
 840		promptMsg := message.Message{
 841			Role:  message.User,
 842			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 843		}
 844
 845		// Append the prompt to the messages
 846		msgsWithPrompt := append(msgs, promptMsg)
 847
 848		event = AgentEvent{
 849			Type:     AgentEventTypeSummarize,
 850			Progress: "Generating summary...",
 851		}
 852
 853		a.Publish(pubsub.CreatedEvent, event)
 854
 855		// Send the messages to the summarize provider
 856		response := a.summarizeProvider.StreamResponse(
 857			summarizeCtx,
 858			msgsWithPrompt,
 859			nil,
 860		)
 861		var finalResponse *provider.ProviderResponse
 862		for r := range response {
 863			if r.Error != nil {
 864				event = AgentEvent{
 865					Type:  AgentEventTypeError,
 866					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 867					Done:  true,
 868				}
 869				a.Publish(pubsub.CreatedEvent, event)
 870				return
 871			}
 872			finalResponse = r.Response
 873		}
 874
 875		summary := strings.TrimSpace(finalResponse.Content)
 876		if summary == "" {
 877			event = AgentEvent{
 878				Type:  AgentEventTypeError,
 879				Error: fmt.Errorf("empty summary returned"),
 880				Done:  true,
 881			}
 882			a.Publish(pubsub.CreatedEvent, event)
 883			return
 884		}
 885		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 886		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 887		event = AgentEvent{
 888			Type:     AgentEventTypeSummarize,
 889			Progress: "Creating new session...",
 890		}
 891
 892		a.Publish(pubsub.CreatedEvent, event)
 893		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 894		if err != nil {
 895			event = AgentEvent{
 896				Type:  AgentEventTypeError,
 897				Error: fmt.Errorf("failed to get session: %w", err),
 898				Done:  true,
 899			}
 900
 901			a.Publish(pubsub.CreatedEvent, event)
 902			return
 903		}
 904		// Create a message in the new session with the summary
 905		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 906			Role: message.Assistant,
 907			Parts: []message.ContentPart{
 908				message.TextContent{Text: summary},
 909				message.Finish{
 910					Reason: message.FinishReasonEndTurn,
 911					Time:   time.Now().Unix(),
 912				},
 913			},
 914			Model:    a.summarizeProvider.Model().ID,
 915			Provider: a.summarizeProviderID,
 916		})
 917		if err != nil {
 918			event = AgentEvent{
 919				Type:  AgentEventTypeError,
 920				Error: fmt.Errorf("failed to create summary message: %w", err),
 921				Done:  true,
 922			}
 923
 924			a.Publish(pubsub.CreatedEvent, event)
 925			return
 926		}
 927		oldSession.SummaryMessageID = msg.ID
 928		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 929		oldSession.PromptTokens = 0
 930		model := a.summarizeProvider.Model()
 931		usage := finalResponse.Usage
 932		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 933			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 934			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 935			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 936		oldSession.Cost += cost
 937		_, err = a.sessions.Save(summarizeCtx, oldSession)
 938		if err != nil {
 939			event = AgentEvent{
 940				Type:  AgentEventTypeError,
 941				Error: fmt.Errorf("failed to save session: %w", err),
 942				Done:  true,
 943			}
 944			a.Publish(pubsub.CreatedEvent, event)
 945		}
 946
 947		event = AgentEvent{
 948			Type:      AgentEventTypeSummarize,
 949			SessionID: oldSession.ID,
 950			Progress:  "Summary complete",
 951			Done:      true,
 952		}
 953		a.Publish(pubsub.CreatedEvent, event)
 954		// Send final success event with the new session ID
 955	}()
 956
 957	return nil
 958}
 959
 960func (a *agent) ClearQueue(sessionID string) {
 961	if a.QueuedPrompts(sessionID) > 0 {
 962		slog.Info("Clearing queued prompts", "session_id", sessionID)
 963		a.promptQueue.Del(sessionID)
 964	}
 965}
 966
 967func (a *agent) CancelAll() {
 968	if !a.IsBusy() {
 969		return
 970	}
 971	for key := range a.activeRequests.Seq2() {
 972		a.Cancel(key) // key is sessionID
 973	}
 974
 975	timeout := time.After(5 * time.Second)
 976	for a.IsBusy() {
 977		select {
 978		case <-timeout:
 979			return
 980		default:
 981			time.Sleep(200 * time.Millisecond)
 982		}
 983	}
 984}
 985
 986func (a *agent) UpdateModel() error {
 987	cfg := config.Get()
 988
 989	// Get current provider configuration
 990	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 991	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 992		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 993	}
 994
 995	// Check if provider has changed
 996	if string(currentProviderCfg.ID) != a.providerID {
 997		// Provider changed, need to recreate the main provider
 998		model := cfg.GetModelByType(a.agentCfg.Model)
 999		if model.ID == "" {
1000			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1001		}
1002
1003		promptID := agentPromptMap[a.agentCfg.ID]
1004		if promptID == "" {
1005			promptID = prompt.PromptDefault
1006		}
1007
1008		opts := []provider.ProviderClientOption{
1009			provider.WithModel(a.agentCfg.Model),
1010			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1011		}
1012
1013		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1014		if err != nil {
1015			return fmt.Errorf("failed to create new provider: %w", err)
1016		}
1017
1018		// Update the provider and provider ID
1019		a.provider = newProvider
1020		a.providerID = string(currentProviderCfg.ID)
1021	}
1022
1023	// Check if providers have changed for title (small) and summarize (large)
1024	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1025	var smallModelProviderCfg config.ProviderConfig
1026	for p := range cfg.Providers.Seq() {
1027		if p.ID == smallModelCfg.Provider {
1028			smallModelProviderCfg = p
1029			break
1030		}
1031	}
1032	if smallModelProviderCfg.ID == "" {
1033		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1034	}
1035
1036	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1037	var largeModelProviderCfg config.ProviderConfig
1038	for p := range cfg.Providers.Seq() {
1039		if p.ID == largeModelCfg.Provider {
1040			largeModelProviderCfg = p
1041			break
1042		}
1043	}
1044	if largeModelProviderCfg.ID == "" {
1045		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1046	}
1047
1048	var maxTitleTokens int64 = 40
1049
1050	// if the max output is too low for the gemini provider it won't return anything
1051	if smallModelCfg.Provider == "gemini" {
1052		maxTitleTokens = 1000
1053	}
1054	// Recreate title provider
1055	titleOpts := []provider.ProviderClientOption{
1056		provider.WithModel(config.SelectedModelTypeSmall),
1057		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1058		provider.WithMaxTokens(maxTitleTokens),
1059	}
1060	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1061	if err != nil {
1062		return fmt.Errorf("failed to create new title provider: %w", err)
1063	}
1064	a.titleProvider = newTitleProvider
1065
1066	// Recreate summarize provider if provider changed (now large model)
1067	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1068		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1069		if largeModel == nil {
1070			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1071		}
1072		summarizeOpts := []provider.ProviderClientOption{
1073			provider.WithModel(config.SelectedModelTypeLarge),
1074			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1075		}
1076		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1077		if err != nil {
1078			return fmt.Errorf("failed to create new summarize provider: %w", err)
1079		}
1080		a.summarizeProvider = newSummarizeProvider
1081		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1082	}
1083
1084	return nil
1085}