agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75	// We need this to be able to update it when model changes
  76	agentToolFn func() (tools.BaseTool, error)
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86
  87	promptQueue *csync.Map[string, []string]
  88}
  89
  90var agentPromptMap = map[string]prompt.PromptID{
  91	"coder": prompt.PromptCoder,
  92	"task":  prompt.PromptTask,
  93}
  94
  95func NewAgent(
  96	ctx context.Context,
  97	agentCfg config.Agent,
  98	// These services are needed in the tools
  99	permissions permission.Service,
 100	sessions session.Service,
 101	messages message.Service,
 102	history history.Service,
 103	lspClients map[string]*lsp.Client,
 104) (Service, error) {
 105	cfg := config.Get()
 106
 107	var agentToolFn func() (tools.BaseTool, error)
 108	if agentCfg.ID == "coder" {
 109		agentToolFn = func() (tools.BaseTool, error) {
 110			taskAgentCfg := config.Get().Agents["task"]
 111			if taskAgentCfg.ID == "" {
 112				return nil, fmt.Errorf("task agent not found in config")
 113			}
 114			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 115			if err != nil {
 116				return nil, fmt.Errorf("failed to create task agent: %w", err)
 117			}
 118			return NewAgentTool(taskAgent, sessions, messages), nil
 119		}
 120	}
 121
 122	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 123	if providerCfg == nil {
 124		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 125	}
 126	model := config.Get().GetModelByType(agentCfg.Model)
 127
 128	if model == nil {
 129		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 130	}
 131
 132	promptID := agentPromptMap[agentCfg.ID]
 133	if promptID == "" {
 134		promptID = prompt.PromptDefault
 135	}
 136	opts := []provider.ProviderClientOption{
 137		provider.WithModel(agentCfg.Model),
 138		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 139	}
 140	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 141	if err != nil {
 142		return nil, err
 143	}
 144
 145	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 146	var smallModelProviderCfg *config.ProviderConfig
 147	if smallModelCfg.Provider == providerCfg.ID {
 148		smallModelProviderCfg = providerCfg
 149	} else {
 150		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 151
 152		if smallModelProviderCfg.ID == "" {
 153			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 154		}
 155	}
 156	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 157	if smallModel.ID == "" {
 158		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 159	}
 160
 161	titleOpts := []provider.ProviderClientOption{
 162		provider.WithModel(config.SelectedModelTypeSmall),
 163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 164	}
 165	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 166	if err != nil {
 167		return nil, err
 168	}
 169
 170	summarizeOpts := []provider.ProviderClientOption{
 171		provider.WithModel(config.SelectedModelTypeLarge),
 172		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 173	}
 174	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 175	if err != nil {
 176		return nil, err
 177	}
 178
 179	toolFn := func() []tools.BaseTool {
 180		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 181		defer func() {
 182			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 183		}()
 184
 185		cwd := cfg.WorkingDir()
 186		allTools := []tools.BaseTool{
 187			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 188			tools.NewDownloadTool(permissions, cwd),
 189			tools.NewEditTool(lspClients, permissions, history, cwd),
 190			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 191			tools.NewFetchTool(permissions, cwd),
 192			tools.NewGlobTool(cwd),
 193			tools.NewGrepTool(cwd),
 194			tools.NewLsTool(permissions, cwd),
 195			tools.NewSourcegraphTool(),
 196			tools.NewViewTool(lspClients, permissions, cwd),
 197			tools.NewWriteTool(lspClients, permissions, history, cwd),
 198		}
 199
 200		mcpToolsOnce.Do(func() {
 201			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 202		})
 203
 204		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 205			if agentCfg.ID == "coder" {
 206				t = append(t, mcpTools...)
 207				if len(lspClients) > 0 {
 208					t = append(t, tools.NewDiagnosticsTool(lspClients))
 209				}
 210			}
 211			return t
 212		}
 213
 214		if agentCfg.AllowedTools == nil {
 215			return withCoderTools(allTools)
 216		}
 217
 218		var filteredTools []tools.BaseTool
 219		for _, tool := range allTools {
 220			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 221				filteredTools = append(filteredTools, tool)
 222			}
 223		}
 224		return withCoderTools(filteredTools)
 225	}
 226
 227	return &agent{
 228		Broker:              pubsub.NewBroker[AgentEvent](),
 229		agentCfg:            agentCfg,
 230		provider:            agentProvider,
 231		providerID:          string(providerCfg.ID),
 232		messages:            messages,
 233		sessions:            sessions,
 234		titleProvider:       titleProvider,
 235		summarizeProvider:   summarizeProvider,
 236		summarizeProviderID: string(providerCfg.ID),
 237		agentToolFn:         agentToolFn,
 238		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 239		tools:               csync.NewLazySlice(toolFn),
 240		promptQueue:         csync.NewMap[string, []string](),
 241	}, nil
 242}
 243
 244func (a *agent) Model() catwalk.Model {
 245	return *config.Get().GetModelByType(a.agentCfg.Model)
 246}
 247
 248func (a *agent) Cancel(sessionID string) {
 249	// Cancel regular requests
 250	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 251		slog.Info("Request cancellation initiated", "session_id", sessionID)
 252		cancel()
 253	}
 254
 255	// Also check for summarize requests
 256	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 257		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 258		cancel()
 259	}
 260
 261	if a.QueuedPrompts(sessionID) > 0 {
 262		slog.Info("Clearing queued prompts", "session_id", sessionID)
 263		a.promptQueue.Del(sessionID)
 264	}
 265}
 266
 267func (a *agent) IsBusy() bool {
 268	var busy bool
 269	for cancelFunc := range a.activeRequests.Seq() {
 270		if cancelFunc != nil {
 271			busy = true
 272			break
 273		}
 274	}
 275	return busy
 276}
 277
 278func (a *agent) IsSessionBusy(sessionID string) bool {
 279	_, busy := a.activeRequests.Get(sessionID)
 280	return busy
 281}
 282
 283func (a *agent) QueuedPrompts(sessionID string) int {
 284	l, ok := a.promptQueue.Get(sessionID)
 285	if !ok {
 286		return 0
 287	}
 288	return len(l)
 289}
 290
 291func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 292	if content == "" {
 293		return nil
 294	}
 295	if a.titleProvider == nil {
 296		return nil
 297	}
 298	session, err := a.sessions.Get(ctx, sessionID)
 299	if err != nil {
 300		return err
 301	}
 302	parts := []message.ContentPart{message.TextContent{
 303		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 304	}}
 305
 306	// Use streaming approach like summarization
 307	response := a.titleProvider.StreamResponse(
 308		ctx,
 309		[]message.Message{
 310			{
 311				Role:  message.User,
 312				Parts: parts,
 313			},
 314		},
 315		nil,
 316	)
 317
 318	var finalResponse *provider.ProviderResponse
 319	for r := range response {
 320		if r.Error != nil {
 321			return r.Error
 322		}
 323		finalResponse = r.Response
 324	}
 325
 326	if finalResponse == nil {
 327		return fmt.Errorf("no response received from title provider")
 328	}
 329
 330	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 331	if title == "" {
 332		return nil
 333	}
 334
 335	session.Title = title
 336	_, err = a.sessions.Save(ctx, session)
 337	return err
 338}
 339
 340func (a *agent) err(err error) AgentEvent {
 341	return AgentEvent{
 342		Type:  AgentEventTypeError,
 343		Error: err,
 344	}
 345}
 346
 347func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 348	if !a.Model().SupportsImages && attachments != nil {
 349		attachments = nil
 350	}
 351	events := make(chan AgentEvent, 1)
 352	if a.IsSessionBusy(sessionID) {
 353		existing, ok := a.promptQueue.Get(sessionID)
 354		if !ok {
 355			existing = []string{}
 356		}
 357		existing = append(existing, content)
 358		a.promptQueue.Set(sessionID, existing)
 359		return nil, nil
 360	}
 361
 362	genCtx, cancel := context.WithCancel(ctx)
 363
 364	a.activeRequests.Set(sessionID, cancel)
 365	go func() {
 366		slog.Debug("Request started", "sessionID", sessionID)
 367		defer log.RecoverPanic("agent.Run", func() {
 368			events <- a.err(fmt.Errorf("panic while running the agent"))
 369		})
 370		var attachmentParts []message.ContentPart
 371		for _, attachment := range attachments {
 372			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 373		}
 374		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 375		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 376			slog.Error(result.Error.Error())
 377		}
 378		slog.Debug("Request completed", "sessionID", sessionID)
 379		a.activeRequests.Del(sessionID)
 380		cancel()
 381		a.Publish(pubsub.CreatedEvent, result)
 382		events <- result
 383		close(events)
 384	}()
 385	return events, nil
 386}
 387
 388func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 389	cfg := config.Get()
 390	// List existing messages; if none, start title generation asynchronously.
 391	msgs, err := a.messages.List(ctx, sessionID)
 392	if err != nil {
 393		return a.err(fmt.Errorf("failed to list messages: %w", err))
 394	}
 395	if len(msgs) == 0 {
 396		go func() {
 397			defer log.RecoverPanic("agent.Run", func() {
 398				slog.Error("panic while generating title")
 399			})
 400			titleErr := a.generateTitle(ctx, sessionID, content)
 401			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 402				slog.Error("failed to generate title", "error", titleErr)
 403			}
 404		}()
 405	}
 406	session, err := a.sessions.Get(ctx, sessionID)
 407	if err != nil {
 408		return a.err(fmt.Errorf("failed to get session: %w", err))
 409	}
 410	if session.SummaryMessageID != "" {
 411		summaryMsgInex := -1
 412		for i, msg := range msgs {
 413			if msg.ID == session.SummaryMessageID {
 414				summaryMsgInex = i
 415				break
 416			}
 417		}
 418		if summaryMsgInex != -1 {
 419			msgs = msgs[summaryMsgInex:]
 420			msgs[0].Role = message.User
 421		}
 422	}
 423
 424	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 425	if err != nil {
 426		return a.err(fmt.Errorf("failed to create user message: %w", err))
 427	}
 428	// Append the new user message to the conversation history.
 429	msgHistory := append(msgs, userMsg)
 430
 431	for {
 432		// Check for cancellation before each iteration
 433		select {
 434		case <-ctx.Done():
 435			return a.err(ctx.Err())
 436		default:
 437			// Continue processing
 438		}
 439		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 440		if err != nil {
 441			if errors.Is(err, context.Canceled) {
 442				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 443				a.messages.Update(context.Background(), agentMessage)
 444				return a.err(ErrRequestCancelled)
 445			}
 446			return a.err(fmt.Errorf("failed to process events: %w", err))
 447		}
 448		if cfg.Options.Debug {
 449			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 450		}
 451		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 452			// We are not done, we need to respond with the tool response
 453			msgHistory = append(msgHistory, agentMessage, *toolResults)
 454			// If there are queued prompts, process the next one
 455			nextPrompt, ok := a.promptQueue.Take(sessionID)
 456			if ok {
 457				for _, prompt := range nextPrompt {
 458					// Create a new user message for the queued prompt
 459					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 460					if err != nil {
 461						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 462					}
 463					// Append the new user message to the conversation history
 464					msgHistory = append(msgHistory, userMsg)
 465				}
 466			}
 467
 468			continue
 469		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 470			queuePrompts, ok := a.promptQueue.Take(sessionID)
 471			if ok {
 472				for _, prompt := range queuePrompts {
 473					if prompt == "" {
 474						continue
 475					}
 476					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 477					if err != nil {
 478						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 479					}
 480					msgHistory = append(msgHistory, userMsg)
 481				}
 482				continue
 483			}
 484		}
 485		if agentMessage.FinishReason() == "" {
 486			// Kujtim: could not track down where this is happening but this means its cancelled
 487			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 488			_ = a.messages.Update(context.Background(), agentMessage)
 489			return a.err(ErrRequestCancelled)
 490		}
 491		return AgentEvent{
 492			Type:    AgentEventTypeResponse,
 493			Message: agentMessage,
 494			Done:    true,
 495		}
 496	}
 497}
 498
 499func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 500	parts := []message.ContentPart{message.TextContent{Text: content}}
 501	parts = append(parts, attachmentParts...)
 502	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 503		Role:  message.User,
 504		Parts: parts,
 505	})
 506}
 507
 508func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 509	allTools := slices.Collect(a.tools.Seq())
 510	if a.agentToolFn != nil {
 511		agentTool, agentToolErr := a.agentToolFn()
 512		if agentToolErr != nil {
 513			return nil, agentToolErr
 514		}
 515		allTools = append(allTools, agentTool)
 516	}
 517	return allTools, nil
 518}
 519
 520func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 521	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 522
 523	// Create the assistant message first so the spinner shows immediately
 524	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 525		Role:     message.Assistant,
 526		Parts:    []message.ContentPart{},
 527		Model:    a.Model().ID,
 528		Provider: a.providerID,
 529	})
 530	if err != nil {
 531		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 532	}
 533
 534	allTools, toolsErr := a.getAllTools()
 535	if toolsErr != nil {
 536		return assistantMsg, nil, toolsErr
 537	}
 538	// Now collect tools (which may block on MCP initialization)
 539	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 540
 541	// Add the session and message ID into the context if needed by tools.
 542	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 543
 544	// Process each event in the stream.
 545	for event := range eventChan {
 546		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 547			if errors.Is(processErr, context.Canceled) {
 548				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 549			} else {
 550				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 551			}
 552			return assistantMsg, nil, processErr
 553		}
 554		if ctx.Err() != nil {
 555			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 556			return assistantMsg, nil, ctx.Err()
 557		}
 558	}
 559
 560	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 561	toolCalls := assistantMsg.ToolCalls()
 562	for i, toolCall := range toolCalls {
 563		select {
 564		case <-ctx.Done():
 565			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 566			// Make all future tool calls cancelled
 567			for j := i; j < len(toolCalls); j++ {
 568				toolResults[j] = message.ToolResult{
 569					ToolCallID: toolCalls[j].ID,
 570					Content:    "Tool execution canceled by user",
 571					IsError:    true,
 572				}
 573			}
 574			goto out
 575		default:
 576			// Continue processing
 577			var tool tools.BaseTool
 578			allTools, _ := a.getAllTools()
 579			for _, availableTool := range allTools {
 580				if availableTool.Info().Name == toolCall.Name {
 581					tool = availableTool
 582					break
 583				}
 584			}
 585
 586			// Tool not found
 587			if tool == nil {
 588				toolResults[i] = message.ToolResult{
 589					ToolCallID: toolCall.ID,
 590					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 591					IsError:    true,
 592				}
 593				continue
 594			}
 595
 596			// Run tool in goroutine to allow cancellation
 597			type toolExecResult struct {
 598				response tools.ToolResponse
 599				err      error
 600			}
 601			resultChan := make(chan toolExecResult, 1)
 602
 603			go func() {
 604				response, err := tool.Run(ctx, tools.ToolCall{
 605					ID:    toolCall.ID,
 606					Name:  toolCall.Name,
 607					Input: toolCall.Input,
 608				})
 609				resultChan <- toolExecResult{response: response, err: err}
 610			}()
 611
 612			var toolResponse tools.ToolResponse
 613			var toolErr error
 614
 615			select {
 616			case <-ctx.Done():
 617				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 618				// Mark remaining tool calls as cancelled
 619				for j := i; j < len(toolCalls); j++ {
 620					toolResults[j] = message.ToolResult{
 621						ToolCallID: toolCalls[j].ID,
 622						Content:    "Tool execution canceled by user",
 623						IsError:    true,
 624					}
 625				}
 626				goto out
 627			case result := <-resultChan:
 628				toolResponse = result.response
 629				toolErr = result.err
 630			}
 631
 632			if toolErr != nil {
 633				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 634				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 635					toolResults[i] = message.ToolResult{
 636						ToolCallID: toolCall.ID,
 637						Content:    "Permission denied",
 638						IsError:    true,
 639					}
 640					for j := i + 1; j < len(toolCalls); j++ {
 641						toolResults[j] = message.ToolResult{
 642							ToolCallID: toolCalls[j].ID,
 643							Content:    "Tool execution canceled by user",
 644							IsError:    true,
 645						}
 646					}
 647					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 648					break
 649				}
 650			}
 651			toolResults[i] = message.ToolResult{
 652				ToolCallID: toolCall.ID,
 653				Content:    toolResponse.Content,
 654				Metadata:   toolResponse.Metadata,
 655				IsError:    toolResponse.IsError,
 656			}
 657		}
 658	}
 659out:
 660	if len(toolResults) == 0 {
 661		return assistantMsg, nil, nil
 662	}
 663	parts := make([]message.ContentPart, 0)
 664	for _, tr := range toolResults {
 665		parts = append(parts, tr)
 666	}
 667	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 668		Role:     message.Tool,
 669		Parts:    parts,
 670		Provider: a.providerID,
 671	})
 672	if err != nil {
 673		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 674	}
 675
 676	return assistantMsg, &msg, err
 677}
 678
 679func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 680	msg.AddFinish(finishReason, message, details)
 681	_ = a.messages.Update(ctx, *msg)
 682}
 683
 684func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 685	select {
 686	case <-ctx.Done():
 687		return ctx.Err()
 688	default:
 689		// Continue processing.
 690	}
 691
 692	switch event.Type {
 693	case provider.EventThinkingDelta:
 694		assistantMsg.AppendReasoningContent(event.Thinking)
 695		return a.messages.Update(ctx, *assistantMsg)
 696	case provider.EventSignatureDelta:
 697		assistantMsg.AppendReasoningSignature(event.Signature)
 698		return a.messages.Update(ctx, *assistantMsg)
 699	case provider.EventContentDelta:
 700		assistantMsg.FinishThinking()
 701		assistantMsg.AppendContent(event.Content)
 702		return a.messages.Update(ctx, *assistantMsg)
 703	case provider.EventToolUseStart:
 704		assistantMsg.FinishThinking()
 705		slog.Info("Tool call started", "toolCall", event.ToolCall)
 706		assistantMsg.AddToolCall(*event.ToolCall)
 707		return a.messages.Update(ctx, *assistantMsg)
 708	case provider.EventToolUseDelta:
 709		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 710		return a.messages.Update(ctx, *assistantMsg)
 711	case provider.EventToolUseStop:
 712		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 713		assistantMsg.FinishToolCall(event.ToolCall.ID)
 714		return a.messages.Update(ctx, *assistantMsg)
 715	case provider.EventError:
 716		return event.Error
 717	case provider.EventComplete:
 718		assistantMsg.FinishThinking()
 719		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 720		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 721		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 722			return fmt.Errorf("failed to update message: %w", err)
 723		}
 724		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 725	}
 726
 727	return nil
 728}
 729
 730func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 731	sess, err := a.sessions.Get(ctx, sessionID)
 732	if err != nil {
 733		return fmt.Errorf("failed to get session: %w", err)
 734	}
 735
 736	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 737		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 738		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 739		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 740
 741	sess.Cost += cost
 742	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 743	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 744
 745	_, err = a.sessions.Save(ctx, sess)
 746	if err != nil {
 747		return fmt.Errorf("failed to save session: %w", err)
 748	}
 749	return nil
 750}
 751
 752func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 753	if a.summarizeProvider == nil {
 754		return fmt.Errorf("summarize provider not available")
 755	}
 756
 757	// Check if session is busy
 758	if a.IsSessionBusy(sessionID) {
 759		return ErrSessionBusy
 760	}
 761
 762	// Create a new context with cancellation
 763	summarizeCtx, cancel := context.WithCancel(ctx)
 764
 765	// Store the cancel function in activeRequests to allow cancellation
 766	a.activeRequests.Set(sessionID+"-summarize", cancel)
 767
 768	go func() {
 769		defer a.activeRequests.Del(sessionID + "-summarize")
 770		defer cancel()
 771		event := AgentEvent{
 772			Type:     AgentEventTypeSummarize,
 773			Progress: "Starting summarization...",
 774		}
 775
 776		a.Publish(pubsub.CreatedEvent, event)
 777		// Get all messages from the session
 778		msgs, err := a.messages.List(summarizeCtx, sessionID)
 779		if err != nil {
 780			event = AgentEvent{
 781				Type:  AgentEventTypeError,
 782				Error: fmt.Errorf("failed to list messages: %w", err),
 783				Done:  true,
 784			}
 785			a.Publish(pubsub.CreatedEvent, event)
 786			return
 787		}
 788		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 789
 790		if len(msgs) == 0 {
 791			event = AgentEvent{
 792				Type:  AgentEventTypeError,
 793				Error: fmt.Errorf("no messages to summarize"),
 794				Done:  true,
 795			}
 796			a.Publish(pubsub.CreatedEvent, event)
 797			return
 798		}
 799
 800		event = AgentEvent{
 801			Type:     AgentEventTypeSummarize,
 802			Progress: "Analyzing conversation...",
 803		}
 804		a.Publish(pubsub.CreatedEvent, event)
 805
 806		// Add a system message to guide the summarization
 807		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 808
 809		// Create a new message with the summarize prompt
 810		promptMsg := message.Message{
 811			Role:  message.User,
 812			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 813		}
 814
 815		// Append the prompt to the messages
 816		msgsWithPrompt := append(msgs, promptMsg)
 817
 818		event = AgentEvent{
 819			Type:     AgentEventTypeSummarize,
 820			Progress: "Generating summary...",
 821		}
 822
 823		a.Publish(pubsub.CreatedEvent, event)
 824
 825		// Send the messages to the summarize provider
 826		response := a.summarizeProvider.StreamResponse(
 827			summarizeCtx,
 828			msgsWithPrompt,
 829			nil,
 830		)
 831		var finalResponse *provider.ProviderResponse
 832		for r := range response {
 833			if r.Error != nil {
 834				event = AgentEvent{
 835					Type:  AgentEventTypeError,
 836					Error: fmt.Errorf("failed to summarize: %w", err),
 837					Done:  true,
 838				}
 839				a.Publish(pubsub.CreatedEvent, event)
 840				return
 841			}
 842			finalResponse = r.Response
 843		}
 844
 845		summary := strings.TrimSpace(finalResponse.Content)
 846		if summary == "" {
 847			event = AgentEvent{
 848				Type:  AgentEventTypeError,
 849				Error: fmt.Errorf("empty summary returned"),
 850				Done:  true,
 851			}
 852			a.Publish(pubsub.CreatedEvent, event)
 853			return
 854		}
 855		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 856		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 857		event = AgentEvent{
 858			Type:     AgentEventTypeSummarize,
 859			Progress: "Creating new session...",
 860		}
 861
 862		a.Publish(pubsub.CreatedEvent, event)
 863		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 864		if err != nil {
 865			event = AgentEvent{
 866				Type:  AgentEventTypeError,
 867				Error: fmt.Errorf("failed to get session: %w", err),
 868				Done:  true,
 869			}
 870
 871			a.Publish(pubsub.CreatedEvent, event)
 872			return
 873		}
 874		// Create a message in the new session with the summary
 875		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 876			Role: message.Assistant,
 877			Parts: []message.ContentPart{
 878				message.TextContent{Text: summary},
 879				message.Finish{
 880					Reason: message.FinishReasonEndTurn,
 881					Time:   time.Now().Unix(),
 882				},
 883			},
 884			Model:    a.summarizeProvider.Model().ID,
 885			Provider: a.summarizeProviderID,
 886		})
 887		if err != nil {
 888			event = AgentEvent{
 889				Type:  AgentEventTypeError,
 890				Error: fmt.Errorf("failed to create summary message: %w", err),
 891				Done:  true,
 892			}
 893
 894			a.Publish(pubsub.CreatedEvent, event)
 895			return
 896		}
 897		oldSession.SummaryMessageID = msg.ID
 898		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 899		oldSession.PromptTokens = 0
 900		model := a.summarizeProvider.Model()
 901		usage := finalResponse.Usage
 902		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 903			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 904			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 905			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 906		oldSession.Cost += cost
 907		_, err = a.sessions.Save(summarizeCtx, oldSession)
 908		if err != nil {
 909			event = AgentEvent{
 910				Type:  AgentEventTypeError,
 911				Error: fmt.Errorf("failed to save session: %w", err),
 912				Done:  true,
 913			}
 914			a.Publish(pubsub.CreatedEvent, event)
 915		}
 916
 917		event = AgentEvent{
 918			Type:      AgentEventTypeSummarize,
 919			SessionID: oldSession.ID,
 920			Progress:  "Summary complete",
 921			Done:      true,
 922		}
 923		a.Publish(pubsub.CreatedEvent, event)
 924		// Send final success event with the new session ID
 925	}()
 926
 927	return nil
 928}
 929
 930func (a *agent) ClearQueue(sessionID string) {
 931	if a.QueuedPrompts(sessionID) > 0 {
 932		slog.Info("Clearing queued prompts", "session_id", sessionID)
 933		a.promptQueue.Del(sessionID)
 934	}
 935}
 936
 937func (a *agent) CancelAll() {
 938	if !a.IsBusy() {
 939		return
 940	}
 941	for key := range a.activeRequests.Seq2() {
 942		a.Cancel(key) // key is sessionID
 943	}
 944
 945	timeout := time.After(5 * time.Second)
 946	for a.IsBusy() {
 947		select {
 948		case <-timeout:
 949			return
 950		default:
 951			time.Sleep(200 * time.Millisecond)
 952		}
 953	}
 954}
 955
 956func (a *agent) UpdateModel() error {
 957	cfg := config.Get()
 958
 959	// Get current provider configuration
 960	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 961	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 962		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 963	}
 964
 965	// Check if provider has changed
 966	if string(currentProviderCfg.ID) != a.providerID {
 967		// Provider changed, need to recreate the main provider
 968		model := cfg.GetModelByType(a.agentCfg.Model)
 969		if model.ID == "" {
 970			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 971		}
 972
 973		promptID := agentPromptMap[a.agentCfg.ID]
 974		if promptID == "" {
 975			promptID = prompt.PromptDefault
 976		}
 977
 978		opts := []provider.ProviderClientOption{
 979			provider.WithModel(a.agentCfg.Model),
 980			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 981		}
 982
 983		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 984		if err != nil {
 985			return fmt.Errorf("failed to create new provider: %w", err)
 986		}
 987
 988		// Update the provider and provider ID
 989		a.provider = newProvider
 990		a.providerID = string(currentProviderCfg.ID)
 991	}
 992
 993	// Check if providers have changed for title (small) and summarize (large)
 994	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 995	var smallModelProviderCfg config.ProviderConfig
 996	for p := range cfg.Providers.Seq() {
 997		if p.ID == smallModelCfg.Provider {
 998			smallModelProviderCfg = p
 999			break
1000		}
1001	}
1002	if smallModelProviderCfg.ID == "" {
1003		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1004	}
1005
1006	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1007	var largeModelProviderCfg config.ProviderConfig
1008	for p := range cfg.Providers.Seq() {
1009		if p.ID == largeModelCfg.Provider {
1010			largeModelProviderCfg = p
1011			break
1012		}
1013	}
1014	if largeModelProviderCfg.ID == "" {
1015		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1016	}
1017
1018	var maxTitleTokens int64 = 40
1019
1020	// if the max output is too low for the gemini provider it won't return anything
1021	if smallModelCfg.Provider == "gemini" {
1022		maxTitleTokens = 1000
1023	}
1024	// Recreate title provider
1025	titleOpts := []provider.ProviderClientOption{
1026		provider.WithModel(config.SelectedModelTypeSmall),
1027		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1028		provider.WithMaxTokens(maxTitleTokens),
1029	}
1030	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1031	if err != nil {
1032		return fmt.Errorf("failed to create new title provider: %w", err)
1033	}
1034	a.titleProvider = newTitleProvider
1035
1036	// Recreate summarize provider if provider changed (now large model)
1037	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1038		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1039		if largeModel == nil {
1040			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1041		}
1042		summarizeOpts := []provider.ProviderClientOption{
1043			provider.WithModel(config.SelectedModelTypeLarge),
1044			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1045		}
1046		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1047		if err != nil {
1048			return fmt.Errorf("failed to create new summarize provider: %w", err)
1049		}
1050		a.summarizeProvider = newSummarizeProvider
1051		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1052	}
1053
1054	return nil
1055}