agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75	// We need this to be able to update it when model changes
  76	agentToolFn func() (tools.BaseTool, error)
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86
  87	promptQueue *csync.Map[string, []string]
  88}
  89
  90var agentPromptMap = map[string]prompt.PromptID{
  91	"coder": prompt.PromptCoder,
  92	"task":  prompt.PromptTask,
  93}
  94
  95func NewAgent(
  96	ctx context.Context,
  97	agentCfg config.Agent,
  98	// These services are needed in the tools
  99	permissions permission.Service,
 100	sessions session.Service,
 101	messages message.Service,
 102	history history.Service,
 103	lspClients map[string]*lsp.Client,
 104) (Service, error) {
 105	cfg := config.Get()
 106
 107	var agentToolFn func() (tools.BaseTool, error)
 108	if agentCfg.ID == "coder" {
 109		agentToolFn = func() (tools.BaseTool, error) {
 110			taskAgentCfg := config.Get().Agents["task"]
 111			if taskAgentCfg.ID == "" {
 112				return nil, fmt.Errorf("task agent not found in config")
 113			}
 114			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 115			if err != nil {
 116				return nil, fmt.Errorf("failed to create task agent: %w", err)
 117			}
 118			return NewAgentTool(taskAgent, sessions, messages), nil
 119		}
 120	}
 121
 122	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 123	if providerCfg == nil {
 124		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 125	}
 126	model := config.Get().GetModelByType(agentCfg.Model)
 127
 128	if model == nil {
 129		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 130	}
 131
 132	promptID := agentPromptMap[agentCfg.ID]
 133	if promptID == "" {
 134		promptID = prompt.PromptDefault
 135	}
 136	opts := []provider.ProviderClientOption{
 137		provider.WithModel(agentCfg.Model),
 138		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 139	}
 140	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 141	if err != nil {
 142		return nil, err
 143	}
 144
 145	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 146	var smallModelProviderCfg *config.ProviderConfig
 147	if smallModelCfg.Provider == providerCfg.ID {
 148		smallModelProviderCfg = providerCfg
 149	} else {
 150		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 151
 152		if smallModelProviderCfg.ID == "" {
 153			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 154		}
 155	}
 156	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 157	if smallModel.ID == "" {
 158		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 159	}
 160
 161	titleOpts := []provider.ProviderClientOption{
 162		provider.WithModel(config.SelectedModelTypeSmall),
 163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 164	}
 165	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 166	if err != nil {
 167		return nil, err
 168	}
 169
 170	summarizeOpts := []provider.ProviderClientOption{
 171		provider.WithModel(config.SelectedModelTypeLarge),
 172		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 173	}
 174	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 175	if err != nil {
 176		return nil, err
 177	}
 178
 179	toolFn := func() []tools.BaseTool {
 180		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 181		defer func() {
 182			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 183		}()
 184
 185		cwd := cfg.WorkingDir()
 186		allTools := []tools.BaseTool{
 187			tools.NewBashTool(permissions, cwd),
 188			tools.NewDownloadTool(permissions, cwd),
 189			tools.NewEditTool(lspClients, permissions, history, cwd),
 190			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 191			tools.NewFetchTool(permissions, cwd),
 192			tools.NewGlobTool(cwd),
 193			tools.NewGrepTool(cwd),
 194			tools.NewLsTool(permissions, cwd),
 195			tools.NewSourcegraphTool(),
 196			tools.NewViewTool(lspClients, permissions, cwd),
 197			tools.NewWriteTool(lspClients, permissions, history, cwd),
 198		}
 199
 200		mcpToolsOnce.Do(func() {
 201			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 202		})
 203
 204		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 205			if agentCfg.ID == "coder" {
 206				t = append(t, mcpTools...)
 207				if len(lspClients) > 0 {
 208					t = append(t, tools.NewDiagnosticsTool(lspClients))
 209				}
 210			}
 211			return t
 212		}
 213
 214		if agentCfg.AllowedTools == nil {
 215			return withCoderTools(allTools)
 216		}
 217
 218		var filteredTools []tools.BaseTool
 219		for _, tool := range allTools {
 220			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 221				filteredTools = append(filteredTools, tool)
 222			}
 223		}
 224
 225		if agentCfg.ID == "coder" {
 226			filteredTools = append(filteredTools, mcpTools...)
 227		}
 228		return withCoderTools(filteredTools)
 229	}
 230
 231	return &agent{
 232		Broker:              pubsub.NewBroker[AgentEvent](),
 233		agentCfg:            agentCfg,
 234		provider:            agentProvider,
 235		providerID:          string(providerCfg.ID),
 236		messages:            messages,
 237		sessions:            sessions,
 238		titleProvider:       titleProvider,
 239		summarizeProvider:   summarizeProvider,
 240		summarizeProviderID: string(providerCfg.ID),
 241		agentToolFn:         agentToolFn,
 242		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 243		tools:               csync.NewLazySlice(toolFn),
 244		promptQueue:         csync.NewMap[string, []string](),
 245	}, nil
 246}
 247
 248func (a *agent) Model() catwalk.Model {
 249	return *config.Get().GetModelByType(a.agentCfg.Model)
 250}
 251
 252func (a *agent) Cancel(sessionID string) {
 253	// Cancel regular requests
 254	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 255		slog.Info("Request cancellation initiated", "session_id", sessionID)
 256		cancel()
 257	}
 258
 259	// Also check for summarize requests
 260	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 261		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 262		cancel()
 263	}
 264
 265	if a.QueuedPrompts(sessionID) > 0 {
 266		slog.Info("Clearing queued prompts", "session_id", sessionID)
 267		a.promptQueue.Del(sessionID)
 268	}
 269}
 270
 271func (a *agent) IsBusy() bool {
 272	var busy bool
 273	for cancelFunc := range a.activeRequests.Seq() {
 274		if cancelFunc != nil {
 275			busy = true
 276			break
 277		}
 278	}
 279	return busy
 280}
 281
 282func (a *agent) IsSessionBusy(sessionID string) bool {
 283	_, busy := a.activeRequests.Get(sessionID)
 284	return busy
 285}
 286
 287func (a *agent) QueuedPrompts(sessionID string) int {
 288	l, ok := a.promptQueue.Get(sessionID)
 289	if !ok {
 290		return 0
 291	}
 292	return len(l)
 293}
 294
 295func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 296	if content == "" {
 297		return nil
 298	}
 299	if a.titleProvider == nil {
 300		return nil
 301	}
 302	session, err := a.sessions.Get(ctx, sessionID)
 303	if err != nil {
 304		return err
 305	}
 306	parts := []message.ContentPart{message.TextContent{
 307		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 308	}}
 309
 310	// Use streaming approach like summarization
 311	response := a.titleProvider.StreamResponse(
 312		ctx,
 313		[]message.Message{
 314			{
 315				Role:  message.User,
 316				Parts: parts,
 317			},
 318		},
 319		nil,
 320	)
 321
 322	var finalResponse *provider.ProviderResponse
 323	for r := range response {
 324		if r.Error != nil {
 325			return r.Error
 326		}
 327		finalResponse = r.Response
 328	}
 329
 330	if finalResponse == nil {
 331		return fmt.Errorf("no response received from title provider")
 332	}
 333
 334	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 335	if title == "" {
 336		return nil
 337	}
 338
 339	session.Title = title
 340	_, err = a.sessions.Save(ctx, session)
 341	return err
 342}
 343
 344func (a *agent) err(err error) AgentEvent {
 345	return AgentEvent{
 346		Type:  AgentEventTypeError,
 347		Error: err,
 348	}
 349}
 350
 351func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 352	if !a.Model().SupportsImages && attachments != nil {
 353		attachments = nil
 354	}
 355	events := make(chan AgentEvent, 1)
 356	if a.IsSessionBusy(sessionID) {
 357		existing, ok := a.promptQueue.Get(sessionID)
 358		if !ok {
 359			existing = []string{}
 360		}
 361		existing = append(existing, content)
 362		a.promptQueue.Set(sessionID, existing)
 363		return nil, nil
 364	}
 365
 366	genCtx, cancel := context.WithCancel(ctx)
 367
 368	a.activeRequests.Set(sessionID, cancel)
 369	go func() {
 370		slog.Debug("Request started", "sessionID", sessionID)
 371		defer log.RecoverPanic("agent.Run", func() {
 372			events <- a.err(fmt.Errorf("panic while running the agent"))
 373		})
 374		var attachmentParts []message.ContentPart
 375		for _, attachment := range attachments {
 376			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 377		}
 378		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 379		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 380			slog.Error(result.Error.Error())
 381		}
 382		slog.Debug("Request completed", "sessionID", sessionID)
 383		a.activeRequests.Del(sessionID)
 384		cancel()
 385		a.Publish(pubsub.CreatedEvent, result)
 386		events <- result
 387		close(events)
 388	}()
 389	return events, nil
 390}
 391
 392func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 393	cfg := config.Get()
 394	// List existing messages; if none, start title generation asynchronously.
 395	msgs, err := a.messages.List(ctx, sessionID)
 396	if err != nil {
 397		return a.err(fmt.Errorf("failed to list messages: %w", err))
 398	}
 399	if len(msgs) == 0 {
 400		go func() {
 401			defer log.RecoverPanic("agent.Run", func() {
 402				slog.Error("panic while generating title")
 403			})
 404			titleErr := a.generateTitle(ctx, sessionID, content)
 405			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 406				slog.Error("failed to generate title", "error", titleErr)
 407			}
 408		}()
 409	}
 410	session, err := a.sessions.Get(ctx, sessionID)
 411	if err != nil {
 412		return a.err(fmt.Errorf("failed to get session: %w", err))
 413	}
 414	if session.SummaryMessageID != "" {
 415		summaryMsgInex := -1
 416		for i, msg := range msgs {
 417			if msg.ID == session.SummaryMessageID {
 418				summaryMsgInex = i
 419				break
 420			}
 421		}
 422		if summaryMsgInex != -1 {
 423			msgs = msgs[summaryMsgInex:]
 424			msgs[0].Role = message.User
 425		}
 426	}
 427
 428	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 429	if err != nil {
 430		return a.err(fmt.Errorf("failed to create user message: %w", err))
 431	}
 432	// Append the new user message to the conversation history.
 433	msgHistory := append(msgs, userMsg)
 434
 435	for {
 436		// Check for cancellation before each iteration
 437		select {
 438		case <-ctx.Done():
 439			return a.err(ctx.Err())
 440		default:
 441			// Continue processing
 442		}
 443		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 444		if err != nil {
 445			if errors.Is(err, context.Canceled) {
 446				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 447				a.messages.Update(context.Background(), agentMessage)
 448				return a.err(ErrRequestCancelled)
 449			}
 450			return a.err(fmt.Errorf("failed to process events: %w", err))
 451		}
 452		if cfg.Options.Debug {
 453			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 454		}
 455		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 456			// We are not done, we need to respond with the tool response
 457			msgHistory = append(msgHistory, agentMessage, *toolResults)
 458			// If there are queued prompts, process the next one
 459			nextPrompt, ok := a.promptQueue.Take(sessionID)
 460			if ok {
 461				for _, prompt := range nextPrompt {
 462					// Create a new user message for the queued prompt
 463					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 464					if err != nil {
 465						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 466					}
 467					// Append the new user message to the conversation history
 468					msgHistory = append(msgHistory, userMsg)
 469				}
 470			}
 471
 472			continue
 473		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 474			queuePrompts, ok := a.promptQueue.Take(sessionID)
 475			if ok {
 476				for _, prompt := range queuePrompts {
 477					if prompt == "" {
 478						continue
 479					}
 480					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 481					if err != nil {
 482						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 483					}
 484					msgHistory = append(msgHistory, userMsg)
 485				}
 486				continue
 487			}
 488		}
 489		if agentMessage.FinishReason() == "" {
 490			// Kujtim: could not track down where this is happening but this means its cancelled
 491			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 492			_ = a.messages.Update(context.Background(), agentMessage)
 493			return a.err(ErrRequestCancelled)
 494		}
 495		return AgentEvent{
 496			Type:    AgentEventTypeResponse,
 497			Message: agentMessage,
 498			Done:    true,
 499		}
 500	}
 501}
 502
 503func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 504	parts := []message.ContentPart{message.TextContent{Text: content}}
 505	parts = append(parts, attachmentParts...)
 506	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 507		Role:  message.User,
 508		Parts: parts,
 509	})
 510}
 511
 512func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 513	allTools := slices.Collect(a.tools.Seq())
 514	if a.agentToolFn != nil {
 515		agentTool, agentToolErr := a.agentToolFn()
 516		if agentToolErr != nil {
 517			return nil, agentToolErr
 518		}
 519		allTools = append(allTools, agentTool)
 520	}
 521	return allTools, nil
 522}
 523
 524func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 525	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 526
 527	// Create the assistant message first so the spinner shows immediately
 528	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 529		Role:     message.Assistant,
 530		Parts:    []message.ContentPart{},
 531		Model:    a.Model().ID,
 532		Provider: a.providerID,
 533	})
 534	if err != nil {
 535		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 536	}
 537
 538	allTools, toolsErr := a.getAllTools()
 539	if toolsErr != nil {
 540		return assistantMsg, nil, toolsErr
 541	}
 542	// Now collect tools (which may block on MCP initialization)
 543	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 544
 545	// Add the session and message ID into the context if needed by tools.
 546	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 547
 548	// Process each event in the stream.
 549	for event := range eventChan {
 550		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 551			if errors.Is(processErr, context.Canceled) {
 552				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 553			} else {
 554				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 555			}
 556			return assistantMsg, nil, processErr
 557		}
 558		if ctx.Err() != nil {
 559			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 560			return assistantMsg, nil, ctx.Err()
 561		}
 562	}
 563
 564	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 565	toolCalls := assistantMsg.ToolCalls()
 566	for i, toolCall := range toolCalls {
 567		select {
 568		case <-ctx.Done():
 569			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 570			// Make all future tool calls cancelled
 571			for j := i; j < len(toolCalls); j++ {
 572				toolResults[j] = message.ToolResult{
 573					ToolCallID: toolCalls[j].ID,
 574					Content:    "Tool execution canceled by user",
 575					IsError:    true,
 576				}
 577			}
 578			goto out
 579		default:
 580			// Continue processing
 581			var tool tools.BaseTool
 582			allTools, _ := a.getAllTools()
 583			for _, availableTool := range allTools {
 584				if availableTool.Info().Name == toolCall.Name {
 585					tool = availableTool
 586					break
 587				}
 588			}
 589
 590			// Tool not found
 591			if tool == nil {
 592				toolResults[i] = message.ToolResult{
 593					ToolCallID: toolCall.ID,
 594					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 595					IsError:    true,
 596				}
 597				continue
 598			}
 599
 600			// Run tool in goroutine to allow cancellation
 601			type toolExecResult struct {
 602				response tools.ToolResponse
 603				err      error
 604			}
 605			resultChan := make(chan toolExecResult, 1)
 606
 607			go func() {
 608				response, err := tool.Run(ctx, tools.ToolCall{
 609					ID:    toolCall.ID,
 610					Name:  toolCall.Name,
 611					Input: toolCall.Input,
 612				})
 613				resultChan <- toolExecResult{response: response, err: err}
 614			}()
 615
 616			var toolResponse tools.ToolResponse
 617			var toolErr error
 618
 619			select {
 620			case <-ctx.Done():
 621				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 622				// Mark remaining tool calls as cancelled
 623				for j := i; j < len(toolCalls); j++ {
 624					toolResults[j] = message.ToolResult{
 625						ToolCallID: toolCalls[j].ID,
 626						Content:    "Tool execution canceled by user",
 627						IsError:    true,
 628					}
 629				}
 630				goto out
 631			case result := <-resultChan:
 632				toolResponse = result.response
 633				toolErr = result.err
 634			}
 635
 636			if toolErr != nil {
 637				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 638				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 639					toolResults[i] = message.ToolResult{
 640						ToolCallID: toolCall.ID,
 641						Content:    "Permission denied",
 642						IsError:    true,
 643					}
 644					for j := i + 1; j < len(toolCalls); j++ {
 645						toolResults[j] = message.ToolResult{
 646							ToolCallID: toolCalls[j].ID,
 647							Content:    "Tool execution canceled by user",
 648							IsError:    true,
 649						}
 650					}
 651					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 652					break
 653				}
 654			}
 655			toolResults[i] = message.ToolResult{
 656				ToolCallID: toolCall.ID,
 657				Content:    toolResponse.Content,
 658				Metadata:   toolResponse.Metadata,
 659				IsError:    toolResponse.IsError,
 660			}
 661		}
 662	}
 663out:
 664	if len(toolResults) == 0 {
 665		return assistantMsg, nil, nil
 666	}
 667	parts := make([]message.ContentPart, 0)
 668	for _, tr := range toolResults {
 669		parts = append(parts, tr)
 670	}
 671	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 672		Role:     message.Tool,
 673		Parts:    parts,
 674		Provider: a.providerID,
 675	})
 676	if err != nil {
 677		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 678	}
 679
 680	return assistantMsg, &msg, err
 681}
 682
 683func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 684	msg.AddFinish(finishReason, message, details)
 685	_ = a.messages.Update(ctx, *msg)
 686}
 687
 688func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 689	select {
 690	case <-ctx.Done():
 691		return ctx.Err()
 692	default:
 693		// Continue processing.
 694	}
 695
 696	switch event.Type {
 697	case provider.EventThinkingDelta:
 698		assistantMsg.AppendReasoningContent(event.Thinking)
 699		return a.messages.Update(ctx, *assistantMsg)
 700	case provider.EventSignatureDelta:
 701		assistantMsg.AppendReasoningSignature(event.Signature)
 702		return a.messages.Update(ctx, *assistantMsg)
 703	case provider.EventContentDelta:
 704		assistantMsg.FinishThinking()
 705		assistantMsg.AppendContent(event.Content)
 706		return a.messages.Update(ctx, *assistantMsg)
 707	case provider.EventToolUseStart:
 708		assistantMsg.FinishThinking()
 709		slog.Info("Tool call started", "toolCall", event.ToolCall)
 710		assistantMsg.AddToolCall(*event.ToolCall)
 711		return a.messages.Update(ctx, *assistantMsg)
 712	case provider.EventToolUseDelta:
 713		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 714		return a.messages.Update(ctx, *assistantMsg)
 715	case provider.EventToolUseStop:
 716		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 717		assistantMsg.FinishToolCall(event.ToolCall.ID)
 718		return a.messages.Update(ctx, *assistantMsg)
 719	case provider.EventError:
 720		return event.Error
 721	case provider.EventComplete:
 722		assistantMsg.FinishThinking()
 723		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 724		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 725		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 726			return fmt.Errorf("failed to update message: %w", err)
 727		}
 728		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 729	}
 730
 731	return nil
 732}
 733
 734func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 735	sess, err := a.sessions.Get(ctx, sessionID)
 736	if err != nil {
 737		return fmt.Errorf("failed to get session: %w", err)
 738	}
 739
 740	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 741		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 742		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 743		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 744
 745	sess.Cost += cost
 746	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 747	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 748
 749	_, err = a.sessions.Save(ctx, sess)
 750	if err != nil {
 751		return fmt.Errorf("failed to save session: %w", err)
 752	}
 753	return nil
 754}
 755
 756func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 757	if a.summarizeProvider == nil {
 758		return fmt.Errorf("summarize provider not available")
 759	}
 760
 761	// Check if session is busy
 762	if a.IsSessionBusy(sessionID) {
 763		return ErrSessionBusy
 764	}
 765
 766	// Create a new context with cancellation
 767	summarizeCtx, cancel := context.WithCancel(ctx)
 768
 769	// Store the cancel function in activeRequests to allow cancellation
 770	a.activeRequests.Set(sessionID+"-summarize", cancel)
 771
 772	go func() {
 773		defer a.activeRequests.Del(sessionID + "-summarize")
 774		defer cancel()
 775		event := AgentEvent{
 776			Type:     AgentEventTypeSummarize,
 777			Progress: "Starting summarization...",
 778		}
 779
 780		a.Publish(pubsub.CreatedEvent, event)
 781		// Get all messages from the session
 782		msgs, err := a.messages.List(summarizeCtx, sessionID)
 783		if err != nil {
 784			event = AgentEvent{
 785				Type:  AgentEventTypeError,
 786				Error: fmt.Errorf("failed to list messages: %w", err),
 787				Done:  true,
 788			}
 789			a.Publish(pubsub.CreatedEvent, event)
 790			return
 791		}
 792		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 793
 794		if len(msgs) == 0 {
 795			event = AgentEvent{
 796				Type:  AgentEventTypeError,
 797				Error: fmt.Errorf("no messages to summarize"),
 798				Done:  true,
 799			}
 800			a.Publish(pubsub.CreatedEvent, event)
 801			return
 802		}
 803
 804		event = AgentEvent{
 805			Type:     AgentEventTypeSummarize,
 806			Progress: "Analyzing conversation...",
 807		}
 808		a.Publish(pubsub.CreatedEvent, event)
 809
 810		// Add a system message to guide the summarization
 811		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 812
 813		// Create a new message with the summarize prompt
 814		promptMsg := message.Message{
 815			Role:  message.User,
 816			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 817		}
 818
 819		// Append the prompt to the messages
 820		msgsWithPrompt := append(msgs, promptMsg)
 821
 822		event = AgentEvent{
 823			Type:     AgentEventTypeSummarize,
 824			Progress: "Generating summary...",
 825		}
 826
 827		a.Publish(pubsub.CreatedEvent, event)
 828
 829		// Send the messages to the summarize provider
 830		response := a.summarizeProvider.StreamResponse(
 831			summarizeCtx,
 832			msgsWithPrompt,
 833			nil,
 834		)
 835		var finalResponse *provider.ProviderResponse
 836		for r := range response {
 837			if r.Error != nil {
 838				event = AgentEvent{
 839					Type:  AgentEventTypeError,
 840					Error: fmt.Errorf("failed to summarize: %w", err),
 841					Done:  true,
 842				}
 843				a.Publish(pubsub.CreatedEvent, event)
 844				return
 845			}
 846			finalResponse = r.Response
 847		}
 848
 849		summary := strings.TrimSpace(finalResponse.Content)
 850		if summary == "" {
 851			event = AgentEvent{
 852				Type:  AgentEventTypeError,
 853				Error: fmt.Errorf("empty summary returned"),
 854				Done:  true,
 855			}
 856			a.Publish(pubsub.CreatedEvent, event)
 857			return
 858		}
 859		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 860		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 861		event = AgentEvent{
 862			Type:     AgentEventTypeSummarize,
 863			Progress: "Creating new session...",
 864		}
 865
 866		a.Publish(pubsub.CreatedEvent, event)
 867		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 868		if err != nil {
 869			event = AgentEvent{
 870				Type:  AgentEventTypeError,
 871				Error: fmt.Errorf("failed to get session: %w", err),
 872				Done:  true,
 873			}
 874
 875			a.Publish(pubsub.CreatedEvent, event)
 876			return
 877		}
 878		// Create a message in the new session with the summary
 879		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 880			Role: message.Assistant,
 881			Parts: []message.ContentPart{
 882				message.TextContent{Text: summary},
 883				message.Finish{
 884					Reason: message.FinishReasonEndTurn,
 885					Time:   time.Now().Unix(),
 886				},
 887			},
 888			Model:    a.summarizeProvider.Model().ID,
 889			Provider: a.summarizeProviderID,
 890		})
 891		if err != nil {
 892			event = AgentEvent{
 893				Type:  AgentEventTypeError,
 894				Error: fmt.Errorf("failed to create summary message: %w", err),
 895				Done:  true,
 896			}
 897
 898			a.Publish(pubsub.CreatedEvent, event)
 899			return
 900		}
 901		oldSession.SummaryMessageID = msg.ID
 902		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 903		oldSession.PromptTokens = 0
 904		model := a.summarizeProvider.Model()
 905		usage := finalResponse.Usage
 906		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 907			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 908			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 909			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 910		oldSession.Cost += cost
 911		_, err = a.sessions.Save(summarizeCtx, oldSession)
 912		if err != nil {
 913			event = AgentEvent{
 914				Type:  AgentEventTypeError,
 915				Error: fmt.Errorf("failed to save session: %w", err),
 916				Done:  true,
 917			}
 918			a.Publish(pubsub.CreatedEvent, event)
 919		}
 920
 921		event = AgentEvent{
 922			Type:      AgentEventTypeSummarize,
 923			SessionID: oldSession.ID,
 924			Progress:  "Summary complete",
 925			Done:      true,
 926		}
 927		a.Publish(pubsub.CreatedEvent, event)
 928		// Send final success event with the new session ID
 929	}()
 930
 931	return nil
 932}
 933
 934func (a *agent) ClearQueue(sessionID string) {
 935	if a.QueuedPrompts(sessionID) > 0 {
 936		slog.Info("Clearing queued prompts", "session_id", sessionID)
 937		a.promptQueue.Del(sessionID)
 938	}
 939}
 940
 941func (a *agent) CancelAll() {
 942	if !a.IsBusy() {
 943		return
 944	}
 945	for key := range a.activeRequests.Seq2() {
 946		a.Cancel(key) // key is sessionID
 947	}
 948
 949	timeout := time.After(5 * time.Second)
 950	for a.IsBusy() {
 951		select {
 952		case <-timeout:
 953			return
 954		default:
 955			time.Sleep(200 * time.Millisecond)
 956		}
 957	}
 958}
 959
 960func (a *agent) UpdateModel() error {
 961	cfg := config.Get()
 962
 963	// Get current provider configuration
 964	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 965	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 966		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 967	}
 968
 969	// Check if provider has changed
 970	if string(currentProviderCfg.ID) != a.providerID {
 971		// Provider changed, need to recreate the main provider
 972		model := cfg.GetModelByType(a.agentCfg.Model)
 973		if model.ID == "" {
 974			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 975		}
 976
 977		promptID := agentPromptMap[a.agentCfg.ID]
 978		if promptID == "" {
 979			promptID = prompt.PromptDefault
 980		}
 981
 982		opts := []provider.ProviderClientOption{
 983			provider.WithModel(a.agentCfg.Model),
 984			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 985		}
 986
 987		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 988		if err != nil {
 989			return fmt.Errorf("failed to create new provider: %w", err)
 990		}
 991
 992		// Update the provider and provider ID
 993		a.provider = newProvider
 994		a.providerID = string(currentProviderCfg.ID)
 995	}
 996
 997	// Check if providers have changed for title (small) and summarize (large)
 998	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 999	var smallModelProviderCfg config.ProviderConfig
