agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75	// We need this to be able to update it when model changes
  76	agentToolFn func() (tools.BaseTool, error)
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86
  87	promptQueue *csync.Map[string, []string]
  88}
  89
  90var agentPromptMap = map[string]prompt.PromptID{
  91	"coder": prompt.PromptCoder,
  92	"task":  prompt.PromptTask,
  93}
  94
  95func NewAgent(
  96	ctx context.Context,
  97	agentCfg config.Agent,
  98	// These services are needed in the tools
  99	permissions permission.Service,
 100	sessions session.Service,
 101	messages message.Service,
 102	history history.Service,
 103	lspClients map[string]*lsp.Client,
 104) (Service, error) {
 105	cfg := config.Get()
 106
 107	var agentToolFn func() (tools.BaseTool, error)
 108	if agentCfg.ID == "coder" {
 109		agentToolFn = func() (tools.BaseTool, error) {
 110			taskAgentCfg := config.Get().Agents["task"]
 111			if taskAgentCfg.ID == "" {
 112				return nil, fmt.Errorf("task agent not found in config")
 113			}
 114			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 115			if err != nil {
 116				return nil, fmt.Errorf("failed to create task agent: %w", err)
 117			}
 118			return NewAgentTool(taskAgent, sessions, messages), nil
 119		}
 120	}
 121
 122	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 123	if providerCfg == nil {
 124		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 125	}
 126	model := config.Get().GetModelByType(agentCfg.Model)
 127
 128	if model == nil {
 129		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 130	}
 131
 132	promptID := agentPromptMap[agentCfg.ID]
 133	if promptID == "" {
 134		promptID = prompt.PromptDefault
 135	}
 136	opts := []provider.ProviderClientOption{
 137		provider.WithModel(agentCfg.Model),
 138		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 139	}
 140	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 141	if err != nil {
 142		return nil, err
 143	}
 144
 145	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 146	var smallModelProviderCfg *config.ProviderConfig
 147	if smallModelCfg.Provider == providerCfg.ID {
 148		smallModelProviderCfg = providerCfg
 149	} else {
 150		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 151
 152		if smallModelProviderCfg.ID == "" {
 153			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 154		}
 155	}
 156	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 157	if smallModel.ID == "" {
 158		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 159	}
 160
 161	titleOpts := []provider.ProviderClientOption{
 162		provider.WithModel(config.SelectedModelTypeSmall),
 163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 164	}
 165	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 166	if err != nil {
 167		return nil, err
 168	}
 169
 170	summarizeOpts := []provider.ProviderClientOption{
 171		provider.WithModel(config.SelectedModelTypeLarge),
 172		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 173	}
 174	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 175	if err != nil {
 176		return nil, err
 177	}
 178
 179	toolFn := func() []tools.BaseTool {
 180		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 181		defer func() {
 182			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 183		}()
 184
 185		cwd := cfg.WorkingDir()
 186		allTools := []tools.BaseTool{
 187			tools.NewBashTool(permissions, cwd),
 188			tools.NewDownloadTool(permissions, cwd),
 189			tools.NewEditTool(lspClients, permissions, history, cwd),
 190			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 191			tools.NewFetchTool(permissions, cwd),
 192			tools.NewGlobTool(cwd),
 193			tools.NewGrepTool(cwd),
 194			tools.NewLsTool(permissions, cwd),
 195			tools.NewSourcegraphTool(),
 196			tools.NewViewTool(lspClients, permissions, cwd),
 197			tools.NewWriteTool(lspClients, permissions, history, cwd),
 198		}
 199
 200		if agentCfg.AllowedTools == nil {
 201			return allTools
 202		}
 203
 204		var filteredTools []tools.BaseTool
 205		for _, tool := range allTools {
 206			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 207				filteredTools = append(filteredTools, tool)
 208			}
 209		}
 210
 211		if agentCfg.ID == "coder" {
 212			mcpToolsOnce.Do(func() {
 213				mcpTools = doGetMCPTools(ctx, permissions, cfg)
 214			})
 215			filteredTools = append(filteredTools, mcpTools...)
 216			if len(lspClients) > 0 {
 217				filteredTools = append(filteredTools, tools.NewDiagnosticsTool(lspClients))
 218			}
 219
 220		}
 221		return filteredTools
 222	}
 223
 224	return &agent{
 225		Broker:              pubsub.NewBroker[AgentEvent](),
 226		agentCfg:            agentCfg,
 227		provider:            agentProvider,
 228		providerID:          string(providerCfg.ID),
 229		messages:            messages,
 230		sessions:            sessions,
 231		titleProvider:       titleProvider,
 232		summarizeProvider:   summarizeProvider,
 233		summarizeProviderID: string(providerCfg.ID),
 234		agentToolFn:         agentToolFn,
 235		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 236		tools:               csync.NewLazySlice(toolFn),
 237		promptQueue:         csync.NewMap[string, []string](),
 238	}, nil
 239}
 240
 241func (a *agent) Model() catwalk.Model {
 242	return *config.Get().GetModelByType(a.agentCfg.Model)
 243}
 244
 245func (a *agent) Cancel(sessionID string) {
 246	// Cancel regular requests
 247	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 248		slog.Info("Request cancellation initiated", "session_id", sessionID)
 249		cancel()
 250	}
 251
 252	// Also check for summarize requests
 253	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 254		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 255		cancel()
 256	}
 257
 258	if a.QueuedPrompts(sessionID) > 0 {
 259		slog.Info("Clearing queued prompts", "session_id", sessionID)
 260		a.promptQueue.Del(sessionID)
 261	}
 262}
 263
 264func (a *agent) IsBusy() bool {
 265	var busy bool
 266	for cancelFunc := range a.activeRequests.Seq() {
 267		if cancelFunc != nil {
 268			busy = true
 269			break
 270		}
 271	}
 272	return busy
 273}
 274
 275func (a *agent) IsSessionBusy(sessionID string) bool {
 276	_, busy := a.activeRequests.Get(sessionID)
 277	return busy
 278}
 279
 280func (a *agent) QueuedPrompts(sessionID string) int {
 281	l, ok := a.promptQueue.Get(sessionID)
 282	if !ok {
 283		return 0
 284	}
 285	return len(l)
 286}
 287
 288func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 289	if content == "" {
 290		return nil
 291	}
 292	if a.titleProvider == nil {
 293		return nil
 294	}
 295	session, err := a.sessions.Get(ctx, sessionID)
 296	if err != nil {
 297		return err
 298	}
 299	parts := []message.ContentPart{message.TextContent{
 300		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 301	}}
 302
 303	// Use streaming approach like summarization
 304	response := a.titleProvider.StreamResponse(
 305		ctx,
 306		[]message.Message{
 307			{
 308				Role:  message.User,
 309				Parts: parts,
 310			},
 311		},
 312		nil,
 313	)
 314
 315	var finalResponse *provider.ProviderResponse
 316	for r := range response {
 317		if r.Error != nil {
 318			return r.Error
 319		}
 320		finalResponse = r.Response
 321	}
 322
 323	if finalResponse == nil {
 324		return fmt.Errorf("no response received from title provider")
 325	}
 326
 327	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 328	if title == "" {
 329		return nil
 330	}
 331
 332	session.Title = title
 333	_, err = a.sessions.Save(ctx, session)
 334	return err
 335}
 336
 337func (a *agent) err(err error) AgentEvent {
 338	return AgentEvent{
 339		Type:  AgentEventTypeError,
 340		Error: err,
 341	}
 342}
 343
 344func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 345	if !a.Model().SupportsImages && attachments != nil {
 346		attachments = nil
 347	}
 348	events := make(chan AgentEvent, 1)
 349	if a.IsSessionBusy(sessionID) {
 350		existing, ok := a.promptQueue.Get(sessionID)
 351		if !ok {
 352			existing = []string{}
 353		}
 354		existing = append(existing, content)
 355		a.promptQueue.Set(sessionID, existing)
 356		return nil, nil
 357	}
 358
 359	genCtx, cancel := context.WithCancel(ctx)
 360
 361	a.activeRequests.Set(sessionID, cancel)
 362	go func() {
 363		slog.Debug("Request started", "sessionID", sessionID)
 364		defer log.RecoverPanic("agent.Run", func() {
 365			events <- a.err(fmt.Errorf("panic while running the agent"))
 366		})
 367		var attachmentParts []message.ContentPart
 368		for _, attachment := range attachments {
 369			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 370		}
 371		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 372		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 373			slog.Error(result.Error.Error())
 374		}
 375		slog.Debug("Request completed", "sessionID", sessionID)
 376		a.activeRequests.Del(sessionID)
 377		cancel()
 378		a.Publish(pubsub.CreatedEvent, result)
 379		events <- result
 380		close(events)
 381	}()
 382	return events, nil
 383}
 384
 385func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 386	cfg := config.Get()
 387	// List existing messages; if none, start title generation asynchronously.
 388	msgs, err := a.messages.List(ctx, sessionID)
 389	if err != nil {
 390		return a.err(fmt.Errorf("failed to list messages: %w", err))
 391	}
 392	if len(msgs) == 0 {
 393		go func() {
 394			defer log.RecoverPanic("agent.Run", func() {
 395				slog.Error("panic while generating title")
 396			})
 397			titleErr := a.generateTitle(ctx, sessionID, content)
 398			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 399				slog.Error("failed to generate title", "error", titleErr)
 400			}
 401		}()
 402	}
 403	session, err := a.sessions.Get(ctx, sessionID)
 404	if err != nil {
 405		return a.err(fmt.Errorf("failed to get session: %w", err))
 406	}
 407	if session.SummaryMessageID != "" {
 408		summaryMsgInex := -1
 409		for i, msg := range msgs {
 410			if msg.ID == session.SummaryMessageID {
 411				summaryMsgInex = i
 412				break
 413			}
 414		}
 415		if summaryMsgInex != -1 {
 416			msgs = msgs[summaryMsgInex:]
 417			msgs[0].Role = message.User
 418		}
 419	}
 420
 421	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 422	if err != nil {
 423		return a.err(fmt.Errorf("failed to create user message: %w", err))
 424	}
 425	// Append the new user message to the conversation history.
 426	msgHistory := append(msgs, userMsg)
 427
 428	for {
 429		// Check for cancellation before each iteration
 430		select {
 431		case <-ctx.Done():
 432			return a.err(ctx.Err())
 433		default:
 434			// Continue processing
 435		}
 436		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 437		if err != nil {
 438			if errors.Is(err, context.Canceled) {
 439				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 440				a.messages.Update(context.Background(), agentMessage)
 441				return a.err(ErrRequestCancelled)
 442			}
 443			return a.err(fmt.Errorf("failed to process events: %w", err))
 444		}
 445		if cfg.Options.Debug {
 446			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 447		}
 448		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 449			// We are not done, we need to respond with the tool response
 450			msgHistory = append(msgHistory, agentMessage, *toolResults)
 451			// If there are queued prompts, process the next one
 452			nextPrompt, ok := a.promptQueue.Take(sessionID)
 453			if ok {
 454				for _, prompt := range nextPrompt {
 455					// Create a new user message for the queued prompt
 456					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 457					if err != nil {
 458						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 459					}
 460					// Append the new user message to the conversation history
 461					msgHistory = append(msgHistory, userMsg)
 462				}
 463			}
 464
 465			continue
 466		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 467			queuePrompts, ok := a.promptQueue.Take(sessionID)
 468			if ok {
 469				for _, prompt := range queuePrompts {
 470					if prompt == "" {
 471						continue
 472					}
 473					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 474					if err != nil {
 475						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 476					}
 477					msgHistory = append(msgHistory, userMsg)
 478				}
 479				continue
 480			}
 481		}
 482		if agentMessage.FinishReason() == "" {
 483			// Kujtim: could not track down where this is happening but this means its cancelled
 484			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 485			_ = a.messages.Update(context.Background(), agentMessage)
 486			return a.err(ErrRequestCancelled)
 487		}
 488		return AgentEvent{
 489			Type:    AgentEventTypeResponse,
 490			Message: agentMessage,
 491			Done:    true,
 492		}
 493	}
 494}
 495
 496func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 497	parts := []message.ContentPart{message.TextContent{Text: content}}
 498	parts = append(parts, attachmentParts...)
 499	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 500		Role:  message.User,
 501		Parts: parts,
 502	})
 503}
 504
 505func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 506	allTools := slices.Collect(a.tools.Seq())
 507	if a.agentToolFn != nil {
 508		agentTool, agentToolErr := a.agentToolFn()
 509		if agentToolErr != nil {
 510			return nil, agentToolErr
 511		}
 512		allTools = append(allTools, agentTool)
 513	}
 514	return allTools, nil
 515}
 516
 517func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 518	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 519
 520	// Create the assistant message first so the spinner shows immediately
 521	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 522		Role:     message.Assistant,
 523		Parts:    []message.ContentPart{},
 524		Model:    a.Model().ID,
 525		Provider: a.providerID,
 526	})
 527	if err != nil {
 528		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 529	}
 530
 531	allTools, toolsErr := a.getAllTools()
 532	if toolsErr != nil {
 533		return assistantMsg, nil, toolsErr
 534	}
 535	// Now collect tools (which may block on MCP initialization)
 536	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 537
 538	// Add the session and message ID into the context if needed by tools.
 539	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 540
 541	// Process each event in the stream.
 542	for event := range eventChan {
 543		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 544			if errors.Is(processErr, context.Canceled) {
 545				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 546			} else {
 547				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 548			}
 549			return assistantMsg, nil, processErr
 550		}
 551		if ctx.Err() != nil {
 552			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 553			return assistantMsg, nil, ctx.Err()
 554		}
 555	}
 556
 557	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 558	toolCalls := assistantMsg.ToolCalls()
 559	for i, toolCall := range toolCalls {
 560		select {
 561		case <-ctx.Done():
 562			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 563			// Make all future tool calls cancelled
 564			for j := i; j < len(toolCalls); j++ {
 565				toolResults[j] = message.ToolResult{
 566					ToolCallID: toolCalls[j].ID,
 567					Content:    "Tool execution canceled by user",
 568					IsError:    true,
 569				}
 570			}
 571			goto out
 572		default:
 573			// Continue processing
 574			var tool tools.BaseTool
 575			allTools, _ := a.getAllTools()
 576			for _, availableTool := range allTools {
 577				if availableTool.Info().Name == toolCall.Name {
 578					tool = availableTool
 579					break
 580				}
 581			}
 582
 583			// Tool not found
 584			if tool == nil {
 585				toolResults[i] = message.ToolResult{
 586					ToolCallID: toolCall.ID,
 587					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 588					IsError:    true,
 589				}
 590				continue
 591			}
 592
 593			// Run tool in goroutine to allow cancellation
 594			type toolExecResult struct {
 595				response tools.ToolResponse
 596				err      error
 597			}
 598			resultChan := make(chan toolExecResult, 1)
 599
 600			go func() {
 601				response, err := tool.Run(ctx, tools.ToolCall{
 602					ID:    toolCall.ID,
 603					Name:  toolCall.Name,
 604					Input: toolCall.Input,
 605				})
 606				resultChan <- toolExecResult{response: response, err: err}
 607			}()
 608
 609			var toolResponse tools.ToolResponse
 610			var toolErr error
 611
 612			select {
 613			case <-ctx.Done():
 614				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 615				// Mark remaining tool calls as cancelled
 616				for j := i; j < len(toolCalls); j++ {
 617					toolResults[j] = message.ToolResult{
 618						ToolCallID: toolCalls[j].ID,
 619						Content:    "Tool execution canceled by user",
 620						IsError:    true,
 621					}
 622				}
 623				goto out
 624			case result := <-resultChan:
 625				toolResponse = result.response
 626				toolErr = result.err
 627			}
 628
 629			if toolErr != nil {
 630				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 631				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 632					toolResults[i] = message.ToolResult{
 633						ToolCallID: toolCall.ID,
 634						Content:    "Permission denied",
 635						IsError:    true,
 636					}
 637					for j := i + 1; j < len(toolCalls); j++ {
 638						toolResults[j] = message.ToolResult{
 639							ToolCallID: toolCalls[j].ID,
 640							Content:    "Tool execution canceled by user",
 641							IsError:    true,
 642						}
 643					}
 644					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 645					break
 646				}
 647			}
 648			toolResults[i] = message.ToolResult{
 649				ToolCallID: toolCall.ID,
 650				Content:    toolResponse.Content,
 651				Metadata:   toolResponse.Metadata,
 652				IsError:    toolResponse.IsError,
 653			}
 654		}
 655	}
 656out:
 657	if len(toolResults) == 0 {
 658		return assistantMsg, nil, nil
 659	}
 660	parts := make([]message.ContentPart, 0)
 661	for _, tr := range toolResults {
 662		parts = append(parts, tr)
 663	}
 664	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 665		Role:     message.Tool,
 666		Parts:    parts,
 667		Provider: a.providerID,
 668	})
 669	if err != nil {
 670		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 671	}
 672
 673	return assistantMsg, &msg, err
 674}
 675
 676func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 677	msg.AddFinish(finishReason, message, details)
 678	_ = a.messages.Update(ctx, *msg)
 679}
 680
 681func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 682	select {
 683	case <-ctx.Done():
 684		return ctx.Err()
 685	default:
 686		// Continue processing.
 687	}
 688
 689	switch event.Type {
 690	case provider.EventThinkingDelta:
 691		assistantMsg.AppendReasoningContent(event.Thinking)
 692		return a.messages.Update(ctx, *assistantMsg)
 693	case provider.EventSignatureDelta:
 694		assistantMsg.AppendReasoningSignature(event.Signature)
 695		return a.messages.Update(ctx, *assistantMsg)
 696	case provider.EventContentDelta:
 697		assistantMsg.FinishThinking()
 698		assistantMsg.AppendContent(event.Content)
 699		return a.messages.Update(ctx, *assistantMsg)
 700	case provider.EventToolUseStart:
 701		assistantMsg.FinishThinking()
 702		slog.Info("Tool call started", "toolCall", event.ToolCall)
 703		assistantMsg.AddToolCall(*event.ToolCall)
 704		return a.messages.Update(ctx, *assistantMsg)
 705	case provider.EventToolUseDelta:
 706		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 707		return a.messages.Update(ctx, *assistantMsg)
 708	case provider.EventToolUseStop:
 709		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 710		assistantMsg.FinishToolCall(event.ToolCall.ID)
 711		return a.messages.Update(ctx, *assistantMsg)
 712	case provider.EventError:
 713		return event.Error
 714	case provider.EventComplete:
 715		assistantMsg.FinishThinking()
 716		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 717		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 718		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 719			return fmt.Errorf("failed to update message: %w", err)
 720		}
 721		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 722	}
 723
 724	return nil
 725}
 726
 727func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 728	sess, err := a.sessions.Get(ctx, sessionID)
 729	if err != nil {
 730		return fmt.Errorf("failed to get session: %w", err)
 731	}
 732
 733	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 734		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 735		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 736		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 737
 738	sess.Cost += cost
 739	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 740	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 741
 742	_, err = a.sessions.Save(ctx, sess)
 743	if err != nil {
 744		return fmt.Errorf("failed to save session: %w", err)
 745	}
 746	return nil
 747}
 748
 749func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 750	if a.summarizeProvider == nil {
 751		return fmt.Errorf("summarize provider not available")
 752	}
 753
 754	// Check if session is busy
 755	if a.IsSessionBusy(sessionID) {
 756		return ErrSessionBusy
 757	}
 758
 759	// Create a new context with cancellation
 760	summarizeCtx, cancel := context.WithCancel(ctx)
 761
 762	// Store the cancel function in activeRequests to allow cancellation
 763	a.activeRequests.Set(sessionID+"-summarize", cancel)
 764
 765	go func() {
 766		defer a.activeRequests.Del(sessionID + "-summarize")
 767		defer cancel()
 768		event := AgentEvent{
 769			Type:     AgentEventTypeSummarize,
 770			Progress: "Starting summarization...",
 771		}
 772
 773		a.Publish(pubsub.CreatedEvent, event)
 774		// Get all messages from the session
 775		msgs, err := a.messages.List(summarizeCtx, sessionID)
 776		if err != nil {
 777			event = AgentEvent{
 778				Type:  AgentEventTypeError,
 779				Error: fmt.Errorf("failed to list messages: %w", err),
 780				Done:  true,
 781			}
 782			a.Publish(pubsub.CreatedEvent, event)
 783			return
 784		}
 785		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 786
 787		if len(msgs) == 0 {
 788			event = AgentEvent{
 789				Type:  AgentEventTypeError,
 790				Error: fmt.Errorf("no messages to summarize"),
 791				Done:  true,
 792			}
 793			a.Publish(pubsub.CreatedEvent, event)
 794			return
 795		}
 796
 797		event = AgentEvent{
 798			Type:     AgentEventTypeSummarize,
 799			Progress: "Analyzing conversation...",
 800		}
 801		a.Publish(pubsub.CreatedEvent, event)
 802
 803		// Add a system message to guide the summarization
 804		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 805
 806		// Create a new message with the summarize prompt
 807		promptMsg := message.Message{
 808			Role:  message.User,
 809			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 810		}
 811
 812		// Append the prompt to the messages
 813		msgsWithPrompt := append(msgs, promptMsg)
 814
 815		event = AgentEvent{
 816			Type:     AgentEventTypeSummarize,
 817			Progress: "Generating summary...",
 818		}
 819
 820		a.Publish(pubsub.CreatedEvent, event)
 821
 822		// Send the messages to the summarize provider
 823		response := a.summarizeProvider.StreamResponse(
 824			summarizeCtx,
 825			msgsWithPrompt,
 826			nil,
 827		)
 828		var finalResponse *provider.ProviderResponse
 829		for r := range response {
 830			if r.Error != nil {
 831				event = AgentEvent{
 832					Type:  AgentEventTypeError,
 833					Error: fmt.Errorf("failed to summarize: %w", err),
 834					Done:  true,
 835				}
 836				a.Publish(pubsub.CreatedEvent, event)
 837				return
 838			}
 839			finalResponse = r.Response
 840		}
 841
 842		summary := strings.TrimSpace(finalResponse.Content)
 843		if summary == "" {
 844			event = AgentEvent{
 845				Type:  AgentEventTypeError,
 846				Error: fmt.Errorf("empty summary returned"),
 847				Done:  true,
 848			}
 849			a.Publish(pubsub.CreatedEvent, event)
 850			return
 851		}
 852		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 853		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 854		event = AgentEvent{
 855			Type:     AgentEventTypeSummarize,
 856			Progress: "Creating new session...",
 857		}
 858
 859		a.Publish(pubsub.CreatedEvent, event)
 860		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 861		if err != nil {
 862			event = AgentEvent{
 863				Type:  AgentEventTypeError,
 864				Error: fmt.Errorf("failed to get session: %w", err),
 865				Done:  true,
 866			}
 867
 868			a.Publish(pubsub.CreatedEvent, event)
 869			return
 870		}
 871		// Create a message in the new session with the summary
 872		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 873			Role: message.Assistant,
 874			Parts: []message.ContentPart{
 875				message.TextContent{Text: summary},
 876				message.Finish{
 877					Reason: message.FinishReasonEndTurn,
 878					Time:   time.Now().Unix(),
 879				},
 880			},
 881			Model:    a.summarizeProvider.Model().ID,
 882			Provider: a.summarizeProviderID,
 883		})
 884		if err != nil {
 885			event = AgentEvent{
 886				Type:  AgentEventTypeError,
 887				Error: fmt.Errorf("failed to create summary message: %w", err),
 888				Done:  true,
 889			}
 890
 891			a.Publish(pubsub.CreatedEvent, event)
 892			return
 893		}
 894		oldSession.SummaryMessageID = msg.ID
 895		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 896		oldSession.PromptTokens = 0
 897		model := a.summarizeProvider.Model()
 898		usage := finalResponse.Usage
 899		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 900			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 901			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 902			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 903		oldSession.Cost += cost
 904		_, err = a.sessions.Save(summarizeCtx, oldSession)
 905		if err != nil {
 906			event = AgentEvent{
 907				Type:  AgentEventTypeError,
 908				Error: fmt.Errorf("failed to save session: %w", err),
 909				Done:  true,
 910			}
 911			a.Publish(pubsub.CreatedEvent, event)
 912		}
 913
 914		event = AgentEvent{
 915			Type:      AgentEventTypeSummarize,
 916			SessionID: oldSession.ID,
 917			Progress:  "Summary complete",
 918			Done:      true,
 919		}
 920		a.Publish(pubsub.CreatedEvent, event)
 921		// Send final success event with the new session ID
 922	}()
 923
 924	return nil
 925}
 926
 927func (a *agent) ClearQueue(sessionID string) {
 928	if a.QueuedPrompts(sessionID) > 0 {
 929		slog.Info("Clearing queued prompts", "session_id", sessionID)
 930		a.promptQueue.Del(sessionID)
 931	}
 932}
 933
 934func (a *agent) CancelAll() {
 935	if !a.IsBusy() {
 936		return
 937	}
 938	for key := range a.activeRequests.Seq2() {
 939		a.Cancel(key) // key is sessionID
 940	}
 941
 942	timeout := time.After(5 * time.Second)
 943	for a.IsBusy() {
 944		select {
 945		case <-timeout:
 946			return
 947		default:
 948			time.Sleep(200 * time.Millisecond)
 949		}
 950	}
 951}
 952
 953func (a *agent) UpdateModel() error {
 954	cfg := config.Get()
 955
 956	// Get current provider configuration
 957	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 958	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 959		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 960	}
 961
 962	// Check if provider has changed
 963	if string(currentProviderCfg.ID) != a.providerID {
 964		// Provider changed, need to recreate the main provider
 965		model := cfg.GetModelByType(a.agentCfg.Model)
 966		if model.ID == "" {
 967			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 968		}
 969
 970		promptID := agentPromptMap[a.agentCfg.ID]
 971		if promptID == "" {
 972			promptID = prompt.PromptDefault
 973		}
 974
 975		opts := []provider.ProviderClientOption{
 976			provider.WithModel(a.agentCfg.Model),
 977			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 978		}
 979
 980		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 981		if err != nil {
 982			return fmt.Errorf("failed to create new provider: %w", err)
 983		}
 984
 985		// Update the provider and provider ID
 986		a.provider = newProvider
 987		a.providerID = string(currentProviderCfg.ID)
 988	}
 989
 990	// Check if providers have changed for title (small) and summarize (large)
 991	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 992	var smallModelProviderCfg config.ProviderConfig
 993	for p := range cfg.Providers.Seq() {
 994		if p.ID == smallModelCfg.Provider {
 995			smallModelProviderCfg = p
 996			break
 997		}
 998	}
 999	if smallModelProviderCfg.ID == "" {
1000		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1001	}
1002
1003	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1004	var largeModelProviderCfg config.ProviderConfig
1005	for p := range cfg.Providers.Seq() {
1006		if p.ID == largeModelCfg.Provider {
1007			largeModelProviderCfg = p
1008			break
1009		}
1010	}
1011	if largeModelProviderCfg.ID == "" {
1012		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1013	}
1014
1015	var maxTitleTokens int64 = 40
1016
1017	// if the max output is too low for the gemini provider it won't return anything
1018	if smallModelCfg.Provider == "gemini" {
1019		maxTitleTokens = 1000
1020	}
1021	// Recreate title provider
1022	titleOpts := []provider.ProviderClientOption{
1023		provider.WithModel(config.SelectedModelTypeSmall),
1024		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1025		provider.WithMaxTokens(maxTitleTokens),
1026	}
1027	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1028	if err != nil {
1029		return fmt.Errorf("failed to create new title provider: %w", err)
1030	}
1031	a.titleProvider = newTitleProvider
1032
1033	// Recreate summarize provider if provider changed (now large model)
1034	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1035		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1036		if largeModel == nil {
1037			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1038		}
1039		summarizeOpts := []provider.ProviderClientOption{
1040			provider.WithModel(config.SelectedModelTypeLarge),
1041			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1042		}
1043		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1044		if err != nil {
1045			return fmt.Errorf("failed to create new summarize provider: %w", err)
1046		}
1047		a.summarizeProvider = newSummarizeProvider
1048		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1049	}
1050
1051	return nil
1052}