agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75	// We need this to be able to update it when model changes
  76	agentToolFn func() (tools.BaseTool, error)
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86	promptQueue    *csync.Map[string, []string]
  87}
  88
  89var agentPromptMap = map[string]prompt.PromptID{
  90	"coder": prompt.PromptCoder,
  91	"task":  prompt.PromptTask,
  92}
  93
  94func NewAgent(
  95	ctx context.Context,
  96	agentCfg config.Agent,
  97	// These services are needed in the tools
  98	permissions permission.Service,
  99	sessions session.Service,
 100	messages message.Service,
 101	history history.Service,
 102	lspClients *csync.Map[string, *lsp.Client],
 103) (Service, error) {
 104	cfg := config.Get()
 105
 106	var agentToolFn func() (tools.BaseTool, error)
 107	if agentCfg.ID == "coder" {
 108		agentToolFn = func() (tools.BaseTool, error) {
 109			taskAgentCfg := config.Get().Agents["task"]
 110			if taskAgentCfg.ID == "" {
 111				return nil, fmt.Errorf("task agent not found in config")
 112			}
 113			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 114			if err != nil {
 115				return nil, fmt.Errorf("failed to create task agent: %w", err)
 116			}
 117			return NewAgentTool(taskAgent, sessions, messages), nil
 118		}
 119	}
 120
 121	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 122	if providerCfg == nil {
 123		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 124	}
 125	model := config.Get().GetModelByType(agentCfg.Model)
 126
 127	if model == nil {
 128		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 129	}
 130
 131	promptID := agentPromptMap[agentCfg.ID]
 132	if promptID == "" {
 133		promptID = prompt.PromptDefault
 134	}
 135	opts := []provider.ProviderClientOption{
 136		provider.WithModel(agentCfg.Model),
 137		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 138	}
 139	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 140	if err != nil {
 141		return nil, err
 142	}
 143
 144	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 145	var smallModelProviderCfg *config.ProviderConfig
 146	if smallModelCfg.Provider == providerCfg.ID {
 147		smallModelProviderCfg = providerCfg
 148	} else {
 149		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 150
 151		if smallModelProviderCfg.ID == "" {
 152			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 153		}
 154	}
 155	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 156	if smallModel.ID == "" {
 157		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 158	}
 159
 160	titleOpts := []provider.ProviderClientOption{
 161		provider.WithModel(config.SelectedModelTypeSmall),
 162		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 163	}
 164	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 165	if err != nil {
 166		return nil, err
 167	}
 168
 169	summarizeOpts := []provider.ProviderClientOption{
 170		provider.WithModel(config.SelectedModelTypeLarge),
 171		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 172	}
 173	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 174	if err != nil {
 175		return nil, err
 176	}
 177
 178	toolFn := func() []tools.BaseTool {
 179		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 180		defer func() {
 181			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 182		}()
 183
 184		cwd := cfg.WorkingDir()
 185		allTools := []tools.BaseTool{
 186			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 187			tools.NewDownloadTool(permissions, cwd),
 188			tools.NewEditTool(lspClients, permissions, history, cwd),
 189			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 190			tools.NewFetchTool(permissions, cwd),
 191			tools.NewGlobTool(cwd),
 192			tools.NewGrepTool(cwd),
 193			tools.NewLsTool(permissions, cwd),
 194			tools.NewSourcegraphTool(),
 195			tools.NewViewTool(lspClients, permissions, cwd),
 196			tools.NewWriteTool(lspClients, permissions, history, cwd),
 197		}
 198
 199		mcpToolsOnce.Do(func() {
 200			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 201		})
 202
 203		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 204			if agentCfg.ID == "coder" {
 205				t = append(t, mcpTools...)
 206				if lspClients.Len() > 0 {
 207					t = append(t, tools.NewDiagnosticsTool(lspClients))
 208				}
 209			}
 210			return t
 211		}
 212
 213		if agentCfg.AllowedTools == nil {
 214			return withCoderTools(allTools)
 215		}
 216
 217		var filteredTools []tools.BaseTool
 218		for _, tool := range allTools {
 219			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 220				filteredTools = append(filteredTools, tool)
 221			}
 222		}
 223		return withCoderTools(filteredTools)
 224	}
 225
 226	return &agent{
 227		Broker:              pubsub.NewBroker[AgentEvent](),
 228		agentCfg:            agentCfg,
 229		provider:            agentProvider,
 230		providerID:          string(providerCfg.ID),
 231		messages:            messages,
 232		sessions:            sessions,
 233		titleProvider:       titleProvider,
 234		summarizeProvider:   summarizeProvider,
 235		summarizeProviderID: string(providerCfg.ID),
 236		agentToolFn:         agentToolFn,
 237		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 238		tools:               csync.NewLazySlice(toolFn),
 239		promptQueue:         csync.NewMap[string, []string](),
 240	}, nil
 241}
 242
 243func (a *agent) Model() catwalk.Model {
 244	return *config.Get().GetModelByType(a.agentCfg.Model)
 245}
 246
 247func (a *agent) Cancel(sessionID string) {
 248	// Cancel regular requests
 249	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 250		slog.Info("Request cancellation initiated", "session_id", sessionID)
 251		cancel()
 252	}
 253
 254	// Also check for summarize requests
 255	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 256		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 257		cancel()
 258	}
 259
 260	if a.QueuedPrompts(sessionID) > 0 {
 261		slog.Info("Clearing queued prompts", "session_id", sessionID)
 262		a.promptQueue.Del(sessionID)
 263	}
 264}
 265
 266func (a *agent) IsBusy() bool {
 267	var busy bool
 268	for cancelFunc := range a.activeRequests.Seq() {
 269		if cancelFunc != nil {
 270			busy = true
 271			break
 272		}
 273	}
 274	return busy
 275}
 276
 277func (a *agent) IsSessionBusy(sessionID string) bool {
 278	_, busy := a.activeRequests.Get(sessionID)
 279	return busy
 280}
 281
 282func (a *agent) QueuedPrompts(sessionID string) int {
 283	l, ok := a.promptQueue.Get(sessionID)
 284	if !ok {
 285		return 0
 286	}
 287	return len(l)
 288}
 289
 290func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 291	if content == "" {
 292		return nil
 293	}
 294	if a.titleProvider == nil {
 295		return nil
 296	}
 297	session, err := a.sessions.Get(ctx, sessionID)
 298	if err != nil {
 299		return err
 300	}
 301	parts := []message.ContentPart{message.TextContent{
 302		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 303	}}
 304
 305	// Use streaming approach like summarization
 306	response := a.titleProvider.StreamResponse(
 307		ctx,
 308		[]message.Message{
 309			{
 310				Role:  message.User,
 311				Parts: parts,
 312			},
 313		},
 314		nil,
 315	)
 316
 317	var finalResponse *provider.ProviderResponse
 318	for r := range response {
 319		if r.Error != nil {
 320			return r.Error
 321		}
 322		finalResponse = r.Response
 323	}
 324
 325	if finalResponse == nil {
 326		return fmt.Errorf("no response received from title provider")
 327	}
 328
 329	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 330
 331	if idx := strings.Index(title, "</think>"); idx > 0 {
 332		title = title[idx+len("</think>"):]
 333	}
 334
 335	title = strings.TrimSpace(title)
 336	if title == "" {
 337		return nil
 338	}
 339
 340	session.Title = title
 341	_, err = a.sessions.Save(ctx, session)
 342	return err
 343}
 344
 345func (a *agent) err(err error) AgentEvent {
 346	return AgentEvent{
 347		Type:  AgentEventTypeError,
 348		Error: err,
 349	}
 350}
 351
 352func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 353	if !a.Model().SupportsImages && attachments != nil {
 354		attachments = nil
 355	}
 356	events := make(chan AgentEvent, 1)
 357	if a.IsSessionBusy(sessionID) {
 358		existing, ok := a.promptQueue.Get(sessionID)
 359		if !ok {
 360			existing = []string{}
 361		}
 362		existing = append(existing, content)
 363		a.promptQueue.Set(sessionID, existing)
 364		return nil, nil
 365	}
 366
 367	genCtx, cancel := context.WithCancel(ctx)
 368
 369	a.activeRequests.Set(sessionID, cancel)
 370	go func() {
 371		slog.Debug("Request started", "sessionID", sessionID)
 372		defer log.RecoverPanic("agent.Run", func() {
 373			events <- a.err(fmt.Errorf("panic while running the agent"))
 374		})
 375		var attachmentParts []message.ContentPart
 376		for _, attachment := range attachments {
 377			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 378		}
 379		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 380		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 381			slog.Error(result.Error.Error())
 382		}
 383		slog.Debug("Request completed", "sessionID", sessionID)
 384		a.activeRequests.Del(sessionID)
 385		cancel()
 386		a.Publish(pubsub.CreatedEvent, result)
 387		events <- result
 388		close(events)
 389	}()
 390	return events, nil
 391}
 392
 393func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 394	cfg := config.Get()
 395	// List existing messages; if none, start title generation asynchronously.
 396	msgs, err := a.messages.List(ctx, sessionID)
 397	if err != nil {
 398		return a.err(fmt.Errorf("failed to list messages: %w", err))
 399	}
 400	if len(msgs) == 0 {
 401		go func() {
 402			defer log.RecoverPanic("agent.Run", func() {
 403				slog.Error("panic while generating title")
 404			})
 405			titleErr := a.generateTitle(ctx, sessionID, content)
 406			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 407				slog.Error("failed to generate title", "error", titleErr)
 408			}
 409		}()
 410	}
 411	session, err := a.sessions.Get(ctx, sessionID)
 412	if err != nil {
 413		return a.err(fmt.Errorf("failed to get session: %w", err))
 414	}
 415	if session.SummaryMessageID != "" {
 416		summaryMsgInex := -1
 417		for i, msg := range msgs {
 418			if msg.ID == session.SummaryMessageID {
 419				summaryMsgInex = i
 420				break
 421			}
 422		}
 423		if summaryMsgInex != -1 {
 424			msgs = msgs[summaryMsgInex:]
 425			msgs[0].Role = message.User
 426		}
 427	}
 428
 429	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 430	if err != nil {
 431		return a.err(fmt.Errorf("failed to create user message: %w", err))
 432	}
 433	// Append the new user message to the conversation history.
 434	msgHistory := append(msgs, userMsg)
 435
 436	for {
 437		// Check for cancellation before each iteration
 438		select {
 439		case <-ctx.Done():
 440			return a.err(ctx.Err())
 441		default:
 442			// Continue processing
 443		}
 444		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 445		if err != nil {
 446			if errors.Is(err, context.Canceled) {
 447				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 448				a.messages.Update(context.Background(), agentMessage)
 449				return a.err(ErrRequestCancelled)
 450			}
 451			return a.err(fmt.Errorf("failed to process events: %w", err))
 452		}
 453		if cfg.Options.Debug {
 454			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 455		}
 456		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 457			// We are not done, we need to respond with the tool response
 458			msgHistory = append(msgHistory, agentMessage, *toolResults)
 459			// If there are queued prompts, process the next one
 460			nextPrompt, ok := a.promptQueue.Take(sessionID)
 461			if ok {
 462				for _, prompt := range nextPrompt {
 463					// Create a new user message for the queued prompt
 464					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 465					if err != nil {
 466						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 467					}
 468					// Append the new user message to the conversation history
 469					msgHistory = append(msgHistory, userMsg)
 470				}
 471			}
 472
 473			continue
 474		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 475			queuePrompts, ok := a.promptQueue.Take(sessionID)
 476			if ok {
 477				for _, prompt := range queuePrompts {
 478					if prompt == "" {
 479						continue
 480					}
 481					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 482					if err != nil {
 483						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 484					}
 485					msgHistory = append(msgHistory, userMsg)
 486				}
 487				continue
 488			}
 489		}
 490		if agentMessage.FinishReason() == "" {
 491			// Kujtim: could not track down where this is happening but this means its cancelled
 492			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 493			_ = a.messages.Update(context.Background(), agentMessage)
 494			return a.err(ErrRequestCancelled)
 495		}
 496		return AgentEvent{
 497			Type:    AgentEventTypeResponse,
 498			Message: agentMessage,
 499			Done:    true,
 500		}
 501	}
 502}
 503
 504func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 505	parts := []message.ContentPart{message.TextContent{Text: content}}
 506	parts = append(parts, attachmentParts...)
 507	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 508		Role:  message.User,
 509		Parts: parts,
 510	})
 511}
 512
 513func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 514	allTools := slices.Collect(a.tools.Seq())
 515	if a.agentToolFn != nil {
 516		agentTool, agentToolErr := a.agentToolFn()
 517		if agentToolErr != nil {
 518			return nil, agentToolErr
 519		}
 520		allTools = append(allTools, agentTool)
 521	}
 522	return allTools, nil
 523}
 524
 525func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 526	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 527
 528	// Create the assistant message first so the spinner shows immediately
 529	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 530		Role:     message.Assistant,
 531		Parts:    []message.ContentPart{},
 532		Model:    a.Model().ID,
 533		Provider: a.providerID,
 534	})
 535	if err != nil {
 536		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 537	}
 538
 539	allTools, toolsErr := a.getAllTools()
 540	if toolsErr != nil {
 541		return assistantMsg, nil, toolsErr
 542	}
 543	// Now collect tools (which may block on MCP initialization)
 544	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 545
 546	// Add the session and message ID into the context if needed by tools.
 547	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 548
 549	// Process each event in the stream.
 550	for event := range eventChan {
 551		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 552			if errors.Is(processErr, context.Canceled) {
 553				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 554			} else {
 555				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 556			}
 557			return assistantMsg, nil, processErr
 558		}
 559		if ctx.Err() != nil {
 560			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 561			return assistantMsg, nil, ctx.Err()
 562		}
 563	}
 564
 565	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 566	toolCalls := assistantMsg.ToolCalls()
 567	for i, toolCall := range toolCalls {
 568		select {
 569		case <-ctx.Done():
 570			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 571			// Make all future tool calls cancelled
 572			for j := i; j < len(toolCalls); j++ {
 573				toolResults[j] = message.ToolResult{
 574					ToolCallID: toolCalls[j].ID,
 575					Content:    "Tool execution canceled by user",
 576					IsError:    true,
 577				}
 578			}
 579			goto out
 580		default:
 581			// Continue processing
 582			var tool tools.BaseTool
 583			allTools, _ := a.getAllTools()
 584			for _, availableTool := range allTools {
 585				if availableTool.Info().Name == toolCall.Name {
 586					tool = availableTool
 587					break
 588				}
 589			}
 590
 591			// Tool not found
 592			if tool == nil {
 593				toolResults[i] = message.ToolResult{
 594					ToolCallID: toolCall.ID,
 595					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 596					IsError:    true,
 597				}
 598				continue
 599			}
 600
 601			// Run tool in goroutine to allow cancellation
 602			type toolExecResult struct {
 603				response tools.ToolResponse
 604				err      error
 605			}
 606			resultChan := make(chan toolExecResult, 1)
 607
 608			go func() {
 609				response, err := tool.Run(ctx, tools.ToolCall{
 610					ID:    toolCall.ID,
 611					Name:  toolCall.Name,
 612					Input: toolCall.Input,
 613				})
 614				resultChan <- toolExecResult{response: response, err: err}
 615			}()
 616
 617			var toolResponse tools.ToolResponse
 618			var toolErr error
 619
 620			select {
 621			case <-ctx.Done():
 622				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 623				// Mark remaining tool calls as cancelled
 624				for j := i; j < len(toolCalls); j++ {
 625					toolResults[j] = message.ToolResult{
 626						ToolCallID: toolCalls[j].ID,
 627						Content:    "Tool execution canceled by user",
 628						IsError:    true,
 629					}
 630				}
 631				goto out
 632			case result := <-resultChan:
 633				toolResponse = result.response
 634				toolErr = result.err
 635			}
 636
 637			if toolErr != nil {
 638				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 639				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 640					toolResults[i] = message.ToolResult{
 641						ToolCallID: toolCall.ID,
 642						Content:    "Permission denied",
 643						IsError:    true,
 644					}
 645					for j := i + 1; j < len(toolCalls); j++ {
 646						toolResults[j] = message.ToolResult{
 647							ToolCallID: toolCalls[j].ID,
 648							Content:    "Tool execution canceled by user",
 649							IsError:    true,
 650						}
 651					}
 652					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 653					break
 654				}
 655			}
 656			toolResults[i] = message.ToolResult{
 657				ToolCallID: toolCall.ID,
 658				Content:    toolResponse.Content,
 659				Metadata:   toolResponse.Metadata,
 660				IsError:    toolResponse.IsError,
 661			}
 662		}
 663	}
 664out:
 665	if len(toolResults) == 0 {
 666		return assistantMsg, nil, nil
 667	}
 668	parts := make([]message.ContentPart, 0)
 669	for _, tr := range toolResults {
 670		parts = append(parts, tr)
 671	}
 672	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 673		Role:     message.Tool,
 674		Parts:    parts,
 675		Provider: a.providerID,
 676	})
 677	if err != nil {
 678		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 679	}
 680
 681	return assistantMsg, &msg, err
 682}
 683
 684func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 685	msg.AddFinish(finishReason, message, details)
 686	_ = a.messages.Update(ctx, *msg)
 687}
 688
 689func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 690	select {
 691	case <-ctx.Done():
 692		return ctx.Err()
 693	default:
 694		// Continue processing.
 695	}
 696
 697	switch event.Type {
 698	case provider.EventThinkingDelta:
 699		assistantMsg.AppendReasoningContent(event.Thinking)
 700		return a.messages.Update(ctx, *assistantMsg)
 701	case provider.EventSignatureDelta:
 702		assistantMsg.AppendReasoningSignature(event.Signature)
 703		return a.messages.Update(ctx, *assistantMsg)
 704	case provider.EventContentDelta:
 705		assistantMsg.FinishThinking()
 706		assistantMsg.AppendContent(event.Content)
 707		return a.messages.Update(ctx, *assistantMsg)
 708	case provider.EventToolUseStart:
 709		assistantMsg.FinishThinking()
 710		slog.Info("Tool call started", "toolCall", event.ToolCall)
 711		assistantMsg.AddToolCall(*event.ToolCall)
 712		return a.messages.Update(ctx, *assistantMsg)
 713	case provider.EventToolUseDelta:
 714		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 715		return a.messages.Update(ctx, *assistantMsg)
 716	case provider.EventToolUseStop:
 717		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 718		assistantMsg.FinishToolCall(event.ToolCall.ID)
 719		return a.messages.Update(ctx, *assistantMsg)
 720	case provider.EventError:
 721		return event.Error
 722	case provider.EventComplete:
 723		assistantMsg.FinishThinking()
 724		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 725		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 726		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 727			return fmt.Errorf("failed to update message: %w", err)
 728		}
 729		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 730	}
 731
 732	return nil
 733}
 734
 735func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 736	sess, err := a.sessions.Get(ctx, sessionID)
 737	if err != nil {
 738		return fmt.Errorf("failed to get session: %w", err)
 739	}
 740
 741	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 742		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 743		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 744		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 745
 746	sess.Cost += cost
 747	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 748	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 749
 750	_, err = a.sessions.Save(ctx, sess)
 751	if err != nil {
 752		return fmt.Errorf("failed to save session: %w", err)
 753	}
 754	return nil
 755}
 756
 757func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 758	if a.summarizeProvider == nil {
 759		return fmt.Errorf("summarize provider not available")
 760	}
 761
 762	// Check if session is busy
 763	if a.IsSessionBusy(sessionID) {
 764		return ErrSessionBusy
 765	}
 766
 767	// Create a new context with cancellation
 768	summarizeCtx, cancel := context.WithCancel(ctx)
 769
 770	// Store the cancel function in activeRequests to allow cancellation
 771	a.activeRequests.Set(sessionID+"-summarize", cancel)
 772
 773	go func() {
 774		defer a.activeRequests.Del(sessionID + "-summarize")
 775		defer cancel()
 776		event := AgentEvent{
 777			Type:     AgentEventTypeSummarize,
 778			Progress: "Starting summarization...",
 779		}
 780
 781		a.Publish(pubsub.CreatedEvent, event)
 782		// Get all messages from the session
 783		msgs, err := a.messages.List(summarizeCtx, sessionID)
 784		if err != nil {
 785			event = AgentEvent{
 786				Type:  AgentEventTypeError,
 787				Error: fmt.Errorf("failed to list messages: %w", err),
 788				Done:  true,
 789			}
 790			a.Publish(pubsub.CreatedEvent, event)
 791			return
 792		}
 793		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 794
 795		if len(msgs) == 0 {
 796			event = AgentEvent{
 797				Type:  AgentEventTypeError,
 798				Error: fmt.Errorf("no messages to summarize"),
 799				Done:  true,
 800			}
 801			a.Publish(pubsub.CreatedEvent, event)
 802			return
 803		}
 804
 805		event = AgentEvent{
 806			Type:     AgentEventTypeSummarize,
 807			Progress: "Analyzing conversation...",
 808		}
 809		a.Publish(pubsub.CreatedEvent, event)
 810
 811		// Add a system message to guide the summarization
 812		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 813
 814		// Create a new message with the summarize prompt
 815		promptMsg := message.Message{
 816			Role:  message.User,
 817			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 818		}
 819
 820		// Append the prompt to the messages
 821		msgsWithPrompt := append(msgs, promptMsg)
 822
 823		event = AgentEvent{
 824			Type:     AgentEventTypeSummarize,
 825			Progress: "Generating summary...",
 826		}
 827
 828		a.Publish(pubsub.CreatedEvent, event)
 829
 830		// Send the messages to the summarize provider
 831		response := a.summarizeProvider.StreamResponse(
 832			summarizeCtx,
 833			msgsWithPrompt,
 834			nil,
 835		)
 836		var finalResponse *provider.ProviderResponse
 837		for r := range response {
 838			if r.Error != nil {
 839				event = AgentEvent{
 840					Type:  AgentEventTypeError,
 841					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 842					Done:  true,
 843				}
 844				a.Publish(pubsub.CreatedEvent, event)
 845				return
 846			}
 847			finalResponse = r.Response
 848		}
 849
 850		summary := strings.TrimSpace(finalResponse.Content)
 851		if summary == "" {
 852			event = AgentEvent{
 853				Type:  AgentEventTypeError,
 854				Error: fmt.Errorf("empty summary returned"),
 855				Done:  true,
 856			}
 857			a.Publish(pubsub.CreatedEvent, event)
 858			return
 859		}
 860		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 861		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 862		event = AgentEvent{
 863			Type:     AgentEventTypeSummarize,
 864			Progress: "Creating new session...",
 865		}
 866
 867		a.Publish(pubsub.CreatedEvent, event)
 868		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 869		if err != nil {
 870			event = AgentEvent{
 871				Type:  AgentEventTypeError,
 872				Error: fmt.Errorf("failed to get session: %w", err),
 873				Done:  true,
 874			}
 875
 876			a.Publish(pubsub.CreatedEvent, event)
 877			return
 878		}
 879		// Create a message in the new session with the summary
 880		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 881			Role: message.Assistant,
 882			Parts: []message.ContentPart{
 883				message.TextContent{Text: summary},
 884				message.Finish{
 885					Reason: message.FinishReasonEndTurn,
 886					Time:   time.Now().Unix(),
 887				},
 888			},
 889			Model:    a.summarizeProvider.Model().ID,
 890			Provider: a.summarizeProviderID,
 891		})
 892		if err != nil {
 893			event = AgentEvent{
 894				Type:  AgentEventTypeError,
 895				Error: fmt.Errorf("failed to create summary message: %w", err),
 896				Done:  true,
 897			}
 898
 899			a.Publish(pubsub.CreatedEvent, event)
 900			return
 901		}
 902		oldSession.SummaryMessageID = msg.ID
 903		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 904		oldSession.PromptTokens = 0
 905		model := a.summarizeProvider.Model()
 906		usage := finalResponse.Usage
 907		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 908			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 909			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 910			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 911		oldSession.Cost += cost
 912		_, err = a.sessions.Save(summarizeCtx, oldSession)
 913		if err != nil {
 914			event = AgentEvent{
 915				Type:  AgentEventTypeError,
 916				Error: fmt.Errorf("failed to save session: %w", err),
 917				Done:  true,
 918			}
 919			a.Publish(pubsub.CreatedEvent, event)
 920		}
 921
 922		event = AgentEvent{
 923			Type:      AgentEventTypeSummarize,
 924			SessionID: oldSession.ID,
 925			Progress:  "Summary complete",
 926			Done:      true,
 927		}
 928		a.Publish(pubsub.CreatedEvent, event)
 929		// Send final success event with the new session ID
 930	}()
 931
 932	return nil
 933}
 934
 935func (a *agent) ClearQueue(sessionID string) {
 936	if a.QueuedPrompts(sessionID) > 0 {
 937		slog.Info("Clearing queued prompts", "session_id", sessionID)
 938		a.promptQueue.Del(sessionID)
 939	}
 940}
 941
 942func (a *agent) CancelAll() {
 943	if !a.IsBusy() {
 944		return
 945	}
 946	for key := range a.activeRequests.Seq2() {
 947		a.Cancel(key) // key is sessionID
 948	}
 949
 950	timeout := time.After(5 * time.Second)
 951	for a.IsBusy() {
 952		select {
 953		case <-timeout:
 954			return
 955		default:
 956			time.Sleep(200 * time.Millisecond)
 957		}
 958	}
 959}
 960
 961func (a *agent) UpdateModel() error {
 962	cfg := config.Get()
 963
 964	// Get current provider configuration
 965	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 966	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 967		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 968	}
 969
 970	// Check if provider has changed
 971	if string(currentProviderCfg.ID) != a.providerID {
 972		// Provider changed, need to recreate the main provider
 973		model := cfg.GetModelByType(a.agentCfg.Model)
 974		if model.ID == "" {
 975			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 976		}
 977
 978		promptID := agentPromptMap[a.agentCfg.ID]
 979		if promptID == "" {
 980			promptID = prompt.PromptDefault
 981		}
 982
 983		opts := []provider.ProviderClientOption{
 984			provider.WithModel(a.agentCfg.Model),
 985			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 986		}
 987
 988		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 989		if err != nil {
 990			return fmt.Errorf("failed to create new provider: %w", err)
 991		}
 992
 993		// Update the provider and provider ID
 994		a.provider = newProvider
 995		a.providerID = string(currentProviderCfg.ID)
 996	}
 997
 998	// Check if providers have changed for title (small) and summarize (large)
 999	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1000	var smallModelProviderCfg config.ProviderConfig