1000	for p := range cfg.Providers.Seq() {
1001		if p.ID == smallModelCfg.Provider {
1002			smallModelProviderCfg = p
1003			break
1004		}
1005	}
1006	if smallModelProviderCfg.ID == "" {
1007		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1008	}
1009
1010	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1011	var largeModelProviderCfg config.ProviderConfig
1012	for p := range cfg.Providers.Seq() {
1013		if p.ID == largeModelCfg.Provider {
1014			largeModelProviderCfg = p
1015			break
1016		}
1017	}
1018	if largeModelProviderCfg.ID == "" {
1019		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1020	}
1021
1022	var maxTitleTokens int64 = 40
1023
1024	// if the max output is too low for the gemini provider it won't return anything
1025	if smallModelCfg.Provider == "gemini" {
1026		maxTitleTokens = 1000
1027	}
1028	// Recreate title provider
1029	titleOpts := []provider.ProviderClientOption{
1030		provider.WithModel(config.SelectedModelTypeSmall),
1031		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1032		provider.WithMaxTokens(maxTitleTokens),
1033	}
1034	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1035	if err != nil {
1036		return fmt.Errorf("failed to create new title provider: %w", err)
1037	}
1038	a.titleProvider = newTitleProvider
1039
1040	// Recreate summarize provider if provider changed (now large model)
1041	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1042		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1043		if largeModel == nil {
1044			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1045		}
1046		summarizeOpts := []provider.ProviderClientOption{
1047			provider.WithModel(config.SelectedModelTypeLarge),
1048			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1049		}
1050		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1051		if err != nil {
1052			return fmt.Errorf("failed to create new summarize provider: %w", err)
1053		}
1054		a.summarizeProvider = newSummarizeProvider
1055		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1056	}
1057
1058	return nil
1059}