agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/event"
  16	"github.com/charmbracelet/crush/internal/history"
  17	"github.com/charmbracelet/crush/internal/llm/prompt"
  18	"github.com/charmbracelet/crush/internal/llm/provider"
  19	"github.com/charmbracelet/crush/internal/llm/tools"
  20	"github.com/charmbracelet/crush/internal/log"
  21	"github.com/charmbracelet/crush/internal/lsp"
  22	"github.com/charmbracelet/crush/internal/message"
  23	"github.com/charmbracelet/crush/internal/permission"
  24	"github.com/charmbracelet/crush/internal/pubsub"
  25	"github.com/charmbracelet/crush/internal/session"
  26	"github.com/charmbracelet/crush/internal/shell"
  27)
  28
  29const streamChunkTimeout = 80 * time.Second
  30
  31type AgentEventType string
  32
  33const (
  34	AgentEventTypeError     AgentEventType = "error"
  35	AgentEventTypeResponse  AgentEventType = "response"
  36	AgentEventTypeSummarize AgentEventType = "summarize"
  37)
  38
  39type AgentEvent struct {
  40	Type    AgentEventType
  41	Message message.Message
  42	Error   error
  43
  44	// When summarizing
  45	SessionID string
  46	Progress  string
  47	Done      bool
  48}
  49
  50type Service interface {
  51	pubsub.Suscriber[AgentEvent]
  52	Model() catwalk.Model
  53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  54	Cancel(sessionID string)
  55	CancelAll()
  56	IsSessionBusy(sessionID string) bool
  57	IsBusy() bool
  58	Summarize(ctx context.Context, sessionID string) error
  59	UpdateModel() error
  60	QueuedPrompts(sessionID string) int
  61	ClearQueue(sessionID string)
  62}
  63
  64type agent struct {
  65	*pubsub.Broker[AgentEvent]
  66	agentCfg    config.Agent
  67	sessions    session.Service
  68	messages    message.Service
  69	permissions permission.Service
  70	mcpTools    []McpTool
  71
  72	tools *csync.LazySlice[tools.BaseTool]
  73	// We need this to be able to update it when model changes
  74	agentToolFn func() (tools.BaseTool, error)
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84	promptQueue    *csync.Map[string, []string]
  85}
  86
  87var agentPromptMap = map[string]prompt.PromptID{
  88	"coder": prompt.PromptCoder,
  89	"task":  prompt.PromptTask,
  90}
  91
  92func NewAgent(
  93	ctx context.Context,
  94	agentCfg config.Agent,
  95	// These services are needed in the tools
  96	permissions permission.Service,
  97	sessions session.Service,
  98	messages message.Service,
  99	history history.Service,
 100	lspClients *csync.Map[string, *lsp.Client],
 101) (Service, error) {
 102	cfg := config.Get()
 103
 104	var agentToolFn func() (tools.BaseTool, error)
 105	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 106		agentToolFn = func() (tools.BaseTool, error) {
 107			taskAgentCfg := config.Get().Agents["task"]
 108			if taskAgentCfg.ID == "" {
 109				return nil, fmt.Errorf("task agent not found in config")
 110			}
 111			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112			if err != nil {
 113				return nil, fmt.Errorf("failed to create task agent: %w", err)
 114			}
 115			return NewAgentTool(taskAgent, sessions, messages), nil
 116		}
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() []tools.BaseTool {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		allTools := []tools.BaseTool{
 184			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 185			tools.NewDownloadTool(permissions, cwd),
 186			tools.NewEditTool(lspClients, permissions, history, cwd),
 187			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 188			tools.NewFetchTool(permissions, cwd),
 189			tools.NewGlobTool(cwd),
 190			tools.NewGrepTool(cwd),
 191			tools.NewLsTool(permissions, cwd),
 192			tools.NewSourcegraphTool(),
 193			tools.NewViewTool(lspClients, permissions, cwd),
 194			tools.NewWriteTool(lspClients, permissions, history, cwd),
 195		}
 196
 197		mcpToolsOnce.Do(func() {
 198			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 199		})
 200
 201		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 202			if agentCfg.ID == "coder" {
 203				t = append(t, mcpTools...)
 204				if lspClients.Len() > 0 {
 205					t = append(t, tools.NewDiagnosticsTool(lspClients))
 206				}
 207			}
 208			return t
 209		}
 210
 211		if agentCfg.AllowedTools == nil {
 212			return withCoderTools(allTools)
 213		}
 214
 215		var filteredTools []tools.BaseTool
 216		for _, tool := range allTools {
 217			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 218				filteredTools = append(filteredTools, tool)
 219			}
 220		}
 221		return withCoderTools(filteredTools)
 222	}
 223
 224	return &agent{
 225		Broker:              pubsub.NewBroker[AgentEvent](),
 226		agentCfg:            agentCfg,
 227		provider:            agentProvider,
 228		providerID:          string(providerCfg.ID),
 229		messages:            messages,
 230		sessions:            sessions,
 231		titleProvider:       titleProvider,
 232		summarizeProvider:   summarizeProvider,
 233		summarizeProviderID: string(providerCfg.ID),
 234		agentToolFn:         agentToolFn,
 235		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 236		tools:               csync.NewLazySlice(toolFn),
 237		promptQueue:         csync.NewMap[string, []string](),
 238		permissions:         permissions,
 239	}, nil
 240}
 241
 242func (a *agent) Model() catwalk.Model {
 243	return *config.Get().GetModelByType(a.agentCfg.Model)
 244}
 245
 246func (a *agent) Cancel(sessionID string) {
 247	// Cancel regular requests
 248	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 249		slog.Info("Request cancellation initiated", "session_id", sessionID)
 250		cancel()
 251	}
 252
 253	// Also check for summarize requests
 254	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 255		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 256		cancel()
 257	}
 258
 259	if a.QueuedPrompts(sessionID) > 0 {
 260		slog.Info("Clearing queued prompts", "session_id", sessionID)
 261		a.promptQueue.Del(sessionID)
 262	}
 263}
 264
 265func (a *agent) IsBusy() bool {
 266	var busy bool
 267	for cancelFunc := range a.activeRequests.Seq() {
 268		if cancelFunc != nil {
 269			busy = true
 270			break
 271		}
 272	}
 273	return busy
 274}
 275
 276func (a *agent) IsSessionBusy(sessionID string) bool {
 277	_, busy := a.activeRequests.Get(sessionID)
 278	return busy
 279}
 280
 281func (a *agent) QueuedPrompts(sessionID string) int {
 282	l, ok := a.promptQueue.Get(sessionID)
 283	if !ok {
 284		return 0
 285	}
 286	return len(l)
 287}
 288
 289func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 290	if content == "" {
 291		return nil
 292	}
 293	if a.titleProvider == nil {
 294		return nil
 295	}
 296	session, err := a.sessions.Get(ctx, sessionID)
 297	if err != nil {
 298		return err
 299	}
 300	parts := []message.ContentPart{message.TextContent{
 301		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 302	}}
 303
 304	// Use streaming approach like summarization
 305	response := a.titleProvider.StreamResponse(
 306		ctx,
 307		[]message.Message{
 308			{
 309				Role:  message.User,
 310				Parts: parts,
 311			},
 312		},
 313		nil,
 314	)
 315
 316	var finalResponse *provider.ProviderResponse
 317	for r := range response {
 318		if r.Error != nil {
 319			return r.Error
 320		}
 321		finalResponse = r.Response
 322	}
 323
 324	if finalResponse == nil {
 325		return fmt.Errorf("no response received from title provider")
 326	}
 327
 328	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 329
 330	if idx := strings.Index(title, "</think>"); idx > 0 {
 331		title = title[idx+len("</think>"):]
 332	}
 333
 334	title = strings.TrimSpace(title)
 335	if title == "" {
 336		return nil
 337	}
 338
 339	session.Title = title
 340	_, err = a.sessions.Save(ctx, session)
 341	return err
 342}
 343
 344func (a *agent) err(err error) AgentEvent {
 345	return AgentEvent{
 346		Type:  AgentEventTypeError,
 347		Error: err,
 348	}
 349}
 350
 351func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 352	if !a.Model().SupportsImages && attachments != nil {
 353		attachments = nil
 354	}
 355	events := make(chan AgentEvent, 1)
 356	if a.IsSessionBusy(sessionID) {
 357		existing, ok := a.promptQueue.Get(sessionID)
 358		if !ok {
 359			existing = []string{}
 360		}
 361		existing = append(existing, content)
 362		a.promptQueue.Set(sessionID, existing)
 363		return nil, nil
 364	}
 365
 366	genCtx, cancel := context.WithCancel(ctx)
 367	a.activeRequests.Set(sessionID, cancel)
 368	startTime := time.Now()
 369
 370	go func() {
 371		slog.Debug("Request started", "sessionID", sessionID)
 372		defer log.RecoverPanic("agent.Run", func() {
 373			events <- a.err(fmt.Errorf("panic while running the agent"))
 374		})
 375		var attachmentParts []message.ContentPart
 376		for _, attachment := range attachments {
 377			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 378		}
 379		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 380		if result.Error != nil {
 381			if isCancelledErr(result.Error) {
 382				slog.Error("Request canceled", "sessionID", sessionID)
 383			} else {
 384				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 385				event.Error(result.Error)
 386			}
 387		} else {
 388			slog.Debug("Request completed", "sessionID", sessionID)
 389		}
 390		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 391		a.activeRequests.Del(sessionID)
 392		cancel()
 393		a.Publish(pubsub.CreatedEvent, result)
 394		events <- result
 395		close(events)
 396	}()
 397	a.eventPromptSent(sessionID)
 398	return events, nil
 399}
 400
 401func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 402	cfg := config.Get()
 403	// List existing messages; if none, start title generation asynchronously.
 404	msgs, err := a.messages.List(ctx, sessionID)
 405	if err != nil {
 406		return a.err(fmt.Errorf("failed to list messages: %w", err))
 407	}
 408	if len(msgs) == 0 {
 409		go func() {
 410			defer log.RecoverPanic("agent.Run", func() {
 411				slog.Error("panic while generating title")
 412			})
 413			titleErr := a.generateTitle(ctx, sessionID, content)
 414			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 415				slog.Error("failed to generate title", "error", titleErr)
 416			}
 417		}()
 418	}
 419	session, err := a.sessions.Get(ctx, sessionID)
 420	if err != nil {
 421		return a.err(fmt.Errorf("failed to get session: %w", err))
 422	}
 423	if session.SummaryMessageID != "" {
 424		summaryMsgInex := -1
 425		for i, msg := range msgs {
 426			if msg.ID == session.SummaryMessageID {
 427				summaryMsgInex = i
 428				break
 429			}
 430		}
 431		if summaryMsgInex != -1 {
 432			msgs = msgs[summaryMsgInex:]
 433			msgs[0].Role = message.User
 434		}
 435	}
 436
 437	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 438	if err != nil {
 439		return a.err(fmt.Errorf("failed to create user message: %w", err))
 440	}
 441	// Append the new user message to the conversation history.
 442	msgHistory := append(msgs, userMsg)
 443
 444	for {
 445		// Check for cancellation before each iteration
 446		select {
 447		case <-ctx.Done():
 448			return a.err(ctx.Err())
 449		default:
 450			// Continue processing
 451		}
 452		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 453		if err != nil {
 454			if errors.Is(err, context.Canceled) {
 455				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 456				a.messages.Update(context.Background(), agentMessage)
 457				return a.err(ErrRequestCancelled)
 458			}
 459			return a.err(fmt.Errorf("failed to process events: %w", err))
 460		}
 461		if cfg.Options.Debug {
 462			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 463		}
 464		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 465			// We are not done, we need to respond with the tool response
 466			msgHistory = append(msgHistory, agentMessage, *toolResults)
 467			// If there are queued prompts, process the next one
 468			nextPrompt, ok := a.promptQueue.Take(sessionID)
 469			if ok {
 470				for _, prompt := range nextPrompt {
 471					// Create a new user message for the queued prompt
 472					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 473					if err != nil {
 474						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 475					}
 476					// Append the new user message to the conversation history
 477					msgHistory = append(msgHistory, userMsg)
 478				}
 479			}
 480
 481			continue
 482		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 483			queuePrompts, ok := a.promptQueue.Take(sessionID)
 484			if ok {
 485				for _, prompt := range queuePrompts {
 486					if prompt == "" {
 487						continue
 488					}
 489					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 490					if err != nil {
 491						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 492					}
 493					msgHistory = append(msgHistory, userMsg)
 494				}
 495				continue
 496			}
 497		}
 498		if agentMessage.FinishReason() == "" {
 499			// Kujtim: could not track down where this is happening but this means its cancelled
 500			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 501			_ = a.messages.Update(context.Background(), agentMessage)
 502			return a.err(ErrRequestCancelled)
 503		}
 504		return AgentEvent{
 505			Type:    AgentEventTypeResponse,
 506			Message: agentMessage,
 507			Done:    true,
 508		}
 509	}
 510}
 511
 512func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 513	parts := []message.ContentPart{message.TextContent{Text: content}}
 514	parts = append(parts, attachmentParts...)
 515	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 516		Role:  message.User,
 517		Parts: parts,
 518	})
 519}
 520
 521func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 522	allTools := slices.Collect(a.tools.Seq())
 523	if a.agentToolFn != nil {
 524		agentTool, agentToolErr := a.agentToolFn()
 525		if agentToolErr != nil {
 526			return nil, agentToolErr
 527		}
 528		allTools = append(allTools, agentTool)
 529	}
 530	return allTools, nil
 531}
 532
 533func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 534	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 535
 536	// Create the assistant message first so the spinner shows immediately
 537	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 538		Role:     message.Assistant,
 539		Parts:    []message.ContentPart{},
 540		Model:    a.Model().ID,
 541		Provider: a.providerID,
 542	})
 543	if err != nil {
 544		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 545	}
 546
 547	allTools, toolsErr := a.getAllTools()
 548	if toolsErr != nil {
 549		return assistantMsg, nil, toolsErr
 550	}
 551	// Now collect tools (which may block on MCP initialization)
 552	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 553
 554	// Add the session and message ID into the context if needed by tools.
 555	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 556
 557	// Process each event in the stream.
 558loop:
 559	for {
 560		select {
 561		case event, ok := <-eventChan:
 562			if !ok {
 563				break loop
 564			}
 565			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 566				if errors.Is(processErr, context.Canceled) {
 567					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 568				} else {
 569					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 570				}
 571				return assistantMsg, nil, processErr
 572			}
 573		case <-time.After(streamChunkTimeout):
 574			a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
 575			return assistantMsg, nil, fmt.Errorf("stream chunk timeout")
 576		case <-ctx.Done():
 577			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 578			return assistantMsg, nil, ctx.Err()
 579		}
 580	}
 581
 582	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 583	toolCalls := assistantMsg.ToolCalls()
 584	for i, toolCall := range toolCalls {
 585		select {
 586		case <-ctx.Done():
 587			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 588			// Make all future tool calls cancelled
 589			for j := i; j < len(toolCalls); j++ {
 590				toolResults[j] = message.ToolResult{
 591					ToolCallID: toolCalls[j].ID,
 592					Content:    "Tool execution canceled by user",
 593					IsError:    true,
 594				}
 595			}
 596			goto out
 597		default:
 598			// Continue processing
 599			var tool tools.BaseTool
 600			allTools, _ := a.getAllTools()
 601			for _, availableTool := range allTools {
 602				if availableTool.Info().Name == toolCall.Name {
 603					tool = availableTool
 604					break
 605				}
 606			}
 607
 608			// Tool not found
 609			if tool == nil {
 610				toolResults[i] = message.ToolResult{
 611					ToolCallID: toolCall.ID,
 612					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 613					IsError:    true,
 614				}
 615				continue
 616			}
 617
 618			// Run tool in goroutine to allow cancellation
 619			type toolExecResult struct {
 620				response tools.ToolResponse
 621				err      error
 622			}
 623			resultChan := make(chan toolExecResult, 1)
 624
 625			go func() {
 626				response, err := tool.Run(ctx, tools.ToolCall{
 627					ID:    toolCall.ID,
 628					Name:  toolCall.Name,
 629					Input: toolCall.Input,
 630				})
 631				resultChan <- toolExecResult{response: response, err: err}
 632			}()
 633
 634			var toolResponse tools.ToolResponse
 635			var toolErr error
 636
 637			select {
 638			case <-ctx.Done():
 639				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 640				// Mark remaining tool calls as cancelled
 641				for j := i; j < len(toolCalls); j++ {
 642					toolResults[j] = message.ToolResult{
 643						ToolCallID: toolCalls[j].ID,
 644						Content:    "Tool execution canceled by user",
 645						IsError:    true,
 646					}
 647				}
 648				goto out
 649			case result := <-resultChan:
 650				toolResponse = result.response
 651				toolErr = result.err
 652			}
 653
 654			if toolErr != nil {
 655				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 656				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 657					toolResults[i] = message.ToolResult{
 658						ToolCallID: toolCall.ID,
 659						Content:    "Permission denied",
 660						IsError:    true,
 661					}
 662					for j := i + 1; j < len(toolCalls); j++ {
 663						toolResults[j] = message.ToolResult{
 664							ToolCallID: toolCalls[j].ID,
 665							Content:    "Tool execution canceled by user",
 666							IsError:    true,
 667						}
 668					}
 669					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 670					break
 671				}
 672			}
 673			toolResults[i] = message.ToolResult{
 674				ToolCallID: toolCall.ID,
 675				Content:    toolResponse.Content,
 676				Metadata:   toolResponse.Metadata,
 677				IsError:    toolResponse.IsError,
 678			}
 679		}
 680	}
 681out:
 682	if len(toolResults) == 0 {
 683		return assistantMsg, nil, nil
 684	}
 685	parts := make([]message.ContentPart, 0)
 686	for _, tr := range toolResults {
 687		parts = append(parts, tr)
 688	}
 689	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 690		Role:     message.Tool,
 691		Parts:    parts,
 692		Provider: a.providerID,
 693	})
 694	if err != nil {
 695		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 696	}
 697
 698	return assistantMsg, &msg, err
 699}
 700
 701func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 702	msg.AddFinish(finishReason, message, details)
 703	_ = a.messages.Update(ctx, *msg)
 704}
 705
 706func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 707	select {
 708	case <-ctx.Done():
 709		return ctx.Err()
 710	default:
 711		// Continue processing.
 712	}
 713
 714	switch event.Type {
 715	case provider.EventThinkingDelta:
 716		assistantMsg.AppendReasoningContent(event.Thinking)
 717		return a.messages.Update(ctx, *assistantMsg)
 718	case provider.EventSignatureDelta:
 719		assistantMsg.AppendReasoningSignature(event.Signature)
 720		return a.messages.Update(ctx, *assistantMsg)
 721	case provider.EventContentDelta:
 722		assistantMsg.FinishThinking()
 723		assistantMsg.AppendContent(event.Content)
 724		return a.messages.Update(ctx, *assistantMsg)
 725	case provider.EventToolUseStart:
 726		assistantMsg.FinishThinking()
 727		slog.Info("Tool call started", "toolCall", event.ToolCall)
 728		assistantMsg.AddToolCall(*event.ToolCall)
 729		return a.messages.Update(ctx, *assistantMsg)
 730	case provider.EventToolUseDelta:
 731		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 732		return a.messages.Update(ctx, *assistantMsg)
 733	case provider.EventToolUseStop:
 734		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 735		assistantMsg.FinishToolCall(event.ToolCall.ID)
 736		return a.messages.Update(ctx, *assistantMsg)
 737	case provider.EventError:
 738		return event.Error
 739	case provider.EventComplete:
 740		assistantMsg.FinishThinking()
 741		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 742		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 743		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 744			return fmt.Errorf("failed to update message: %w", err)
 745		}
 746		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 747	}
 748
 749	return nil
 750}
 751
 752func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 753	sess, err := a.sessions.Get(ctx, sessionID)
 754	if err != nil {
 755		return fmt.Errorf("failed to get session: %w", err)
 756	}
 757
 758	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 759		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 760		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 761		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 762
 763	a.eventTokensUsed(sessionID, usage, cost)
 764
 765	sess.Cost += cost
 766	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 767	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 768
 769	_, err = a.sessions.Save(ctx, sess)
 770	if err != nil {
 771		return fmt.Errorf("failed to save session: %w", err)
 772	}
 773	return nil
 774}
 775
 776func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 777	if a.summarizeProvider == nil {
 778		return fmt.Errorf("summarize provider not available")
 779	}
 780
 781	// Check if session is busy
 782	if a.IsSessionBusy(sessionID) {
 783		return ErrSessionBusy
 784	}
 785
 786	// Create a new context with cancellation
 787	summarizeCtx, cancel := context.WithCancel(ctx)
 788
 789	// Store the cancel function in activeRequests to allow cancellation
 790	a.activeRequests.Set(sessionID+"-summarize", cancel)
 791
 792	go func() {
 793		defer a.activeRequests.Del(sessionID + "-summarize")
 794		defer cancel()
 795		event := AgentEvent{
 796			Type:     AgentEventTypeSummarize,
 797			Progress: "Starting summarization...",
 798		}
 799
 800		a.Publish(pubsub.CreatedEvent, event)
 801		// Get all messages from the session
 802		msgs, err := a.messages.List(summarizeCtx, sessionID)
 803		if err != nil {
 804			event = AgentEvent{
 805				Type:  AgentEventTypeError,
 806				Error: fmt.Errorf("failed to list messages: %w", err),
 807				Done:  true,
 808			}
 809			a.Publish(pubsub.CreatedEvent, event)
 810			return
 811		}
 812		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 813
 814		if len(msgs) == 0 {
 815			event = AgentEvent{
 816				Type:  AgentEventTypeError,
 817				Error: fmt.Errorf("no messages to summarize"),
 818				Done:  true,
 819			}
 820			a.Publish(pubsub.CreatedEvent, event)
 821			return
 822		}
 823
 824		event = AgentEvent{
 825			Type:     AgentEventTypeSummarize,
 826			Progress: "Analyzing conversation...",
 827		}
 828		a.Publish(pubsub.CreatedEvent, event)
 829
 830		// Add a system message to guide the summarization
 831		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 832
 833		// Create a new message with the summarize prompt
 834		promptMsg := message.Message{
 835			Role:  message.User,
 836			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 837		}
 838
 839		// Append the prompt to the messages
 840		msgsWithPrompt := append(msgs, promptMsg)
 841
 842		event = AgentEvent{
 843			Type:     AgentEventTypeSummarize,
 844			Progress: "Generating summary...",
 845		}
 846
 847		a.Publish(pubsub.CreatedEvent, event)
 848
 849		// Send the messages to the summarize provider
 850		response := a.summarizeProvider.StreamResponse(
 851			summarizeCtx,
 852			msgsWithPrompt,
 853			nil,
 854		)
 855		var finalResponse *provider.ProviderResponse
 856		for r := range response {
 857			if r.Error != nil {
 858				event = AgentEvent{
 859					Type:  AgentEventTypeError,
 860					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 861					Done:  true,
 862				}
 863				a.Publish(pubsub.CreatedEvent, event)
 864				return
 865			}
 866			finalResponse = r.Response
 867		}
 868
 869		summary := strings.TrimSpace(finalResponse.Content)
 870		if summary == "" {
 871			event = AgentEvent{
 872				Type:  AgentEventTypeError,
 873				Error: fmt.Errorf("empty summary returned"),
 874				Done:  true,
 875			}
 876			a.Publish(pubsub.CreatedEvent, event)
 877			return
 878		}
 879		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 880		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 881		event = AgentEvent{
 882			Type:     AgentEventTypeSummarize,
 883			Progress: "Creating new session...",
 884		}
 885
 886		a.Publish(pubsub.CreatedEvent, event)
 887		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 888		if err != nil {
 889			event = AgentEvent{
 890				Type:  AgentEventTypeError,
 891				Error: fmt.Errorf("failed to get session: %w", err),
 892				Done:  true,
 893			}
 894
 895			a.Publish(pubsub.CreatedEvent, event)
 896			return
 897		}
 898		// Create a message in the new session with the summary
 899		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 900			Role: message.Assistant,
 901			Parts: []message.ContentPart{
 902				message.TextContent{Text: summary},
 903				message.Finish{
 904					Reason: message.FinishReasonEndTurn,
 905					Time:   time.Now().Unix(),
 906				},
 907			},
 908			Model:    a.summarizeProvider.Model().ID,
 909			Provider: a.summarizeProviderID,
 910		})
 911		if err != nil {
 912			event = AgentEvent{
 913				Type:  AgentEventTypeError,
 914				Error: fmt.Errorf("failed to create summary message: %w", err),
 915				Done:  true,
 916			}
 917
 918			a.Publish(pubsub.CreatedEvent, event)
 919			return
 920		}
 921		oldSession.SummaryMessageID = msg.ID
 922		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 923		oldSession.PromptTokens = 0
 924		model := a.summarizeProvider.Model()
 925		usage := finalResponse.Usage
 926		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 927			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 928			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 929			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 930		oldSession.Cost += cost
 931		_, err = a.sessions.Save(summarizeCtx, oldSession)
 932		if err != nil {
 933			event = AgentEvent{
 934				Type:  AgentEventTypeError,
 935				Error: fmt.Errorf("failed to save session: %w", err),
 936				Done:  true,
 937			}
 938			a.Publish(pubsub.CreatedEvent, event)
 939		}
 940
 941		event = AgentEvent{
 942			Type:      AgentEventTypeSummarize,
 943			SessionID: oldSession.ID,
 944			Progress:  "Summary complete",
 945			Done:      true,
 946		}
 947		a.Publish(pubsub.CreatedEvent, event)
 948		// Send final success event with the new session ID
 949	}()
 950
 951	return nil
 952}
 953
 954func (a *agent) ClearQueue(sessionID string) {
 955	if a.QueuedPrompts(sessionID) > 0 {
 956		slog.Info("Clearing queued prompts", "session_id", sessionID)
 957		a.promptQueue.Del(sessionID)
 958	}
 959}
 960
 961func (a *agent) CancelAll() {
 962	if !a.IsBusy() {
 963		return
 964	}
 965	for key := range a.activeRequests.Seq2() {
 966		a.Cancel(key) // key is sessionID
 967	}
 968
 969	timeout := time.After(5 * time.Second)
 970	for a.IsBusy() {
 971		select {
 972		case <-timeout:
 973			return
 974		default:
 975			time.Sleep(200 * time.Millisecond)
 976		}
 977	}
 978}
 979
 980func (a *agent) UpdateModel() error {
 981	cfg := config.Get()
 982
 983	// Get current provider configuration
 984	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 985	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 986		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 987	}
 988
 989	// Check if provider has changed
 990	if string(currentProviderCfg.ID) != a.providerID {
 991		// Provider changed, need to recreate the main provider
 992		model := cfg.GetModelByType(a.agentCfg.Model)
 993		if model.ID == "" {
 994			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 995		}
 996
 997		promptID := agentPromptMap[a.agentCfg.ID]
 998		if promptID == "" {
 999			promptID = prompt.PromptDefault
1000		}
1001
1002		opts := []provider.ProviderClientOption{
1003			provider.WithModel(a.agentCfg.Model),
1004			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1005		}
1006
1007		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1008		if err != nil {
1009			return fmt.Errorf("failed to create new provider: %w", err)
1010		}
1011
1012		// Update the provider and provider ID
1013		a.provider = newProvider
1014		a.providerID = string(currentProviderCfg.ID)
1015	}
1016
1017	// Check if providers have changed for title (small) and summarize (large)
1018	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1019	var smallModelProviderCfg config.ProviderConfig
1020	for p := range cfg.Providers.Seq() {
1021		if p.ID == smallModelCfg.Provider {
1022			smallModelProviderCfg = p
1023			break
1024		}
1025	}
1026	if smallModelProviderCfg.ID == "" {
1027		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1028	}
1029
1030	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1031	var largeModelProviderCfg config.ProviderConfig
1032	for p := range cfg.Providers.Seq() {
1033		if p.ID == largeModelCfg.Provider {
1034			largeModelProviderCfg = p
1035			break
1036		}
1037	}
1038	if largeModelProviderCfg.ID == "" {
1039		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1040	}
1041
1042	var maxTitleTokens int64 = 40
1043
1044	// if the max output is too low for the gemini provider it won't return anything
1045	if smallModelCfg.Provider == "gemini" {
1046		maxTitleTokens = 1000
1047	}
1048	// Recreate title provider
1049	titleOpts := []provider.ProviderClientOption{
1050		provider.WithModel(config.SelectedModelTypeSmall),
1051		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1052		provider.WithMaxTokens(maxTitleTokens),
1053	}
1054	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1055	if err != nil {
1056		return fmt.Errorf("failed to create new title provider: %w", err)
1057	}
1058	a.titleProvider = newTitleProvider
1059
1060	// Recreate summarize provider if provider changed (now large model)
1061	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1062		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1063		if largeModel == nil {
1064			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1065		}
1066		summarizeOpts := []provider.ProviderClientOption{
1067			provider.WithModel(config.SelectedModelTypeLarge),
1068			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1069		}
1070		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1071		if err != nil {
1072			return fmt.Errorf("failed to create new summarize provider: %w", err)
1073		}
1074		a.summarizeProvider = newSummarizeProvider
1075		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1076	}
1077
1078	return nil
1079}