1001	for p := range cfg.Providers.Seq() {
1002		if p.ID == smallModelCfg.Provider {
1003			smallModelProviderCfg = p
1004			break
1005		}
1006	}
1007	if smallModelProviderCfg.ID == "" {
1008		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1009	}
1010
1011	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1012	var largeModelProviderCfg config.ProviderConfig
1013	for p := range cfg.Providers.Seq() {
1014		if p.ID == largeModelCfg.Provider {
1015			largeModelProviderCfg = p
1016			break
1017		}
1018	}
1019	if largeModelProviderCfg.ID == "" {
1020		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1021	}
1022
1023	var maxTitleTokens int64 = 40
1024
1025	// if the max output is too low for the gemini provider it won't return anything
1026	if smallModelCfg.Provider == "gemini" {
1027		maxTitleTokens = 1000
1028	}
1029	// Recreate title provider
1030	titleOpts := []provider.ProviderClientOption{
1031		provider.WithModel(config.SelectedModelTypeSmall),
1032		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1033		provider.WithMaxTokens(maxTitleTokens),
1034	}
1035	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1036	if err != nil {
1037		return fmt.Errorf("failed to create new title provider: %w", err)
1038	}
1039	a.titleProvider = newTitleProvider
1040
1041	// Recreate summarize provider if provider changed (now large model)
1042	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1043		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1044		if largeModel == nil {
1045			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1046		}
1047		summarizeOpts := []provider.ProviderClientOption{
1048			provider.WithModel(config.SelectedModelTypeLarge),
1049			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1050		}
1051		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1052		if err != nil {
1053			return fmt.Errorf("failed to create new summarize provider: %w", err)
1054		}
1055		a.summarizeProvider = newSummarizeProvider
1056		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1057	}
1058
1059	return nil
1060}