agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75	// We need this to be able to update it when model changes
  76	agentToolFn func() (tools.BaseTool, error)
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86	promptQueue    *csync.Map[string, []string]
  87}
  88
  89var agentPromptMap = map[string]prompt.PromptID{
  90	"coder": prompt.PromptCoder,
  91	"task":  prompt.PromptTask,
  92}
  93
  94func NewAgent(
  95	ctx context.Context,
  96	agentCfg config.Agent,
  97	// These services are needed in the tools
  98	permissions permission.Service,
  99	sessions session.Service,
 100	messages message.Service,
 101	history history.Service,
 102	lspClients *csync.Map[string, *lsp.Client],
 103) (Service, error) {
 104	cfg := config.Get()
 105
 106	var agentToolFn func() (tools.BaseTool, error)
 107	if agentCfg.ID == "coder" {
 108		agentToolFn = func() (tools.BaseTool, error) {
 109			taskAgentCfg := config.Get().Agents["task"]
 110			if taskAgentCfg.ID == "" {
 111				return nil, fmt.Errorf("task agent not found in config")
 112			}
 113			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 114			if err != nil {
 115				return nil, fmt.Errorf("failed to create task agent: %w", err)
 116			}
 117			return NewAgentTool(taskAgent, sessions, messages), nil
 118		}
 119	}
 120
 121	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 122	if providerCfg == nil {
 123		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 124	}
 125	model := config.Get().GetModelByType(agentCfg.Model)
 126
 127	if model == nil {
 128		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 129	}
 130
 131	promptID := agentPromptMap[agentCfg.ID]
 132	if promptID == "" {
 133		promptID = prompt.PromptDefault
 134	}
 135	opts := []provider.ProviderClientOption{
 136		provider.WithModel(agentCfg.Model),
 137		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 138	}
 139	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 140	if err != nil {
 141		return nil, err
 142	}
 143
 144	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 145	var smallModelProviderCfg *config.ProviderConfig
 146	if smallModelCfg.Provider == providerCfg.ID {
 147		smallModelProviderCfg = providerCfg
 148	} else {
 149		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 150
 151		if smallModelProviderCfg.ID == "" {
 152			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 153		}
 154	}
 155	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 156	if smallModel.ID == "" {
 157		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 158	}
 159
 160	titleOpts := []provider.ProviderClientOption{
 161		provider.WithModel(config.SelectedModelTypeSmall),
 162		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 163	}
 164	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 165	if err != nil {
 166		return nil, err
 167	}
 168
 169	summarizeOpts := []provider.ProviderClientOption{
 170		provider.WithModel(config.SelectedModelTypeLarge),
 171		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 172	}
 173	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 174	if err != nil {
 175		return nil, err
 176	}
 177
 178	toolFn := func() []tools.BaseTool {
 179		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 180		defer func() {
 181			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 182		}()
 183
 184		cwd := cfg.WorkingDir()
 185		allTools := []tools.BaseTool{
 186			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 187			tools.NewDownloadTool(permissions, cwd),
 188			tools.NewEditTool(lspClients, permissions, history, cwd),
 189			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 190			tools.NewFetchTool(permissions, cwd),
 191			tools.NewGlobTool(cwd),
 192			tools.NewGrepTool(cwd),
 193			tools.NewLsTool(permissions, cwd),
 194			tools.NewSourcegraphTool(),
 195			tools.NewViewTool(lspClients, permissions, cwd),
 196			tools.NewWriteTool(lspClients, permissions, history, cwd),
 197		}
 198
 199		mcpToolsOnce.Do(func() {
 200			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 201		})
 202
 203		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 204			if agentCfg.ID == "coder" {
 205				t = append(t, mcpTools...)
 206				if lspClients.Len() > 0 {
 207					t = append(t, tools.NewDiagnosticsTool(lspClients))
 208				}
 209			}
 210			return t
 211		}
 212
 213		if agentCfg.AllowedTools == nil {
 214			return withCoderTools(allTools)
 215		}
 216
 217		var filteredTools []tools.BaseTool
 218		for _, tool := range allTools {
 219			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 220				filteredTools = append(filteredTools, tool)
 221			}
 222		}
 223		return withCoderTools(filteredTools)
 224	}
 225
 226	return &agent{
 227		Broker:              pubsub.NewBroker[AgentEvent](),
 228		agentCfg:            agentCfg,
 229		provider:            agentProvider,
 230		providerID:          string(providerCfg.ID),
 231		messages:            messages,
 232		sessions:            sessions,
 233		titleProvider:       titleProvider,
 234		summarizeProvider:   summarizeProvider,
 235		summarizeProviderID: string(providerCfg.ID),
 236		agentToolFn:         agentToolFn,
 237		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 238		tools:               csync.NewLazySlice(toolFn),
 239		promptQueue:         csync.NewMap[string, []string](),
 240	}, nil
 241}
 242
 243func (a *agent) Model() catwalk.Model {
 244	return *config.Get().GetModelByType(a.agentCfg.Model)
 245}
 246
 247func (a *agent) Cancel(sessionID string) {
 248	// Cancel regular requests
 249	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 250		slog.Info("Request cancellation initiated", "session_id", sessionID)
 251		cancel()
 252	}
 253
 254	// Also check for summarize requests
 255	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 256		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 257		cancel()
 258	}
 259
 260	if a.QueuedPrompts(sessionID) > 0 {
 261		slog.Info("Clearing queued prompts", "session_id", sessionID)
 262		a.promptQueue.Del(sessionID)
 263	}
 264}
 265
 266func (a *agent) IsBusy() bool {
 267	var busy bool
 268	for cancelFunc := range a.activeRequests.Seq() {
 269		if cancelFunc != nil {
 270			busy = true
 271			break
 272		}
 273	}
 274	return busy
 275}
 276
 277func (a *agent) IsSessionBusy(sessionID string) bool {
 278	_, busy := a.activeRequests.Get(sessionID)
 279	return busy
 280}
 281
 282func (a *agent) QueuedPrompts(sessionID string) int {
 283	l, ok := a.promptQueue.Get(sessionID)
 284	if !ok {
 285		return 0
 286	}
 287	return len(l)
 288}
 289
 290func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 291	if content == "" {
 292		return nil
 293	}
 294	if a.titleProvider == nil {
 295		return nil
 296	}
 297	session, err := a.sessions.Get(ctx, sessionID)
 298	if err != nil {
 299		return err
 300	}
 301	parts := []message.ContentPart{message.TextContent{
 302		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 303	}}
 304
 305	// Use streaming approach like summarization
 306	response := a.titleProvider.StreamResponse(
 307		ctx,
 308		[]message.Message{
 309			{
 310				Role:  message.User,
 311				Parts: parts,
 312			},
 313		},
 314		nil,
 315	)
 316
 317	var finalResponse *provider.ProviderResponse
 318	for r := range response {
 319		if r.Error != nil {
 320			return r.Error
 321		}
 322		finalResponse = r.Response
 323	}
 324
 325	if finalResponse == nil {
 326		return fmt.Errorf("no response received from title provider")
 327	}
 328
 329	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 330	if title == "" {
 331		return nil
 332	}
 333
 334	session.Title = title
 335	_, err = a.sessions.Save(ctx, session)
 336	return err
 337}
 338
 339func (a *agent) err(err error) AgentEvent {
 340	return AgentEvent{
 341		Type:  AgentEventTypeError,
 342		Error: err,
 343	}
 344}
 345
 346func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 347	if !a.Model().SupportsImages && attachments != nil {
 348		attachments = nil
 349	}
 350	events := make(chan AgentEvent, 1)
 351	if a.IsSessionBusy(sessionID) {
 352		existing, ok := a.promptQueue.Get(sessionID)
 353		if !ok {
 354			existing = []string{}
 355		}
 356		existing = append(existing, content)
 357		a.promptQueue.Set(sessionID, existing)
 358		return nil, nil
 359	}
 360
 361	genCtx, cancel := context.WithCancel(ctx)
 362
 363	a.activeRequests.Set(sessionID, cancel)
 364	go func() {
 365		slog.Debug("Request started", "sessionID", sessionID)
 366		defer log.RecoverPanic("agent.Run", func() {
 367			events <- a.err(fmt.Errorf("panic while running the agent"))
 368		})
 369		var attachmentParts []message.ContentPart
 370		for _, attachment := range attachments {
 371			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 372		}
 373		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 374		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 375			slog.Error(result.Error.Error())
 376		}
 377		slog.Debug("Request completed", "sessionID", sessionID)
 378		a.activeRequests.Del(sessionID)
 379		cancel()
 380		a.Publish(pubsub.CreatedEvent, result)
 381		events <- result
 382		close(events)
 383	}()
 384	return events, nil
 385}
 386
 387func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 388	cfg := config.Get()
 389	// List existing messages; if none, start title generation asynchronously.
 390	msgs, err := a.messages.List(ctx, sessionID)
 391	if err != nil {
 392		return a.err(fmt.Errorf("failed to list messages: %w", err))
 393	}
 394	if len(msgs) == 0 {
 395		go func() {
 396			defer log.RecoverPanic("agent.Run", func() {
 397				slog.Error("panic while generating title")
 398			})
 399			titleErr := a.generateTitle(ctx, sessionID, content)
 400			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 401				slog.Error("failed to generate title", "error", titleErr)
 402			}
 403		}()
 404	}
 405	session, err := a.sessions.Get(ctx, sessionID)
 406	if err != nil {
 407		return a.err(fmt.Errorf("failed to get session: %w", err))
 408	}
 409	if session.SummaryMessageID != "" {
 410		summaryMsgInex := -1
 411		for i, msg := range msgs {
 412			if msg.ID == session.SummaryMessageID {
 413				summaryMsgInex = i
 414				break
 415			}
 416		}
 417		if summaryMsgInex != -1 {
 418			msgs = msgs[summaryMsgInex:]
 419			msgs[0].Role = message.User
 420		}
 421	}
 422
 423	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 424	if err != nil {
 425		return a.err(fmt.Errorf("failed to create user message: %w", err))
 426	}
 427	// Append the new user message to the conversation history.
 428	msgHistory := append(msgs, userMsg)
 429
 430	for {
 431		// Check for cancellation before each iteration
 432		select {
 433		case <-ctx.Done():
 434			return a.err(ctx.Err())
 435		default:
 436			// Continue processing
 437		}
 438		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 439		if err != nil {
 440			if errors.Is(err, context.Canceled) {
 441				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 442				a.messages.Update(context.Background(), agentMessage)
 443				return a.err(ErrRequestCancelled)
 444			}
 445			return a.err(fmt.Errorf("failed to process events: %w", err))
 446		}
 447		if cfg.Options.Debug {
 448			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 449		}
 450		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 451			// We are not done, we need to respond with the tool response
 452			msgHistory = append(msgHistory, agentMessage, *toolResults)
 453			// If there are queued prompts, process the next one
 454			nextPrompt, ok := a.promptQueue.Take(sessionID)
 455			if ok {
 456				for _, prompt := range nextPrompt {
 457					// Create a new user message for the queued prompt
 458					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 459					if err != nil {
 460						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 461					}
 462					// Append the new user message to the conversation history
 463					msgHistory = append(msgHistory, userMsg)
 464				}
 465			}
 466
 467			continue
 468		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 469			queuePrompts, ok := a.promptQueue.Take(sessionID)
 470			if ok {
 471				for _, prompt := range queuePrompts {
 472					if prompt == "" {
 473						continue
 474					}
 475					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 476					if err != nil {
 477						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 478					}
 479					msgHistory = append(msgHistory, userMsg)
 480				}
 481				continue
 482			}
 483		}
 484		if agentMessage.FinishReason() == "" {
 485			// Kujtim: could not track down where this is happening but this means its cancelled
 486			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 487			_ = a.messages.Update(context.Background(), agentMessage)
 488			return a.err(ErrRequestCancelled)
 489		}
 490		return AgentEvent{
 491			Type:    AgentEventTypeResponse,
 492			Message: agentMessage,
 493			Done:    true,
 494		}
 495	}
 496}
 497
 498func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 499	parts := []message.ContentPart{message.TextContent{Text: content}}
 500	parts = append(parts, attachmentParts...)
 501	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 502		Role:  message.User,
 503		Parts: parts,
 504	})
 505}
 506
 507func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 508	allTools := slices.Collect(a.tools.Seq())
 509	if a.agentToolFn != nil {
 510		agentTool, agentToolErr := a.agentToolFn()
 511		if agentToolErr != nil {
 512			return nil, agentToolErr
 513		}
 514		allTools = append(allTools, agentTool)
 515	}
 516	return allTools, nil
 517}
 518
 519func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 520	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 521
 522	// Create the assistant message first so the spinner shows immediately
 523	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 524		Role:     message.Assistant,
 525		Parts:    []message.ContentPart{},
 526		Model:    a.Model().ID,
 527		Provider: a.providerID,
 528	})
 529	if err != nil {
 530		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 531	}
 532
 533	allTools, toolsErr := a.getAllTools()
 534	if toolsErr != nil {
 535		return assistantMsg, nil, toolsErr
 536	}
 537	// Now collect tools (which may block on MCP initialization)
 538	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 539
 540	// Add the session and message ID into the context if needed by tools.
 541	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 542
 543	// Process each event in the stream.
 544	for event := range eventChan {
 545		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 546			if errors.Is(processErr, context.Canceled) {
 547				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 548			} else {
 549				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 550			}
 551			return assistantMsg, nil, processErr
 552		}
 553		if ctx.Err() != nil {
 554			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 555			return assistantMsg, nil, ctx.Err()
 556		}
 557	}
 558
 559	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 560	toolCalls := assistantMsg.ToolCalls()
 561	for i, toolCall := range toolCalls {
 562		select {
 563		case <-ctx.Done():
 564			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 565			// Make all future tool calls cancelled
 566			for j := i; j < len(toolCalls); j++ {
 567				toolResults[j] = message.ToolResult{
 568					ToolCallID: toolCalls[j].ID,
 569					Content:    "Tool execution canceled by user",
 570					IsError:    true,
 571				}
 572			}
 573			goto out
 574		default:
 575			// Continue processing
 576			var tool tools.BaseTool
 577			allTools, _ := a.getAllTools()
 578			for _, availableTool := range allTools {
 579				if availableTool.Info().Name == toolCall.Name {
 580					tool = availableTool
 581					break
 582				}
 583			}
 584
 585			// Tool not found
 586			if tool == nil {
 587				toolResults[i] = message.ToolResult{
 588					ToolCallID: toolCall.ID,
 589					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 590					IsError:    true,
 591				}
 592				continue
 593			}
 594
 595			// Run tool in goroutine to allow cancellation
 596			type toolExecResult struct {
 597				response tools.ToolResponse
 598				err      error
 599			}
 600			resultChan := make(chan toolExecResult, 1)
 601
 602			go func() {
 603				response, err := tool.Run(ctx, tools.ToolCall{
 604					ID:    toolCall.ID,
 605					Name:  toolCall.Name,
 606					Input: toolCall.Input,
 607				})
 608				resultChan <- toolExecResult{response: response, err: err}
 609			}()
 610
 611			var toolResponse tools.ToolResponse
 612			var toolErr error
 613
 614			select {
 615			case <-ctx.Done():
 616				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 617				// Mark remaining tool calls as cancelled
 618				for j := i; j < len(toolCalls); j++ {
 619					toolResults[j] = message.ToolResult{
 620						ToolCallID: toolCalls[j].ID,
 621						Content:    "Tool execution canceled by user",
 622						IsError:    true,
 623					}
 624				}
 625				goto out
 626			case result := <-resultChan:
 627				toolResponse = result.response
 628				toolErr = result.err
 629			}
 630
 631			if toolErr != nil {
 632				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 633				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 634					toolResults[i] = message.ToolResult{
 635						ToolCallID: toolCall.ID,
 636						Content:    "Permission denied",
 637						IsError:    true,
 638					}
 639					for j := i + 1; j < len(toolCalls); j++ {
 640						toolResults[j] = message.ToolResult{
 641							ToolCallID: toolCalls[j].ID,
 642							Content:    "Tool execution canceled by user",
 643							IsError:    true,
 644						}
 645					}
 646					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 647					break
 648				}
 649			}
 650			toolResults[i] = message.ToolResult{
 651				ToolCallID: toolCall.ID,
 652				Content:    toolResponse.Content,
 653				Metadata:   toolResponse.Metadata,
 654				IsError:    toolResponse.IsError,
 655			}
 656		}
 657	}
 658out:
 659	if len(toolResults) == 0 {
 660		return assistantMsg, nil, nil
 661	}
 662	parts := make([]message.ContentPart, 0)
 663	for _, tr := range toolResults {
 664		parts = append(parts, tr)
 665	}
 666	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 667		Role:     message.Tool,
 668		Parts:    parts,
 669		Provider: a.providerID,
 670	})
 671	if err != nil {
 672		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 673	}
 674
 675	return assistantMsg, &msg, err
 676}
 677
 678func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 679	msg.AddFinish(finishReason, message, details)
 680	_ = a.messages.Update(ctx, *msg)
 681}
 682
 683func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 684	select {
 685	case <-ctx.Done():
 686		return ctx.Err()
 687	default:
 688		// Continue processing.
 689	}
 690
 691	switch event.Type {
 692	case provider.EventThinkingDelta:
 693		assistantMsg.AppendReasoningContent(event.Thinking)
 694		return a.messages.Update(ctx, *assistantMsg)
 695	case provider.EventSignatureDelta:
 696		assistantMsg.AppendReasoningSignature(event.Signature)
 697		return a.messages.Update(ctx, *assistantMsg)
 698	case provider.EventContentDelta:
 699		assistantMsg.FinishThinking()
 700		assistantMsg.AppendContent(event.Content)
 701		return a.messages.Update(ctx, *assistantMsg)
 702	case provider.EventToolUseStart:
 703		assistantMsg.FinishThinking()
 704		slog.Info("Tool call started", "toolCall", event.ToolCall)
 705		assistantMsg.AddToolCall(*event.ToolCall)
 706		return a.messages.Update(ctx, *assistantMsg)
 707	case provider.EventToolUseDelta:
 708		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 709		return a.messages.Update(ctx, *assistantMsg)
 710	case provider.EventToolUseStop:
 711		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 712		assistantMsg.FinishToolCall(event.ToolCall.ID)
 713		return a.messages.Update(ctx, *assistantMsg)
 714	case provider.EventError:
 715		return event.Error
 716	case provider.EventComplete:
 717		assistantMsg.FinishThinking()
 718		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 719		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 720		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 721			return fmt.Errorf("failed to update message: %w", err)
 722		}
 723		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 724	}
 725
 726	return nil
 727}
 728
 729func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 730	sess, err := a.sessions.Get(ctx, sessionID)
 731	if err != nil {
 732		return fmt.Errorf("failed to get session: %w", err)
 733	}
 734
 735	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 736		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 737		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 738		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 739
 740	sess.Cost += cost
 741	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 742	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 743
 744	_, err = a.sessions.Save(ctx, sess)
 745	if err != nil {
 746		return fmt.Errorf("failed to save session: %w", err)
 747	}
 748	return nil
 749}
 750
 751func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 752	if a.summarizeProvider == nil {
 753		return fmt.Errorf("summarize provider not available")
 754	}
 755
 756	// Check if session is busy
 757	if a.IsSessionBusy(sessionID) {
 758		return ErrSessionBusy
 759	}
 760
 761	// Create a new context with cancellation
 762	summarizeCtx, cancel := context.WithCancel(ctx)
 763
 764	// Store the cancel function in activeRequests to allow cancellation
 765	a.activeRequests.Set(sessionID+"-summarize", cancel)
 766
 767	go func() {
 768		defer a.activeRequests.Del(sessionID + "-summarize")
 769		defer cancel()
 770		event := AgentEvent{
 771			Type:     AgentEventTypeSummarize,
 772			Progress: "Starting summarization...",
 773		}
 774
 775		a.Publish(pubsub.CreatedEvent, event)
 776		// Get all messages from the session
 777		msgs, err := a.messages.List(summarizeCtx, sessionID)
 778		if err != nil {
 779			event = AgentEvent{
 780				Type:  AgentEventTypeError,
 781				Error: fmt.Errorf("failed to list messages: %w", err),
 782				Done:  true,
 783			}
 784			a.Publish(pubsub.CreatedEvent, event)
 785			return
 786		}
 787		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 788
 789		if len(msgs) == 0 {
 790			event = AgentEvent{
 791				Type:  AgentEventTypeError,
 792				Error: fmt.Errorf("no messages to summarize"),
 793				Done:  true,
 794			}
 795			a.Publish(pubsub.CreatedEvent, event)
 796			return
 797		}
 798
 799		event = AgentEvent{
 800			Type:     AgentEventTypeSummarize,
 801			Progress: "Analyzing conversation...",
 802		}
 803		a.Publish(pubsub.CreatedEvent, event)
 804
 805		// Add a system message to guide the summarization
 806		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 807
 808		// Create a new message with the summarize prompt
 809		promptMsg := message.Message{
 810			Role:  message.User,
 811			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 812		}
 813
 814		// Append the prompt to the messages
 815		msgsWithPrompt := append(msgs, promptMsg)
 816
 817		event = AgentEvent{
 818			Type:     AgentEventTypeSummarize,
 819			Progress: "Generating summary...",
 820		}
 821
 822		a.Publish(pubsub.CreatedEvent, event)
 823
 824		// Send the messages to the summarize provider
 825		response := a.summarizeProvider.StreamResponse(
 826			summarizeCtx,
 827			msgsWithPrompt,
 828			nil,
 829		)
 830		var finalResponse *provider.ProviderResponse
 831		for r := range response {
 832			if r.Error != nil {
 833				event = AgentEvent{
 834					Type:  AgentEventTypeError,
 835					Error: fmt.Errorf("failed to summarize: %w", err),
 836					Done:  true,
 837				}
 838				a.Publish(pubsub.CreatedEvent, event)
 839				return
 840			}
 841			finalResponse = r.Response
 842		}
 843
 844		summary := strings.TrimSpace(finalResponse.Content)
 845		if summary == "" {
 846			event = AgentEvent{
 847				Type:  AgentEventTypeError,
 848				Error: fmt.Errorf("empty summary returned"),
 849				Done:  true,
 850			}
 851			a.Publish(pubsub.CreatedEvent, event)
 852			return
 853		}
 854		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 855		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 856		event = AgentEvent{
 857			Type:     AgentEventTypeSummarize,
 858			Progress: "Creating new session...",
 859		}
 860
 861		a.Publish(pubsub.CreatedEvent, event)
 862		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 863		if err != nil {
 864			event = AgentEvent{
 865				Type:  AgentEventTypeError,
 866				Error: fmt.Errorf("failed to get session: %w", err),
 867				Done:  true,
 868			}
 869
 870			a.Publish(pubsub.CreatedEvent, event)
 871			return
 872		}
 873		// Create a message in the new session with the summary
 874		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 875			Role: message.Assistant,
 876			Parts: []message.ContentPart{
 877				message.TextContent{Text: summary},
 878				message.Finish{
 879					Reason: message.FinishReasonEndTurn,
 880					Time:   time.Now().Unix(),
 881				},
 882			},
 883			Model:    a.summarizeProvider.Model().ID,
 884			Provider: a.summarizeProviderID,
 885		})
 886		if err != nil {
 887			event = AgentEvent{
 888				Type:  AgentEventTypeError,
 889				Error: fmt.Errorf("failed to create summary message: %w", err),
 890				Done:  true,
 891			}
 892
 893			a.Publish(pubsub.CreatedEvent, event)
 894			return
 895		}
 896		oldSession.SummaryMessageID = msg.ID
 897		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 898		oldSession.PromptTokens = 0
 899		model := a.summarizeProvider.Model()
 900		usage := finalResponse.Usage
 901		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 902			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 903			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 904			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 905		oldSession.Cost += cost
 906		_, err = a.sessions.Save(summarizeCtx, oldSession)
 907		if err != nil {
 908			event = AgentEvent{
 909				Type:  AgentEventTypeError,
 910				Error: fmt.Errorf("failed to save session: %w", err),
 911				Done:  true,
 912			}
 913			a.Publish(pubsub.CreatedEvent, event)
 914		}
 915
 916		event = AgentEvent{
 917			Type:      AgentEventTypeSummarize,
 918			SessionID: oldSession.ID,
 919			Progress:  "Summary complete",
 920			Done:      true,
 921		}
 922		a.Publish(pubsub.CreatedEvent, event)
 923		// Send final success event with the new session ID
 924	}()
 925
 926	return nil
 927}
 928
 929func (a *agent) ClearQueue(sessionID string) {
 930	if a.QueuedPrompts(sessionID) > 0 {
 931		slog.Info("Clearing queued prompts", "session_id", sessionID)
 932		a.promptQueue.Del(sessionID)
 933	}
 934}
 935
 936func (a *agent) CancelAll() {
 937	if !a.IsBusy() {
 938		return
 939	}
 940	for key := range a.activeRequests.Seq2() {
 941		a.Cancel(key) // key is sessionID
 942	}
 943
 944	timeout := time.After(5 * time.Second)
 945	for a.IsBusy() {
 946		select {
 947		case <-timeout:
 948			return
 949		default:
 950			time.Sleep(200 * time.Millisecond)
 951		}
 952	}
 953}
 954
 955func (a *agent) UpdateModel() error {
 956	cfg := config.Get()
 957
 958	// Get current provider configuration
 959	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 960	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 961		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 962	}
 963
 964	// Check if provider has changed
 965	if string(currentProviderCfg.ID) != a.providerID {
 966		// Provider changed, need to recreate the main provider
 967		model := cfg.GetModelByType(a.agentCfg.Model)
 968		if model.ID == "" {
 969			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 970		}
 971
 972		promptID := agentPromptMap[a.agentCfg.ID]
 973		if promptID == "" {
 974			promptID = prompt.PromptDefault
 975		}
 976
 977		opts := []provider.ProviderClientOption{
 978			provider.WithModel(a.agentCfg.Model),
 979			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 980		}
 981
 982		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 983		if err != nil {
 984			return fmt.Errorf("failed to create new provider: %w", err)
 985		}
 986
 987		// Update the provider and provider ID
 988		a.provider = newProvider
 989		a.providerID = string(currentProviderCfg.ID)
 990	}
 991
 992	// Check if providers have changed for title (small) and summarize (large)
 993	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 994	var smallModelProviderCfg config.ProviderConfig
 995	for p := range cfg.Providers.Seq() {
 996		if p.ID == smallModelCfg.Provider {
 997			smallModelProviderCfg = p
 998			break
 999		}
1000	}
1001	if smallModelProviderCfg.ID == "" {
1002		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1003	}
1004
1005	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1006	var largeModelProviderCfg config.ProviderConfig
1007	for p := range cfg.Providers.Seq() {
1008		if p.ID == largeModelCfg.Provider {
1009			largeModelProviderCfg = p
1010			break
1011		}
1012	}
1013	if largeModelProviderCfg.ID == "" {
1014		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1015	}
1016
1017	var maxTitleTokens int64 = 40
1018
1019	// if the max output is too low for the gemini provider it won't return anything
1020	if smallModelCfg.Provider == "gemini" {
1021		maxTitleTokens = 1000
1022	}
1023	// Recreate title provider
1024	titleOpts := []provider.ProviderClientOption{
1025		provider.WithModel(config.SelectedModelTypeSmall),
1026		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1027		provider.WithMaxTokens(maxTitleTokens),
1028	}
1029	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1030	if err != nil {
1031		return fmt.Errorf("failed to create new title provider: %w", err)
1032	}
1033	a.titleProvider = newTitleProvider
1034
1035	// Recreate summarize provider if provider changed (now large model)
1036	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1037		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1038		if largeModel == nil {
1039			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1040		}
1041		summarizeOpts := []provider.ProviderClientOption{
1042			provider.WithModel(config.SelectedModelTypeLarge),
1043			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1044		}
1045		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1046		if err != nil {
1047			return fmt.Errorf("failed to create new summarize provider: %w", err)
1048		}
1049		a.summarizeProvider = newSummarizeProvider
1050		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1051	}
1052
1053	return nil
1054}