agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84
  85	promptQueue *csync.Map[string, []string]
  86}
  87
  88var agentPromptMap = map[string]prompt.PromptID{
  89	"coder": prompt.PromptCoder,
  90	"task":  prompt.PromptTask,
  91}
  92
  93func NewAgent(
  94	ctx context.Context,
  95	agentCfg config.Agent,
  96	// These services are needed in the tools
  97	permissions permission.Service,
  98	sessions session.Service,
  99	messages message.Service,
 100	history history.Service,
 101	lspClients map[string]*lsp.Client,
 102) (Service, error) {
 103	cfg := config.Get()
 104
 105	var agentTool tools.BaseTool
 106	if agentCfg.ID == "coder" {
 107		taskAgentCfg := config.Get().Agents["task"]
 108		if taskAgentCfg.ID == "" {
 109			return nil, fmt.Errorf("task agent not found in config")
 110		}
 111		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112		if err != nil {
 113			return nil, fmt.Errorf("failed to create task agent: %w", err)
 114		}
 115
 116		agentTool = NewAgentTool(taskAgent, sessions, messages)
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() []tools.BaseTool {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		allTools := []tools.BaseTool{
 184			tools.NewBashTool(permissions, cwd),
 185			tools.NewDownloadTool(permissions, cwd),
 186			tools.NewEditTool(lspClients, permissions, history, cwd),
 187			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 188			tools.NewFetchTool(permissions, cwd),
 189			tools.NewGlobTool(cwd),
 190			tools.NewGrepTool(cwd),
 191			tools.NewLsTool(permissions, cwd),
 192			tools.NewSourcegraphTool(),
 193			tools.NewViewTool(lspClients, permissions, cwd),
 194			tools.NewWriteTool(lspClients, permissions, history, cwd),
 195		}
 196
 197		mcpToolsOnce.Do(func() {
 198			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 199		})
 200		allTools = append(allTools, mcpTools...)
 201
 202		if len(lspClients) > 0 {
 203			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
 204		}
 205
 206		if agentTool != nil {
 207			allTools = append(allTools, agentTool)
 208		}
 209
 210		if agentCfg.AllowedTools == nil {
 211			return allTools
 212		}
 213
 214		var filteredTools []tools.BaseTool
 215		for _, tool := range allTools {
 216			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 217				filteredTools = append(filteredTools, tool)
 218			}
 219		}
 220		return filteredTools
 221	}
 222
 223	return &agent{
 224		Broker:              pubsub.NewBroker[AgentEvent](),
 225		agentCfg:            agentCfg,
 226		provider:            agentProvider,
 227		providerID:          string(providerCfg.ID),
 228		messages:            messages,
 229		sessions:            sessions,
 230		titleProvider:       titleProvider,
 231		summarizeProvider:   summarizeProvider,
 232		summarizeProviderID: string(providerCfg.ID),
 233		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 234		tools:               csync.NewLazySlice(toolFn),
 235		promptQueue:         csync.NewMap[string, []string](),
 236	}, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 326	if title == "" {
 327		return nil
 328	}
 329
 330	session.Title = title
 331	_, err = a.sessions.Save(ctx, session)
 332	return err
 333}
 334
 335func (a *agent) err(err error) AgentEvent {
 336	return AgentEvent{
 337		Type:  AgentEventTypeError,
 338		Error: err,
 339	}
 340}
 341
 342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 343	if !a.Model().SupportsImages && attachments != nil {
 344		attachments = nil
 345	}
 346	events := make(chan AgentEvent)
 347	if a.IsSessionBusy(sessionID) {
 348		existing, ok := a.promptQueue.Get(sessionID)
 349		if !ok {
 350			existing = []string{}
 351		}
 352		existing = append(existing, content)
 353		a.promptQueue.Set(sessionID, existing)
 354		return nil, nil
 355	}
 356
 357	genCtx, cancel := context.WithCancel(ctx)
 358
 359	a.activeRequests.Set(sessionID, cancel)
 360	go func() {
 361		slog.Debug("Request started", "sessionID", sessionID)
 362		defer log.RecoverPanic("agent.Run", func() {
 363			events <- a.err(fmt.Errorf("panic while running the agent"))
 364		})
 365		var attachmentParts []message.ContentPart
 366		for _, attachment := range attachments {
 367			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 368		}
 369		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 370		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 371			slog.Error(result.Error.Error())
 372		}
 373		slog.Debug("Request completed", "sessionID", sessionID)
 374		a.activeRequests.Del(sessionID)
 375		cancel()
 376		a.Publish(pubsub.CreatedEvent, result)
 377		events <- result
 378		close(events)
 379	}()
 380	return events, nil
 381}
 382
 383func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 384	cfg := config.Get()
 385	// List existing messages; if none, start title generation asynchronously.
 386	msgs, err := a.messages.List(ctx, sessionID)
 387	if err != nil {
 388		return a.err(fmt.Errorf("failed to list messages: %w", err))
 389	}
 390
 391	if len(msgs) == 0 {
 392		// Use a context with timeout for title generation
 393		titleCtx, titleCancel := context.WithTimeout(context.Background(), 30*time.Second)
 394		go func() {
 395			defer titleCancel()
 396			defer log.RecoverPanic("agent.Run", func() {
 397				slog.Error("panic while generating title")
 398			})
 399			titleErr := a.generateTitle(titleCtx, sessionID, content)
 400			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 401				slog.Error("failed to generate title", "error", titleErr)
 402			}
 403		}()
 404	}
 405	session, err := a.sessions.Get(ctx, sessionID)
 406	if err != nil {
 407		return a.err(fmt.Errorf("failed to get session: %w", err))
 408	}
 409	if session.SummaryMessageID != "" {
 410		summaryMsgInex := -1
 411		for i, msg := range msgs {
 412			if msg.ID == session.SummaryMessageID {
 413				summaryMsgInex = i
 414				break
 415			}
 416		}
 417		if summaryMsgInex != -1 {
 418			msgs = msgs[summaryMsgInex:]
 419			msgs[0].Role = message.User
 420		}
 421	}
 422
 423	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 424	if err != nil {
 425		return a.err(fmt.Errorf("failed to create user message: %w", err))
 426	}
 427	// Append the new user message to the conversation history.
 428	msgHistory := append(msgs, userMsg)
 429
 430	for {
 431		// Check for cancellation before each iteration
 432		select {
 433		case <-ctx.Done():
 434			return a.err(ctx.Err())
 435		default:
 436			// Continue processing
 437		}
 438		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 439		if err != nil {
 440			if errors.Is(err, context.Canceled) {
 441				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 442				a.messages.Update(context.Background(), agentMessage)
 443				return a.err(ErrRequestCancelled)
 444			}
 445			return a.err(fmt.Errorf("failed to process events: %w", err))
 446		}
 447		if cfg.Options.Debug {
 448			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 449		}
 450		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 451			// We are not done, we need to respond with the tool response
 452			msgHistory = append(msgHistory, agentMessage, *toolResults)
 453			// If there are queued prompts, process the next one
 454			nextPrompt, ok := a.promptQueue.Take(sessionID)
 455			if ok {
 456				for _, prompt := range nextPrompt {
 457					// Create a new user message for the queued prompt
 458					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 459					if err != nil {
 460						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 461					}
 462					// Append the new user message to the conversation history
 463					msgHistory = append(msgHistory, userMsg)
 464				}
 465			}
 466
 467			continue
 468		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 469			queuePrompts, ok := a.promptQueue.Take(sessionID)
 470			if ok {
 471				for _, prompt := range queuePrompts {
 472					if prompt == "" {
 473						continue
 474					}
 475					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 476					if err != nil {
 477						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 478					}
 479					msgHistory = append(msgHistory, userMsg)
 480				}
 481				continue
 482			}
 483		}
 484		if agentMessage.FinishReason() == "" {
 485			// Kujtim: could not track down where this is happening but this means its cancelled
 486			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 487			_ = a.messages.Update(context.Background(), agentMessage)
 488			return a.err(ErrRequestCancelled)
 489		}
 490		return AgentEvent{
 491			Type:    AgentEventTypeResponse,
 492			Message: agentMessage,
 493			Done:    true,
 494		}
 495	}
 496}
 497
 498func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 499	parts := []message.ContentPart{message.TextContent{Text: content}}
 500	parts = append(parts, attachmentParts...)
 501	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 502		Role:  message.User,
 503		Parts: parts,
 504	})
 505}
 506
 507func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 508	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 509
 510	// Create the assistant message first so the spinner shows immediately
 511	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 512		Role:     message.Assistant,
 513		Parts:    []message.ContentPart{},
 514		Model:    a.Model().ID,
 515		Provider: a.providerID,
 516	})
 517	if err != nil {
 518		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 519	}
 520
 521	// Now collect tools (which may block on MCP initialization)
 522	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
 523
 524	// Add the session and message ID into the context if needed by tools.
 525	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 526
 527	// Process each event in the stream.
 528	for event := range eventChan {
 529		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 530			if errors.Is(processErr, context.Canceled) {
 531				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 532			} else {
 533				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 534			}
 535			return assistantMsg, nil, processErr
 536		}
 537		if ctx.Err() != nil {
 538			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 539			return assistantMsg, nil, ctx.Err()
 540		}
 541	}
 542
 543	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 544	toolCalls := assistantMsg.ToolCalls()
 545	for i, toolCall := range toolCalls {
 546		select {
 547		case <-ctx.Done():
 548			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 549			// Make all future tool calls cancelled
 550			for j := i; j < len(toolCalls); j++ {
 551				toolResults[j] = message.ToolResult{
 552					ToolCallID: toolCalls[j].ID,
 553					Content:    "Tool execution canceled by user",
 554					IsError:    true,
 555				}
 556			}
 557			goto out
 558		default:
 559			// Continue processing
 560			var tool tools.BaseTool
 561			for availableTool := range a.tools.Seq() {
 562				if availableTool.Info().Name == toolCall.Name {
 563					tool = availableTool
 564					break
 565				}
 566			}
 567
 568			// Tool not found
 569			if tool == nil {
 570				toolResults[i] = message.ToolResult{
 571					ToolCallID: toolCall.ID,
 572					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 573					IsError:    true,
 574				}
 575				continue
 576			}
 577
 578			// Run tool in goroutine to allow cancellation
 579			type toolExecResult struct {
 580				response tools.ToolResponse
 581				err      error
 582			}
 583			resultChan := make(chan toolExecResult, 1)
 584
 585			go func() {
 586				response, err := tool.Run(ctx, tools.ToolCall{
 587					ID:    toolCall.ID,
 588					Name:  toolCall.Name,
 589					Input: toolCall.Input,
 590				})
 591				resultChan <- toolExecResult{response: response, err: err}
 592			}()
 593
 594			var toolResponse tools.ToolResponse
 595			var toolErr error
 596
 597			select {
 598			case <-ctx.Done():
 599				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 600				// Mark remaining tool calls as cancelled
 601				for j := i; j < len(toolCalls); j++ {
 602					toolResults[j] = message.ToolResult{
 603						ToolCallID: toolCalls[j].ID,
 604						Content:    "Tool execution canceled by user",
 605						IsError:    true,
 606					}
 607				}
 608				goto out
 609			case result := <-resultChan:
 610				toolResponse = result.response
 611				toolErr = result.err
 612			}
 613
 614			if toolErr != nil {
 615				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 616				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 617					toolResults[i] = message.ToolResult{
 618						ToolCallID: toolCall.ID,
 619						Content:    "Permission denied",
 620						IsError:    true,
 621					}
 622					for j := i + 1; j < len(toolCalls); j++ {
 623						toolResults[j] = message.ToolResult{
 624							ToolCallID: toolCalls[j].ID,
 625							Content:    "Tool execution canceled by user",
 626							IsError:    true,
 627						}
 628					}
 629					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 630					break
 631				}
 632			}
 633			toolResults[i] = message.ToolResult{
 634				ToolCallID: toolCall.ID,
 635				Content:    toolResponse.Content,
 636				Metadata:   toolResponse.Metadata,
 637				IsError:    toolResponse.IsError,
 638			}
 639		}
 640	}
 641out:
 642	if len(toolResults) == 0 {
 643		return assistantMsg, nil, nil
 644	}
 645	parts := make([]message.ContentPart, 0)
 646	for _, tr := range toolResults {
 647		parts = append(parts, tr)
 648	}
 649	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 650		Role:     message.Tool,
 651		Parts:    parts,
 652		Provider: a.providerID,
 653	})
 654	if err != nil {
 655		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 656	}
 657
 658	return assistantMsg, &msg, err
 659}
 660
 661func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 662	msg.AddFinish(finishReason, message, details)
 663	_ = a.messages.Update(ctx, *msg)
 664}
 665
 666func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 667	select {
 668	case <-ctx.Done():
 669		return ctx.Err()
 670	default:
 671		// Continue processing.
 672	}
 673
 674	switch event.Type {
 675	case provider.EventThinkingDelta:
 676		assistantMsg.AppendReasoningContent(event.Thinking)
 677		return a.messages.Update(ctx, *assistantMsg)
 678	case provider.EventSignatureDelta:
 679		assistantMsg.AppendReasoningSignature(event.Signature)
 680		return a.messages.Update(ctx, *assistantMsg)
 681	case provider.EventContentDelta:
 682		assistantMsg.FinishThinking()
 683		assistantMsg.AppendContent(event.Content)
 684		return a.messages.Update(ctx, *assistantMsg)
 685	case provider.EventToolUseStart:
 686		assistantMsg.FinishThinking()
 687		slog.Info("Tool call started", "toolCall", event.ToolCall)
 688		assistantMsg.AddToolCall(*event.ToolCall)
 689		return a.messages.Update(ctx, *assistantMsg)
 690	case provider.EventToolUseDelta:
 691		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 692		return a.messages.Update(ctx, *assistantMsg)
 693	case provider.EventToolUseStop:
 694		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 695		assistantMsg.FinishToolCall(event.ToolCall.ID)
 696		return a.messages.Update(ctx, *assistantMsg)
 697	case provider.EventError:
 698		return event.Error
 699	case provider.EventComplete:
 700		assistantMsg.FinishThinking()
 701		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 702		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 703		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 704			return fmt.Errorf("failed to update message: %w", err)
 705		}
 706		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 707	}
 708
 709	return nil
 710}
 711
 712func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 713	sess, err := a.sessions.Get(ctx, sessionID)
 714	if err != nil {
 715		return fmt.Errorf("failed to get session: %w", err)
 716	}
 717
 718	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 719		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 720		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 721		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 722
 723	sess.Cost += cost
 724	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 725	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 726
 727	_, err = a.sessions.Save(ctx, sess)
 728	if err != nil {
 729		return fmt.Errorf("failed to save session: %w", err)
 730	}
 731	return nil
 732}
 733
 734func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 735	if a.summarizeProvider == nil {
 736		return fmt.Errorf("summarize provider not available")
 737	}
 738
 739	// Check if session is busy
 740	if a.IsSessionBusy(sessionID) {
 741		return ErrSessionBusy
 742	}
 743
 744	// Create a new context with cancellation
 745	summarizeCtx, cancel := context.WithCancel(ctx)
 746
 747	// Store the cancel function in activeRequests to allow cancellation
 748	a.activeRequests.Set(sessionID+"-summarize", cancel)
 749
 750	go func() {
 751		defer a.activeRequests.Del(sessionID + "-summarize")
 752		defer cancel()
 753		event := AgentEvent{
 754			Type:     AgentEventTypeSummarize,
 755			Progress: "Starting summarization...",
 756		}
 757
 758		a.Publish(pubsub.CreatedEvent, event)
 759		// Get all messages from the session
 760		msgs, err := a.messages.List(summarizeCtx, sessionID)
 761		if err != nil {
 762			event = AgentEvent{
 763				Type:  AgentEventTypeError,
 764				Error: fmt.Errorf("failed to list messages: %w", err),
 765				Done:  true,
 766			}
 767			a.Publish(pubsub.CreatedEvent, event)
 768			return
 769		}
 770		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 771
 772		if len(msgs) == 0 {
 773			event = AgentEvent{
 774				Type:  AgentEventTypeError,
 775				Error: fmt.Errorf("no messages to summarize"),
 776				Done:  true,
 777			}
 778			a.Publish(pubsub.CreatedEvent, event)
 779			return
 780		}
 781
 782		event = AgentEvent{
 783			Type:     AgentEventTypeSummarize,
 784			Progress: "Analyzing conversation...",
 785		}
 786		a.Publish(pubsub.CreatedEvent, event)
 787
 788		// Add a system message to guide the summarization
 789		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 790
 791		// Create a new message with the summarize prompt
 792		promptMsg := message.Message{
 793			Role:  message.User,
 794			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 795		}
 796
 797		// Append the prompt to the messages
 798		msgsWithPrompt := append(msgs, promptMsg)
 799
 800		event = AgentEvent{
 801			Type:     AgentEventTypeSummarize,
 802			Progress: "Generating summary...",
 803		}
 804
 805		a.Publish(pubsub.CreatedEvent, event)
 806
 807		// Send the messages to the summarize provider
 808		response := a.summarizeProvider.StreamResponse(
 809			summarizeCtx,
 810			msgsWithPrompt,
 811			nil,
 812		)
 813		var finalResponse *provider.ProviderResponse
 814		for r := range response {
 815			if r.Error != nil {
 816				event = AgentEvent{
 817					Type:  AgentEventTypeError,
 818					Error: fmt.Errorf("failed to summarize: %w", err),
 819					Done:  true,
 820				}
 821				a.Publish(pubsub.CreatedEvent, event)
 822				return
 823			}
 824			finalResponse = r.Response
 825		}
 826
 827		summary := strings.TrimSpace(finalResponse.Content)
 828		if summary == "" {
 829			event = AgentEvent{
 830				Type:  AgentEventTypeError,
 831				Error: fmt.Errorf("empty summary returned"),
 832				Done:  true,
 833			}
 834			a.Publish(pubsub.CreatedEvent, event)
 835			return
 836		}
 837		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 838		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 839		event = AgentEvent{
 840			Type:     AgentEventTypeSummarize,
 841			Progress: "Creating new session...",
 842		}
 843
 844		a.Publish(pubsub.CreatedEvent, event)
 845		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 846		if err != nil {
 847			event = AgentEvent{
 848				Type:  AgentEventTypeError,
 849				Error: fmt.Errorf("failed to get session: %w", err),
 850				Done:  true,
 851			}
 852
 853			a.Publish(pubsub.CreatedEvent, event)
 854			return
 855		}
 856		// Create a message in the new session with the summary
 857		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 858			Role: message.Assistant,
 859			Parts: []message.ContentPart{
 860				message.TextContent{Text: summary},
 861				message.Finish{
 862					Reason: message.FinishReasonEndTurn,
 863					Time:   time.Now().Unix(),
 864				},
 865			},
 866			Model:    a.summarizeProvider.Model().ID,
 867			Provider: a.summarizeProviderID,
 868		})
 869		if err != nil {
 870			event = AgentEvent{
 871				Type:  AgentEventTypeError,
 872				Error: fmt.Errorf("failed to create summary message: %w", err),
 873				Done:  true,
 874			}
 875
 876			a.Publish(pubsub.CreatedEvent, event)
 877			return
 878		}
 879		oldSession.SummaryMessageID = msg.ID
 880		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 881		oldSession.PromptTokens = 0
 882		model := a.summarizeProvider.Model()
 883		usage := finalResponse.Usage
 884		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 885			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 886			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 887			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 888		oldSession.Cost += cost
 889		_, err = a.sessions.Save(summarizeCtx, oldSession)
 890		if err != nil {
 891			event = AgentEvent{
 892				Type:  AgentEventTypeError,
 893				Error: fmt.Errorf("failed to save session: %w", err),
 894				Done:  true,
 895			}
 896			a.Publish(pubsub.CreatedEvent, event)
 897		}
 898
 899		event = AgentEvent{
 900			Type:      AgentEventTypeSummarize,
 901			SessionID: oldSession.ID,
 902			Progress:  "Summary complete",
 903			Done:      true,
 904		}
 905		a.Publish(pubsub.CreatedEvent, event)
 906		// Send final success event with the new session ID
 907	}()
 908
 909	return nil
 910}
 911
 912func (a *agent) ClearQueue(sessionID string) {
 913	if a.QueuedPrompts(sessionID) > 0 {
 914		slog.Info("Clearing queued prompts", "session_id", sessionID)
 915		a.promptQueue.Del(sessionID)
 916	}
 917}
 918
 919func (a *agent) CancelAll() {
 920	if !a.IsBusy() {
 921		return
 922	}
 923	for key := range a.activeRequests.Seq2() {
 924		a.Cancel(key) // key is sessionID
 925	}
 926
 927	timeout := time.After(5 * time.Second)
 928	for a.IsBusy() {
 929		select {
 930		case <-timeout:
 931			return
 932		default:
 933			time.Sleep(200 * time.Millisecond)
 934		}
 935	}
 936}
 937
 938func (a *agent) UpdateModel() error {
 939	cfg := config.Get()
 940
 941	// Get current provider configuration
 942	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 943	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 944		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 945	}
 946
 947	// Check if provider has changed
 948	if string(currentProviderCfg.ID) != a.providerID {
 949		// Provider changed, need to recreate the main provider
 950		model := cfg.GetModelByType(a.agentCfg.Model)
 951		if model.ID == "" {
 952			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 953		}
 954
 955		promptID := agentPromptMap[a.agentCfg.ID]
 956		if promptID == "" {
 957			promptID = prompt.PromptDefault
 958		}
 959
 960		opts := []provider.ProviderClientOption{
 961			provider.WithModel(a.agentCfg.Model),
 962			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 963		}
 964
 965		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 966		if err != nil {
 967			return fmt.Errorf("failed to create new provider: %w", err)
 968		}
 969
 970		// Update the provider and provider ID
 971		a.provider = newProvider
 972		a.providerID = string(currentProviderCfg.ID)
 973	}
 974
 975	// Check if providers have changed for title (small) and summarize (large)
 976	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 977	var smallModelProviderCfg config.ProviderConfig
 978	for p := range cfg.Providers.Seq() {
 979		if p.ID == smallModelCfg.Provider {
 980			smallModelProviderCfg = p
 981			break
 982		}
 983	}
 984	if smallModelProviderCfg.ID == "" {
 985		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 986	}
 987
 988	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
 989	var largeModelProviderCfg config.ProviderConfig
 990	for p := range cfg.Providers.Seq() {
 991		if p.ID == largeModelCfg.Provider {
 992			largeModelProviderCfg = p
 993			break
 994		}
 995	}
 996	if largeModelProviderCfg.ID == "" {
 997		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
 998	}
 999
1000	// Recreate title provider
1001	titleOpts := []provider.ProviderClientOption{
1002		provider.WithModel(config.SelectedModelTypeSmall),
1003		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1004		provider.WithMaxTokens(40),
1005	}
1006	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1007	if err != nil {
1008		return fmt.Errorf("failed to create new title provider: %w", err)
1009	}
1010	a.titleProvider = newTitleProvider
1011
1012	// Recreate summarize provider if provider changed (now large model)
1013	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1014		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1015		if largeModel == nil {
1016			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1017		}
1018		summarizeOpts := []provider.ProviderClientOption{
1019			provider.WithModel(config.SelectedModelTypeLarge),
1020			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1021		}
1022		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1023		if err != nil {
1024			return fmt.Errorf("failed to create new summarize provider: %w", err)
1025		}
1026		a.summarizeProvider = newSummarizeProvider
1027		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1028	}
1029
1030	return nil
1031}